Code Generation

lua-users home
wiki

Wiki editors: This is a work in progress. I'll link it in to appropriate pages when I finish. It's continuing on from the saga started in LuaSorting -- RiciLake

Customizing Sort

It's clear that shellsort can be speeded up quite a bit by specializing it for different comparison functions. However, there is no obvious criterion for deciding which specialized versions will be useful in a particular program. Rather than trying to guess what will and will not be useful, we can take advantage of Lua's built-in (and rapid) compiler in order to create specialized versions on the spot, using a template. Shellsort is particularly well suited to that, because the comparison function is only called in one place.

We want to end up with a function which returns a specialized sort function, given an arbitrary Lua expression which represents the comparison. The expression will be provided as a string; to simplify things, we'll insist that the expression use the variables a and b to represent the object's compared. For example, suppose we want to sort an array of Person objects, where each object looks like {name = "Joe", age = 32, <other fields> } . A sample call will look like this:

local sort_by_age = make_sorter  [[a.age < b.age]]



-- ... somewhere later on ...



sort_by_age(folks)

make_sorter simply inserts its argument into the Shell sort template. To make the template work, we need to rename variables inside shellsort so that the values compared end up being called a and b, to match the names in the comparison function.

Our first attempt looks like this (courtesy DavidManura who actually wrote the second version, below):

   function make_sorter(compare)

     local src = [[

       local incs = { 1391376,

                      463792, 198768, 86961, 33936,

                      13776, 4592, 1968, 861, 336, 

                      112, 48, 21, 7, 3, 1 }

 

       -- The value of the compiled chunk needs to be the sort function itself

       return function(t, n)

         for _, h in ipairs(incs) do

           for i = h + 1, n do

             local a = t[i]

             for j = i - h, 1, -h do

               local b = t[j]

               if not (]] .. compare .. [[) then break end

               t[i] = b; i = j

             end

             t[i] = a

           end 

         end

         return t

       end

     ]]

     return assert(loadstring(src, "Shellsort "..compare))()

   end

Sharing common data between generated code

This works nicely, but there's a small problem. After we've defined 35 different shell sorters, we've got 35 copies of the incs array cluttering up storage. (We also have 35 customized sort functions, but they actually take up less space.) What we'd like to do is use the same incs array for every specialized sortation function.

With Lua 5.1.1, we can pass arguments into a newly-compiled chunk, where they're available as the values of .... So, instead of just calling the chunk in order to get the sortation function, and compiling a new incs array in every chunk, we simply pass the master incs table in as an argument:

   local incs = { 1391376,

                  463792, 198768, 86961, 33936,

                  13776, 4592, 1968, 861, 336, 

                  112, 48, 21, 7, 3, 1 }

   function make_sorter(compare)

 

     -- The first line captures the argument to the chunk

     local src = [[

       local incs = ...

       return function(t, n)

         for _, h in ipairs(incs) do

           for i = h + 1, n do

             local a = t[i]

             for j = i - h, 1, -h do

               local b = t[j]

               if not (]] .. compare .. [[) then break end

               t[i] = b; i = j

             end

             t[i] = a

           end 

         end

         return t

       end

     ]]

     -- We have to call the compiled chunk with incs:

     return assert(loadstring(src, "Shellsort "..compare))(incs)

   end

Combining code generation with memoization

Now we can easily generate specialized sorters, but it's still a bit awkward. As our application grows, we find that we're generating the same sorter over and over again in various modules. Also, it's a bit annoying to have to think up a name for each sortation function, and call make_sorter for each one. It would be nicer if we could just put the comparison in the actual call. In other words, we've ended up with something like this:

local sort_by_age = make_sorter [[a.age < b.age]]

local sort_by_name = make_sorter [[a.name < b.name]]

-- This is inefficient, but simple. See below for an improvement.

