Method Chaining Wrapper

lua-users home
wiki

At times we would like to add custom methods to built-in types like strings and functions, particularly when using method chaining [1][2]:

("  test  "):trim():repeatchars(2):upper() --> "TTEESSTT"



(function(x,y) return x*y end):curry(2) --> (function(y) return 2*y end)

We can do this with the debug library (debug.setmetatable) [5]. The downside is that each built-in type has a single, common metatable. Modifying this metatable causes a global side-effect, which is a potential source of conflict between independently maintained modules in a program. Functions in the debug library are often discouraged in regular code for good reason. Many people avoid injecting into these global metatables, while others find it too convenient to avoid [3][6][ExtensionProposal]. Some have even asked why objects of built-in types don't have their own metatables [7].

...

debug.setmetatable("", string_mt)

debug.setmetatable(function()end, function_mt)

We could instead use just standalone functions:

(repeatchars(trim("test"), 2)):upper()



curry(function(x,y) return x*y end, 2)

This is the simplest solution. Simple solutions are often good ones. Nevertheless, there can be a certain level of discordance with some operations being method calls and some being standalone global functions, along with the reordering that results.

One solution to avoid touching the global metatables is to wrap the object inside our own class, perform operations on the wrapper in a method call chain, and unwrap the objects.

Examples would look like this:

S"  test  ":trim():repeatchars(2):upper()() --> TTEESSTT



S"  TEST  ":trim():lower():find('e')() --> 2 2

The S function wraps the given object into a wrapper object. A chain of method calls on the wrapper object operate on the wrapped object in-place. Finally, the wrapper object is unwrapped with a function call ().

For functions that return a single value, an alternative way to unpack is to the use the unary minus:

-S"  test  ":trim():repeatchars(2):upper() --> TTEESSTT

To define S in terms of a table of string functions stringx, we can use this code:

local stringx = {}

for k,v in pairs(string) do stringx[k] = v end

function stringx.trim(self)

  return self:match('^%s*(%S*)%s*$')

end

function stringx.repeatchars(self, n)

  local ts = {}

  for i=1,#self do

    local c = self:sub(i,i)

    for i=1,n do ts[#ts+1] = c end

  end

  return table.concat(ts)

end



local S = buildchainwrapbuilder(stringx)

The buildchainwrapbuilder function is general and implements our design pattern:

-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).

-- version 20090430

local select = select

local setmetatable = setmetatable

local unpack = unpack

local rawget = rawget



-- http://lua-users.org/wiki/CodeGeneration

local function memoize(func)

  return setmetatable({}, {

    __index = function(self, k) local v = func(k); self[k] = v; return v end,

    __call = function(self, k) return self[k] end

  })

end



-- unique IDs (avoid name clashes with wrapped object)

local N = {}

local VALS = memoize(function() return {} end)

local VAL = VALS[1]

local PREV = {}



local function mypack(ow, ...)

  local n = select('#', ...)

  for i=1,n do ow[VALS[i]] = select(i, ...) end

  for i=n+1,ow[N] do ow[VALS[i]] = nil end

  ow[N] = n

end



local function myunpack(ow, i)

  i = i or 1

  if i <= ow[N] then

    return rawget(ow, VALS[i]), myunpack(ow, i+1)

  end

end



local function buildchainwrapbuilder(t)

  local mt = {}

  function mt:__index(k)

    local val = rawget(self, VAL)

    self[PREV] = val -- store in case of method call

    mypack(self, t[k])

    return self

  end

  function mt:__call(...)

    if (...) == self then -- method call

      local val = rawget(self, VAL)

      local prev = rawget(self, PREV)

      self[PREV] = nil

      mypack(self, val(prev, select(2,...)))

      return self

    else

      return myunpack(self, 1, self[N])

    end

  end

  function mt:__unm() return rawget(self, VAL) end



  local function build(o)

    return setmetatable({[VAL]=o,[N]=1}, mt)

  end

  return build

end



local function chainwrap(o, t)

  return buildchainwrapbuilder(t)(o)

end

Test suite:

-- simple examples

assert(-S"AA":lower() == "aa")

