Symbolic Differentiation

lua-users home
wiki

This started out as a little exercise for the PenlightLibraries, but became sufficiently obsessive to warrant a more throrough implemetation.

The first step in symbolic algebra is defining a representation. Getting expressions into a suitable form is actually straightforward; no parsing of expressions is needed, since we have Lua to do that for us. Using the pl.func library does all the hard work; it redefines the arithmetic operations to work on placeholder expressions (PEs), which are Lua expressions involving dummy variables called placeholders. pl.func defines standard placeholders for arguments called _1,_2, etc but the Var function will create new ones of our chosing:

utils.import 'pl.func'

a,b,c,d = Var 'a,b,c,d'

print(a+b+c+d)

Which will indeed print out the expression in a readable form. PE operator expressions are stored as combinations of tables looking like {op='+',x,y} , which have an associated metatable which defines the metamethods like __add and so forth. As a tree, with the usual associativity of Lua operators we get:

It is irritating to draw these diagrams, so a better notation is Lisp-style S-expressions:


1: (+ (+ (+ a b) c) d)

However, with the various manipulations we will perform, this canonical form is not the only possible representation of a+b+c+d:


2: (+ a (+ b (+ c d)))

3: (+ (+ a b) (+ c d))

Now, experience shows that this leads to madness. Instead, it's easier to go for the canonical Lisp representation:


4: (+ a b c d)

Many operations become straightforward once this is in place, for instance comparing with (+ a c b d) is just a matter of doing a 'compare with no order' on the arguments. Displaying PEs in this form is straightforward. isPE simply checks the expression to see if it is a placeholder expression, by looking at the metatable. PEs with op=='X' are placeholder variables, so the rest must be expression nodes.

function sexpr (e)

	if isPE(e) then

  	  if e.op ~= 'X' then

	    local args = tablex.imap(sexpr,e)

	    return '('..e.op..' '..table.concat(args,' ')..')'

	  else

	    return e.repr

	  end

	else

	  return tostring(e)

	end

end

The first task is to balance the expressions, which converts representations 1-3 into 4.

function balance (e)

	if isPE(e) and e.op ~= 'X' then

	  local op,args = e.op

	  if op == '+' or op == '*' then

		args = rcollect(e)

	  else

		args = imap(balance,e)

	  end

	  for i = 1,#args do

		e[i] = args[i]

	  end

	end

	return e

end

For the non-commutative operators, the idea is just to balance all the subexpressions by mapping balance over the array part of the PE, which is the argument list. These are then copied back in-place. The non-trivial part is dealing with + and *, where it is necessary to collect all the arguments from expression trees looking like 1,2 or 3 and convert them into the fourth form.

function tcollect (op,e,ls)

    if isPE(e) and e.op == op then

	for i = 1,#e do

	   tcollect(op,e[i],ls)

        end

    else

        ls:append(e)

        return

    end

end



function rcollect (e)

    local res = List()

    tcollect(e.op,e,res)

    return res

end

