Curried Lua

lua-users home
wiki

Currying is defined by Wikipedia[1] as follows:

"In computer science, currying is the technique of transforming a function taking multiple arguments into a function that takes a single argument (the first of the arguments to the original function) and returns a new function that takes the remainder of the arguments and returns the result"

You can implement curried functions in all languages that support functions as first-class objects. For example, there's a little [tutorial about curried JavaScript].

Here is a small Lua example of a curried function:

function sum(number) 

  return function(anothernumber) 

    return number + anothernumber

  end

end



local f = sum(5)

print(f(3)) --> 8

-- WalterCruz

Here is another, contributed by [GavinWraith], which takes a variable number of arguments terminated with a "()":

function addup(x)

  local sum = 0

  local function f(n)

    if type(n) == "number" then

      sum = sum + n

      return f

    else

      return sum

    end

  end

  return f(x)

end



print(addup (1) (2) (3) ())  --> 6

print(addup (4) (5) (6) ())  --> 15

Although these pre-curried functions are useful, what we would really like to do is make a general-purpose function that can perform the curry operation on any other function. To do this, we need to realize that functions can be operated upon by a "Higher-order function" -- a function that takes functions as arguments. The following curry function is an example of this, and curries a 2-argument function:

function curry(f)

    return function (x) return function (y) return f(x,y) end end

end



powcurry = curry(math.pow)

powcurry (2) (4) --> 16

pow2 = powcurry(2)

pow2(3) --> 8

pow2(4) --> 16

pow2(8) --> 256

To go from currying 2 arguments to currying 'n' arguments is a bit more complicated. We need to store an indeterminate number of partial applications, and unfortunately there is no way for Lua to know how many arguments a function requires; Lua functions can successfully receive any number of arguments, whether too many or too few. So, it's necessary to tell the curry function how many single-argument calls to accept before applying those collected arguments to the original function.

(this code is freely available from http://tinylittlelife.org/?p=249 and includes a full discussion of how to tackle this problem.)

-- curry(func, num_args) : take a function requiring a tuple for num_args arguments

--                         and turn it into a series of 1-argument functions

-- e.g.: you have a function dosomething(a, b, c)

-- curried_dosomething = curry(dosomething, 3) -- we want to curry 3 arguments

-- curried_dosomething (a1) (b1) (c1) -- returns the result of dosomething(a1, b1, c1)

-- partial_dosomething1 = curried_dosomething (a_value) -- returns a function

-- partial_dosomething2 = partial_dosomething1 (b_value) -- returns a function

-- partial_dosomething2 (c_value) -- returns the result of dosomething(a_value, b_value, c_value)

function curry(func, num_args)



   -- currying 2-argument functions seems to be the most popular application

   num_args = num_args or 2



   -- no sense currying for 1 arg or less

   if num_args <= 1 then return func end



   -- helper takes an argtrace function, and number of arguments remaining to be applied

   local function curry_h(argtrace, n)

      if 0 == n then

	 -- kick off argtrace, reverse argument list, and call the original function

         return func(reverse(argtrace()))

      else

         -- "push" argument (by building a wrapper function) and decrement n

         return function (onearg)

                   return curry_h(function () return onearg, argtrace() end, n - 1)

                end

      end

   end  

   

   -- push the terminal case of argtrace into the function first

   return curry_h(function () return end, num_args)



end



-- reverse(...) : take some tuple and return a tuple of elements in reverse order

--                  

-- e.g. "reverse(1,2,3)" returns 3,2,1

function reverse(...)



   --reverse args by building a function to do it, similar to the unpack() example

   local function reverse_h(acc, v, ...)

      if 0 == select('#', ...) then

	 return v, acc()

      else

         return reverse_h(function () return v, acc() end, ...)

      end

   end  



   -- initial acc is the end of the list

   return reverse_h(function () return end, ...)

end



The above code is Lua 5.1 compatible.

Since Lua 5.2 (or LuaJIT 2.0) provides an advanced debug.getinfo function that let us know how many arguments a function desires, we can make a practical function which mixes currying and partial application techniques. Here's the code:

function curry(func, num_args)

  num_args = num_args or debug.getinfo(func, "u").nparams

  if num_args < 2 then return func end

  local function helper(argtrace, n)

    if n < 1 then

      return func(unpack(flatten(argtrace)))

    else

      return function (...)

        return helper({argtrace, ...}, n - select("#", ...))

      end

    end

  end

  return helper({}, num_args)

end



function flatten(t)

  local ret = {}

  for _, v in ipairs(t) do

    if type(v) == 'table' then

      for _, fv in ipairs(flatten(v)) do

        ret[#ret + 1] = fv

      end

    else

      ret[#ret + 1] = v

    end

  end

  return ret

end



function multiplyAndAdd (a, b, c) return a * b + c end



curried_multiplyAndAdd = curry(multiplyAndAdd)



multiplyBySevenAndAdd = curried_multiplyAndAdd(7)



multiplySevenByEightAndAdd_v1 = multiplyBySevenAndAdd(8)

multiplySevenByEightAndAdd_v2 = curried_multiplyAndAdd(7, 8)



assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v1(9))

assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v2(9))

assert(multiplyAndAdd(7, 8, 9) == multiplyBySevenAndAdd(8, 9))

assert(multiplyAndAdd(7, 8, 9) == curried_multiplyAndAdd(7, 8, 9))

See Also


RecentChanges · preferences
edit · history
Last edited March 27, 2014 1:39 pm GMT (diff)