assert(-S"AB":lower():reverse() == "ba")

assert(-S"  test  ":trim():repeatchars(2):upper() == "TTEESSTT")

assert(S"  test  ":trim():repeatchars(2):upper()() == "TTEESSTT")



-- basics

assert(S""() == "")

assert(S"a"() == "a")

assert(-S"a" == "a")

assert(S(nil)() == nil)

assert(S"a":byte()() == 97)

local a,b,c = S"TEST":lower():find('e')()

assert(a==2 and b==2 and c==nil)

assert(-S"TEST":lower():find('e') == 2)



-- potentially tricky cases

assert(S"".__index() == nil)

assert(S"".__call() == nil)

assert(S""[1]() == nil)

stringx[1] = 'c'

assert(S"a"[1]() == 'c')

assert(S"a"[1]:upper()() == 'C')

stringx[1] = 'd'

assert(S"a"[1]() == 'd') -- uncached

assert(S"a".lower() == string.lower)



-- improve error messages?

--assert(S(nil):z() == nil)



print 'DONE'

The above implementation has these qualities and assumptions:

There are alternative ways we could have expressed the chaining:

S{"  test  ", "trim", {"repeatchars",2}, "upper"}



S("  test  ", "trim | repeatchars(2) | upper")

but this looks less conventional. (Note: the second argument in the last line is point-free [4].)

Method Chaining Wrapper Take #2 - Object at End of Chain

We could instead express the call chain like this:

chain(stringx):trim():repeatchars(5):upper()('  test   ')

where the object operated on is placed at the very end. This reduces the chance of forgetting to unpack, and it allows separation and reuse:

f = chain(stringx):trim():repeatchars(5):upper()

print ( f('  test  ') )

print ( f('  again  ') )

There's various ways to implement this (functional, CodeGeneration, and VM). Here we take the latter approach.

-- method call chaining, take #2

-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).

-- version 20090501



-- unique IDs to avoid name conflict

local OPS = {}

local INDEX = {}

local METHOD = {}



-- table insert, allowing trailing nils

local function myinsert(t, v)

  local n = t.n + 1; t.n = n

  t[n] = v

end



local function eval(ops, x)

  --print('DEBUG:', unpack(ops,1,ops.n))

  local t = ops.t



  local self = x

  local prev

  local n = ops.n

  local i=1; while i <= n do

    if ops[i] == INDEX then

      local k = ops[i+1]

      prev = x  -- save in case of method call

      x = t[k]

      i = i + 2

    elseif ops[i] == METHOD then

      local narg = ops[i+1]

      x = x(prev, unpack(ops, i+2, i+1+narg))

      i = i + 2 + narg

    else

      assert(false)

    end

  end

  return x

end



local mt = {}

function mt:__index(k)

  local ops = self[OPS]

  myinsert(ops, INDEX)

  myinsert(ops, k)

  return self

end



function mt:__call(x, ...)

  local ops = self[OPS]

  if x == self then -- method call

    myinsert(ops, METHOD)

    local n = select('#', ...)

    myinsert(ops, n)

    for i=1,n do

      myinsert(ops, (select(i, ...)))

    end

    return self

  else

    return eval(ops, x)

  end

end



local function chain(t)

  return setmetatable({[OPS]={n=0,t=t}}, mt)

end

Rudimentary test code:

local stringx = {}

for k,v in pairs(string) do stringx[k] = v end

function stringx.trim(self)

  return self:match('^%s*(%S*)%s*$')

end