local sort_by_number_of_children = make_sorter [[#(a.children or {}) < #(b.children or {})]]



-- ... much later ...



function show_offspring(grandmom)

  -- Arrange the families so the smallest ones come first

  sort_by_number_of_children(grandmom.children)

  -- In each family, sort the grandkids, if any, by age 

  for _, child in pairs(grandmom.children) do

     if child.children then sort_by_age(child.children) end

  end

end

But now we realize that we really wanted the families sorted with the largest family first, and moreover we don't use sort_by_name at all. So we've got a code maintenance problem.

Now, we could just create the sorter when we needed it, like this:

function show_offspring(grandmom)

  -- Arrange the families so the largest ones come first

  make_sorter[[#(a.children or {}) > #(b.children or {})]](grandmom.children)

  -- In each family, sort the grandkids, if any, by age 

  for _, child in pairs(grandmom.children) do

     if child.children then make_sorter[[a.age < b.age]](child.children) end

  end

end

That seem's better, although it's got a lot of punctuation. However, it means we're calling make_sorter a lot of times, and even though Lua's compiler is remarkably fast, that's still a lot of excess work.

Fortunately, with Lua we can have our cake and eat it to. We can create a table of sorting functions, indexed by the comparison string, and let Lua create the functions on demand -- a sort of virtual table. This technique is called memoization, as in memory: the result of a complex computation which is likely to be used again is remembered based on the actual argument of the function. (It is sometimes also called caching, but that word is used in quite a few contexts.)

We could write a cache into the definition of make_sorter, but memoization is a very common LuaDesignPatterns, so we might as well use a general solution, particularly since it is so simple:

-- Make this part of your standard Lua library. You'll find yourself using it a lot



-- Given a function, return a memoization table for that function

function memoize(func)

  return setmetatable({}, {

    __index = function(self, k) local v = func(k); self[k] = v; return v end,

    __call = function(self, k) return self[k] end

  })

end

Now, we can turn make_sorter into a memoized table in a single line of code:

Shellsorter = memoize(make_sorter)

and we can safely write show_offspring, confident that it will result in at most two compilations in the lifetime of our application:

function show_offspring(grandmom)

  -- Arrange the families so the largest ones come first

  Shellsorter[ "#(a.children or {}) > #(b.children or {})" ](grandmom.children)

  -- In each family, sort the grandkids, if any, by age 

  for _, child in pairs(grandmom.children) do

     if child.children then Shellsorter[ "a.age < b.age" ](child.children) end

  end

end

Note that I replaced the call to make_sorter with an index on Shellsorter. The implementation of memoize provided above makes this unnecessary; I could use Shellsorter as though it were a function. Sometimes that's useful, and this technique is further explained in FuncTables. In this case, though, it doesn't seem to add any clarity, and it only serves to create an unnecessary function call.

Interlude: the Schwartzian transform

Often (though not always) the comparison used in a sort can be expressed in the form f(a) < f(b) for some function f. The value of f(a) is often called the characteristic.

For example, we might need to sort an array of points in three-dimensional space according to their distance from some reference point. The naive version of this would be:

function distance(x, y) return ((x[1]-y[1])^2 + (x[2]-y[2])^2 + (x[3]-y[3])^2)^0.5 end

Shellsorter[" distance(a, ref) < distance(b, ref) "](array)

A moment's thought should suggest that it would be better to remove the square root computation, since it makes no difference to the final outcome. Even so, the computation is somewhat time-consuming. Roughly speaking, most sorting algorithms will perform about N log₂N comparisons to sort an array of N objects, which means that distance above will be called 2N log₂N times on a total of N objects; in other words, it will be called 2 log₂ N times on each object, yielding the same result each time. If we had. say, 100,000 points, that would mean approximately 34 calls to distance for each point. The discussion of memoization above suggests that it should be possible to just call it once once on each object.

Characteristics don't need to be numbers; they just need to be simpler to compare than the original values. For example, sorting words according to the rules of a given language can be done by transforming each word into a longer string which can be compared with a simple string comparison. (See the strxfrm() function in Posix.)

The easiest way of sorting by a characteristic function is to first compute the characteristic of each element, constructing an array of pairs {characteristic, element} . This array can then be sorted using a much simpler comparison function, and then the sorted array can be turned back into an array of objects by just selecting the second element of each pair. This approach is a common idiom in Perl, where it is called the Schwartzian transform, named after the famous Perl hacker Randall Schwartz. [wikipedia]

A naive translation of the Perl idiom to Lua turns out to be extremely inefficient, however; the array of pairs uses an enormous amount of space. However, an analogous solution exists, which elegantly solves the related problem of sorting tables other than arrays.

In general, a Lua table is a set of mappings from key to value:

T = {k₁→v₁, k₂→v₂, … kn→vn}

In order to provide an ordered view of this table, we can construct an array of keys:

K = {1→k₁,2→k₂, … n→kn;} .

We can then iterate over the key array and recover the key-value pair by looking up the key in the original table, using an iterator like this:

function pairs_in_order(T, K)

  local i = 0

  return function iter()

    i = i + 1

    local k = K[i]

    if k then return k, T[k] end

  end

end

To sort such a table, we can construct a third table, the characteristic table, which maps keys onto characteristics:

C = {k₁→f(k₁, v₁), k₂→f(k₂, v₂), … kn→f(kn, vn)}

Note that we've given the characteristic function both the key and the value, since both might participate in the sortation order. We then sort the array of keys, looking up each key in the characteristic table:

-- K = keytable(T)

K = {}; for k in pairs(T) do K[#K+1] = k end

-- C = map(f, T)

C = {}; for k, v in pairs(T) do C[k] = f(k, v) end

-- Sort

Shellsorter["C[a] < C[b]"](K)

-- Get rid of C (see below)

C = nil

-- Iterate the sorted key, value pairs

for k, v in pairs_in_order(T, K) do

  print(k, v)

end

This technique works regardless of the keys in the original table. What's more, because the key is given to the comparison function, we can ensure a stable sort by falling back on key comparison if the characteristics of the values are equal; we just replace the sort with:

Shellsorter["C[a] < C[b] or C[a] == C[b] and a < b"](K)

But there's a problem: we'd really like the auxiliary tables to be local variables; in particular, the characteristic table is a temporary value, and we'd like it to be garbage collected as soon as possible. But the code generation we've done does not allow the comparison function to refer to upvalues.

Hygienic Macros

Often, we would like to insert a piece of code into a boilerplate template. However, that can lead to obscure problems if the inserted code has a free variable which happens to be bound at the point of insertion.

To be continued...

See Also


RecentChanges · preferences
edit · history
Last edited May 28, 2009 2:06 am GMT (diff)