This recursively goes down same-operator chains (the (+ (+ ...) mentioned earlier) and collects the arguments, flattening them into n-ary + or * expressions.

Here is a useful function, which follows the same recursive pattern:

-- does this PE contain a reference to x?

function references (e,x)

    if isPE(e) then

		if e.op == 'X' then return x.repr == e.repr

		else

			return find_if(e,references,x)

		end

	else

		return false

	end

end

Here are functions to create n-ary products and sums:

function muli (args) return PE{op='*',unpack(args)} end

function addi (args) return PE{op='+',unpack(args)} end

With this in place, the basic differentiation rules are not difficult. Firstly, only consider subexpressions which do contain the variable:

function diff (e,x)

    if isPE(e) and references(e,x) then

		local op = e.op

        if op == 'X' then

            return 1

	else

   	    local a,b = e[1],e[2]

            if op == '+' then -- differentiation is linear

		local args = imap(diff,e,x)

		return balance(addi(args))

            elseif op == '*' then -- product rule

		local res,d,ee = {}

		for i = 1,#e do

			d = fold(diff(e[i],x))

			if d ~= 0 then

	 		  ee = {unpack(e)} -- make a copy

			  ee[i] = d

			  append(res,balance(muli(ee)))

			end

		end

		if #res > 1 then return addi(res)

		else return res[1] end

            elseif op == '^' and isnumber(b) then -- power rule

                return b*x^(b-1)

            end

        end

	else

		return 0

	end

end

The derivative of a sum of expressions is the sum of the derivatives. Again, imap does the job of applying the function recursively over the subexpressions. After constructing the result, we re-balance for luck.

The product rule is given here in its general form, with an explicit check for terms which work out to zero - that is the job of fold, which will be discussed next.


(uvw..)' = u'vw.. + uv'w... + uvw'... + ...

And finally, the simple power rule. Note how the result can be expressed in a straightforward fashion, since all these operators are acting on PEs.

In fact, all these rules are certainly clearer if you use form 1, binary + and *! But then simplification becomes unbearable. And simplification ('folding') is the tricky one to get right. fold is a longish function, so I will deal with it in sections:

local op = e.op

local addmul = op == '*' or op == '+'

-- first fold all arguments

local args = imap(fold,e)

if not addmul and not find_if(args,isPE) then

  -- no placeholders in these args, we can fold the expression.

  local opfn = optable[op]

  if opfn then

    return opfn(unpack(args))

  else

   return '?'

  end

elseif addmul then

The first if is looking for a case where a subexpression has no symbols, i.e. it is something like 2*5 or 10^2; in this case, the constant can be completely folded. optable (defined in pl.operator) gives a mapping between the operator names and the actual function implementing them.

elseif op == '^' then

  if args[2] == 1 then return args[1] end -- identity

  if args[2] == 0 then return 1 end

end

return PE{op=op,unpack(args)}

This clause is clearing up expressions like x^1 and y^0 which naturally arise from the power rule in diff. Once args has been processed, the expression can be put together again.

The bulk of this routine handles the awkward twins, + and *.

-- split the args into two classes, PE args and non-PE args.

local classes = List.partition(args,isPE)

local pe,npe = classes[true],classes[false]

List.partition takes a list and a function, which takes one argument and returns a single value. The result is table where the keys are the returned values, and the values are lists of those elements where the function returned that value. So:

List{1,2,3,4}:partition(function(x) return x > 2 end)

--> {false={1,2},true={3,4}}

List{'one',math.sin,10,20,{1,2}}:partition(type)

--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }

(Mathematically, these are referred to as equivalence classes and partition would be called the quotient set)

In this case, we want to separate the non-symbolic arguments from the symbolic arguments; order does not matter. The non-symbolic arguments npe can folded into a constant. At this point, operator identity rules can kick in, so that we can drop (* 0 x) and simplify (+ 0 x) to be just x.

The final simplification is replacing repeated values, so that (* x x) should become (^ x 2) and (+ x x x) should become (* x 3). count_map from pl.tablex will do the job. It is given a list-like table and a function which defines equivalence, and returns a map from the values to the number of their occurrences, so that count_map{'a','b','a'} is {a=2,b=1} .

Given this test function:

function testdiff (e)

  balance(e)

  e = diff(e,x)

  balance(e)

  print('+ ',e)

  e = fold(e)

  print('- ',e)

end

and these cases:

testdiff(x^2+1)

testdiff(3*x^2)

testdiff(x^2 + 2*x^3)

testdiff(x^2 + 2*a*x^3 + x^4)

testdiff(2*a*x^3)

testdiff(x*x*x)

we get these results, showing why something like fold is so necessary to process the result of diff.


+ 	2 * x ^ 1 + 0

- 	2 * x

+ 	3 * 2 * x

- 	6 * x

+ 	2 * x ^ 1 + 2 * 3 * x ^ 2

- 	2 * x + 6 * x ^ 2

+ 	2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3

- 	6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x

+ 	2 * a * 3 * x ^ 2

- 	6 * a * x ^ 2

+ 	1 * x * x + x * 1 * x + x * x * 1

- 	x ^ 2 * 3

https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua

https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua

SteveDonovan

See Also


RecentChanges · preferences
edit · history
Last edited July 4, 2012 9:41 am GMT (diff)