function stringx.repeatchars(self, n)

  local ts = {}

  for i=1,#self do

    local c = self:sub(i,i)

    for i=1,n do ts[#ts+1] = c end

  end

  return table.concat(ts)

end



local C = chain

assert(C(stringx):trim():repeatchars(2):upper()("  test  ") == 'TTEESSTT')

local f = C(stringx):trim():repeatchars(2):upper()

assert(f"  test  " == 'TTEESSTT')

assert(f"  again  " == 'AAGGAAIINN')

print 'DONE'

Method Chaining Wrapper Take #3 - Lexical injecting with scope-aware metatable

An alternate idea is to modify the string metatable so that the extensions to the string methods are only visible within a lexical scope. The following is not perfect (e.g. nested functions), but it is a start. Example:

-- test example libraries

local stringx = {}

function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end

local stringxx = {}

function stringxx.trim(self) return self:match('^%s?(.-)%s?$') end



-- test example

function test2(s)

  assert(s.trim == nil)

  scoped_string_methods(stringxx)

  assert(s:trim() == ' 123 ')

end

function test(s)

  scoped_string_methods(stringx)

  assert(s:trim() == '123')

  test2(s)

  assert(s:trim() == '123')

end

local s = '  123  '

assert(s.trim == nil)

test(s)

assert(s.trim == nil)

print 'DONE'

The function scoped_string_methods assigns the given function table to the scope of the currently executing function. All string indexing within the scope goes through that given table.

The above uses this framework code:

-- framework

local mt = debug.getmetatable('')

local scope = {}

function mt.__index(s, k)

  local f = debug.getinfo(2, 'f').func

  return scope[f] and scope[f][k] or string[k]

end

local function scoped_string_methods(t)

  local f = debug.getinfo(2, 'f').func

  scope[f] = t

end

Method Chaining Wrapper Take #4 - Lexical injecting with Metalua

We can do something similar to the above more robustly via MetaLua. An example is below.

-{extension "lexicalindex"}



-- test example libraries

local stringx = {}

function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end



local function f(o,k)

  if type(o) == 'string' then

    return stringx[k] or string[k]

  end

  return o[k]

end



local function test(s)

  assert(s.trim == nil)

  lexicalindex f

  assert(s.trim ~= nil)

  assert(s:trim():upper() == 'TEST')

end

local s = '  test  '

assert(s.trim == nil)

test(s)

assert(s.trim == nil)



print 'DONE'

The syntax extension introduces a new keyword lexicalindex that specifies a function to be called whenever a value is to be indexed inside the current scope.

Here is what the corresponding plain Lua source looks like:

--- $ ./build/bin/metalua -S vs.lua

--- Source From "@vs.lua": ---

local function __li_invoke (__li_index, o, name, ...)

   return __li_index (o, name) (o, ...)

end



local stringx = { }



function stringx:trim ()

   return self:match "^%s*(%S*)%s*$"

end



local function f (o, k)

   if type (o) == "string" then

      return stringx[k] or string[k]

   end

   return o[k]

end



local function test (s)

   assert (s.trim == nil)

   local __li_index = f

   assert (__li_index (s, "trim") ~= nil)

   assert (__li_invoke (__li_index, __li_invoke (__li_index, s, "trim"), "upper"

) == "TEST")

end



local s = "  test  "



assert (s.trim == nil)



test (s)



assert (s.trim == nil)



print "DONE"

The lexicalindex Metalua extension is implemented as

-- lexical index in scope iff depth > 0

local depth = 0



-- transform indexing expressions

mlp.expr.transformers:add(function(ast)

  if depth > 0 then

    if ast.tag == 'Index' then

      return +{__li_index(-{ast[1]}, -{ast[2]})}

    elseif ast.tag == 'Invoke' then

      return `Call{`Id'__li_invoke', `Id'__li_index', unpack(ast)}

    end

  end

end)



-- monitor scoping depth

mlp.block.transformers:add(function(ast)

  for _,ast2 in ipairs(ast) do

    if ast2.is_lexicalindex then

      depth = depth - 1; break

    end

  end

end)



-- handle new "lexicalindex" statement

mlp.lexer:add'lexicalindex'

mlp.stat:add{'lexicalindex', mlp.expr, builder=function(x)

  local e = unpack(x)

  local ast_out = +{stat: local __li_index = -{e}}

  ast_out.is_lexicalindex = true

  depth = depth + 1

  return ast_out

end}



-- utility function

-- (note: o must be indexed exactly once to preserve behavior

return +{block:

  local function __li_invoke(__li_index, o, name, ...)

    return __li_index(o, name)(o, ...)

  end

}

--DavidManura

See Also


RecentChanges · preferences
edit · history
Last edited December 9, 2009 12:38 am GMT (diff)