Curried Memoization

lua-users home
wiki

Yet another function memoization implementation. It handles the general case of M argument, N return value functions. It also takes some care to preserve nils appearing in either the argument or result lists.

For example:

  local mtest = weak_memoize_m_to_n( function(...) print 'exec' return ... end )



  print( mtest(nil,2) )    --> "exec", "nil,2"

  print( mtest(nil,2) )    --> "nil,2"

  collectgarbage()

  print( mtest ( nil,2 ) ) --> "exec", "nil,2"

  print( mtest ( ) )       --> "exec", ""

  print( mtest ( nil ) )   --> "exec", "nil"

The design is equivalent to an argument tree technique like [memoize.lua]. However, the tree is built implicitly, rather than explicitly, by recursively reducing M>1 cases to M-1 cases.

Lua Code

  local _ENV = setmetatable({},{__index=_G})



  -- this code can be made more memory and speed efficient by

  -- defining catch() in C.  but, a table.unpack approach will also

  -- work. 

  function catch(...)

    local rvals = {...}

    local n = select('#',...)

    return function()

      return table.unpack(rvals,1,n)

    end

  end



  local weak_mt= {__mode='kv'}

  local function weak_table() return setmetatable({},weak_mt) end

  local function strong_table() return {} end



  local null = {}



  local function arg2key(arg)

    return (arg == nil and null) or arg

  end



  -- build a memoization function that can handle the 1-argument 

  -- to n rvals case.  

  local function new_memoizer_1_to_n(newtable)

    return function(fun)

      local lookup = newtable()



      return function (arg)

        local k = arg2key(arg)

        local r=lookup[k]

        if r then

          return r()

        end

        r=catch( fun(arg) )

        lookup[k] = r

        return r()

      end



    end

  end



  local function new_memoizer_m_to_n( newtable, memoize_1_to_n )



    -- return a memoization of f that assumes m arguments.

    local function memoize_m_to_n(m,f)

      --  handle the m==0 case

      if m==0 then

        local memoized

        return function()

          if memoized then

            return memoized()

          end

          memoized = catch(f())

          return memoized()

        end

      end



      if m==1 then

        return memoize_1_to_n(f)

      end



      local lookup = newtable()



      -- handle the general m-to-n case, for m>=2. 

      return function(arg, ...)



        local k = arg2key(arg)

        local r = lookup[k]



        if r then

          return r(...)

        end

    

        -- create a new (m-1) argument memoizer that will handle 

        -- this arg value in the future.  

        r = memoize_m_to_n(m-1, function(...)

          return f(arg,...)

        end)



        lookup[k]=r

        return r(...)

      end

    end



    -- return a memoizer that dispatches between the different m-argument cases.

    return function(f)

      local m_to_memoized = newtable()

      return function(...)

        local m = select('#',...)

        local memoized = m_to_memoized[m]

        if memoized then

          return memoized(...)

        end

        memoized = memoize_m_to_n(m,f)

        m_to_memoized[m]=memoized

        return memoized(...)

      end

    end

  end



  weak_memoize_1_to_n =  new_memoizer_1_to_n(weak_table)

  strong_memoize_1_to_n =  new_memoizer_1_to_n(strong_table)



  weak_memoize_m_to_n = new_memoizer_m_to_n(weak_table,weak_memoize_1_to_n)

  strong_memoize_m_to_n = new_memoizer_m_to_n(strong_table,strong_memoize_1_to_n)



  return _ENV

Implementing catch() in C

Memory use and performance can both be improved by implementing catch() inside the C-API. While their storage capacity is limited to 255 values; C-closures are lighter, faster datastructures than Lua's generic tables.


  static int throw_upvalues(lua_State *L) {

    int n1=lua_tointeger(L,lua_upvalueindex(1));

    luaL_checkstack(L,n1-1,"too many upvalues");

    for(int i=2; i<=n1; i++) {

      lua_pushvalue(L,lua_upvalueindex(i));

    }

    return n1-1;

  }



  static int catch_args(lua_State *L) {

    int n1 = lua_gettop(L)+1;

    if(n1>MAXUPVAL) {

      return luaL_error(L,"can't catch more than %d args. (catch() called with %d arguments).",MAXUPVAL-1, n1-1);

    }

    lua_pushinteger(L,n1);

    lua_insert(L,1);

    lua_pushcclosure(L,throw_upvalues,n1);

    return 1;

  }

See Also

Many other Lua memoization implementations are scattered around the web. The FuncTables page appears to be the de-facto wiki link hub. But, the topic also comes up frequently on the lua users list; and there are a couple nice implementations of string-serialization based approaches posted in the archives [1], [2].


RecentChanges · preferences
edit · history
Last edited December 31, 2013 6:50 pm GMT (diff)