Overloaded Functions

lua-users home
wiki

Overloaded functions let you call one function and have it dispatch to different implementations based on the argument types. C++ and Java have support for this at the language level; in Lua you can roll your own at runtime:

function overloaded()

    local fns = {}

    local mt = {}

    

    local function oerror()

        return error("Invalid argument types to overloaded function")

    end

    

    function mt:__call(...)

        local argv = {...}

        local default = self.default

        

        local signature = {}

        for i,arg in ipairs {...} do

            signature[i] = type(arg)

        end

        

        signature = table.concat(signature, ",")

        

        return (fns[signature] or self.default)(...)

    end

    

    function mt:__index(key)

        local signature = {}

        local function __newindex(self, key, value)

            print(key, type(key), value, type(value))

            signature[#signature+1] = key

            fns[table.concat(signature, ",")] = value

            print("bind", table.concat(signature, ", "))

        end

        local function __index(self, key)

            print("I", key, type(key))

            signature[#signature+1] = key

            return setmetatable({}, { __index = __index, __newindex = __newindex })

        end

        return __index(self, key)

    end

    

    function mt:__newindex(key, value)

        fns[key] = value

    end

    

    return setmetatable({ default = oerror }, mt)

end

You can use this like so:

foo = overloaded()



-- if passed a number, return its square

function foo.number(n)

    return n^2

end



-- if passed a string, convert it to a number and call the numeric version

function foo.string(s)

    return foo(tonumber(s))

end



-- if passed a string _and_ a number, act like string.rep

foo.string.number = string.rep



-- if given anything else, act like print

foo.default = print



--- begin test code ---

foo(6)

=> 36

foo("4")

=> 16

foo("not a valid number")

=> error (attempt to perform arithmetic on a nil value)

foo("foo", 4)

=> foofoofoofoo

foo(true, false, {})

=> true    false   table: 0x12345678


RecentChanges · preferences
edit · history
Last edited July 29, 2009 10:19 pm GMT (diff)