Building a computational graph: part 2
This is the second post in a series on computational graphs. You can go back to the previous post or ahead to the next post.
If you’d like to see what we are working towards in these posts, here is the Github link:
Last time we
- looked at computational graphs and their use in autodiff packages.
- looked at the autodiff problem and the structure of the
grad
function inautograd
- showed how Python breaks down expressions to create computational graphs
- created a simple graph manually using a simplified
Node
class
How do we automatically create a computational graph for a function? We could create it manually last time, but we’ll need to be able to do it automatically for any function. That’s what we cover here.
As a running example we’ll use this logistic function throughout:
def logistic(z): return 1 / (1 + np.exp(-z))
Primitives Link to heading
Loosely speaking, a primitive is a basic operation, like $+, \times, /, \exp$ or $\log$. We want to create a function for each primitive that adds them to a computation graph whenever they are called. Something like this:
def add_new(x,y):
# add to computation graph
print('Add to graph!')
return x+y
The numpy
package implements well-tested functions for each primitive, like np.add
, np.multiply
or np.exp
. Because numpy
goes to all the work of creating reliable, tested primitives, it’d be great to reuse their work instead of creating our functions from scratch. So that’s what we’ll do.
We create a function primitive
that
- takes a function
f
as an input (which will be anumpy
function) - returns the same function
f
, except we addf
to our computation graph as aNode
.
Here’s the basic structure of primitive
, just with placeholder code for the computational-graph adding bit.
def primitive(f):
def inner(*args, **kwargs):
"""This is a nested function"""
# add to graph
print("add to graph!")
return f(*args, **kwargs)
return inner
Use it like this.
mult_new = primitive(np.multiply) #
print(mult_new(1,4))
add to graph!
4
Since primitive
is a function that returns a function, we can also use it as a decorator. I’ve written this other post on decorators if you want to know more.
# another way to use it
@primitive
def mult_new2(*args, **kwargs): return np.multiply(*args, **kwargs)
print(mult_new2(1,4))
add to graph!
4
A problem with this as it stands is that we lose all the metadata of the numpy
function we wrap in primitive
, like its documentation and name. It won’t get copied over. Instead this new function has the metadata of the nested function inner
inside primitive
.
print("Name of new function:", mult_new.__name__)
print("Doc of new function:", mult_new.__doc__)
Name of new function: inner
Doc of new function: This is a nested function
We obviously don’t want this, but we can get around it by adding the @wraps(f)
decorator from the functools
package above inner
inside the primitive
definition. This copies over the name, docs, and some other things from the numpy function to our version. Now we don’t lose all the documentation.
from functools import wraps
def primitive(f):
@wraps(f)
def inner(*args, **kwargs):
"""This is a nested function"""
# add to graph
print("add to graph!")
return f(*args, **kwargs)
return inner
mult_new3 = primitive(np.multiply)
mult_new3.__name__ # multiply
print(mult_new3.__doc__[0:300])
multiply(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Multiply arguments element-wise.
Parameters
----------
x1, x2 : array_like
Input arrays to be multiplied.
out : ndarray, None, or tuple of ndarray and None, optional
A
Creating primitives Link to heading
Last time we created a Node class. Remember, Nodes hold operations/primitives in them (as the fun
attribute), the value at that point, and their parents in the graph.
Below is the same Node
class. I have just added a __repr__
method to make debugging a bit easier.
class Node:
"""A node in a computation graph."""
def __init__(self, value, fun, parents):
self.parents = parents
self.value = value
self.fun = fun
def __repr__(self):
"""A (very) basic string representation"""
if self.value is None: str_val = 'None'
else: str_val = str(round(self.value,3))
return "\n" + "Fun: " + str(self.fun) +\
" Value: "+ str_val + \
" Parents: " + str(self.parents)
Let’s create some primitives. There are a few differences to before:
inner
doesn’t return a function value likef(*args, **kwargs)
, but aNode
with the function value as thevalue
attribute:Node(f(*args, **kwargs), f, args)
- Sometimes
Node
’s interact with integers. There is some extra code below to handle that situation, mostly around extracting thevalue
attribute of the node and savng that inargs
andkwargs
for use inf
.
from functools import wraps
def primitive(f):
@wraps(f)
def inner(*args, **kwargs):
## Code to add operation/primitive to computation graph
# We need to separate out the integer/non node case. Sometimes you are adding
# constants to nodes.
def getval(o): return o.value if type(o) == Node else o
if len(args): argvals = [getval(o) for o in args]
else: argvals = args
if len(kwargs): kwargvals = dict([(k,getval(o)) for k,o in kwargs.items()])
else: kwargvals = kwargs
# get parents
l = list(args) + list(kwargs.values())
parents = [o for o in l if type(o) == Node ]
value = f(*argvals, **kwargvals)
print("add", "'" + f.__name__ + "'", "to graph with value",value)
return Node(value, f, parents)
return inner
Now wrap some basic numpy
functions with primitive
to get computational-graph versions of these functions:
add_new = primitive(np.add)
mul_new = primitive(np.multiply)
div_new = primitive(np.divide)
sub_new = primitive(np.subtract)
neg_new = primitive(np.negative)
exp_new = primitive(np.exp)
Let’s try it out! We can’t try it out on our logistic
function yet, because that uses operators like $+$ and $\times$ instead of np.add
and np.multiply
, and we haven’t done any operator overloading. But we can write out the logistic
function in terms of the operators and see if it works. We should get a final value of 0.818
(and indeed we do).
def start_node(value = None):
"""A function to create an empty node to start off the graph"""
fun,parents = lambda x: x, []
return Node(value, fun, parents)
z = start_node(1.5)
t1 = mul_new(z, -1)
t2 = exp_new(t1)
t3 = add_new(t2, 1)
y = div_new(1,t3)
print("Final answer:", round(y.value,3)) # correct final output
print(y)
add 'multiply' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
Final answer: 0.818
Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'multiply'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fea27b8> Value: 1.5 Parents: []]]]]
Operator overloading Link to heading
We want to be able to use these functions for common operators. In other words, if we define a function def f(x,y): return x+y
, and we pass in two Node
objects to f
as x
and y
, we want f
to use our add_new
method.
Let’s do this. All we have to do is redefine a version of Node
that implements the relevant dunder methods:
class Node:
"""A node in a computation graph."""
def __init__(self, value, fun, parents):
self.parents = parents
self.value = value
self.fun = fun
def __repr__(self):
"""A (very) basic string representation"""
if self.value is None: str_val = 'None'
else: str_val = str(round(self.value,3))
return "\n" + "Fun: " + str(self.fun) +\
" Value: "+ str_val + \
" Parents: " + str(self.parents)
## Code to overload operators
# Don't put self.value or other.value in the arguments of these functions,
# otherwise you won't be able to access the Node object to create the
# computational graph.
# Instead, pass the whole node through. And to prevent recursion errors,
# extract the value inside the `primitive` function.
def __add__(self, other): return add_new(self, other)
def __radd__(self, other): return add_new(other, self)
def __sub__(self, other): return sub_new(self, other)
def __rsub__(self, other): return sub_new(other, self)
def __truediv__(self, other): return div_new(self, other)
def __rtruediv__(self, other): return div_new(other, self)
def __mul__(self, other): return mul_new(self, other)
def __rmul__(self, other): return mul_new(other, self)
def __neg__(self): return neg_new(self)
def __exp__(self): return exp_new(self)
Now we can add nodes using $+$, divide them with $/$ and so on. Here is a basic example of adding Nodes with $+$:
val_z = 1.5
z = Node(val_z, None, [])
val_t1 = 4
t1 = Node(val_t1, None, [])
y = z + t1
add 'add' to graph with value 5.5
Here is the graph of y
:
print(y)
Fun: <ufunc 'add'> Value: 5.5 Parents: [
Fun: None Value: 1.5 Parents: [],
Fun: None Value: 4 Parents: []]
Let’s try it out on a modified version of logistic
function that uses our exp_new
function.
def logistic2(z): return 1 / (1 + exp_new(-z))
y = logistic2(start_node(value = 1.5))
add 'negative' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
The graph of y
:
print(y)
Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'negative'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fe90f28> Value: 1.5 Parents: []]]]]
Sweet! It is working. Now try a multivariate functions.
def somefun(x,y): return (x*y + exp_new(x)*exp_new(y))/(4*y)
def somefun2(x,y): return (x*y + np.exp(x)*np.exp(y))/(4*y)
val_x, val_y = 3,4
ans = somefun(start_node(3), start_node(4))
add 'multiply' to graph with value 12
add 'exp' to graph with value 20.085536923187668
add 'exp' to graph with value 54.598150033144236
add 'multiply' to graph with value 1096.6331584284585
add 'add' to graph with value 1108.6331584284585
add 'multiply' to graph with value 16
add 'true_divide' to graph with value 69.28957240177866
Graph of ans
:
print(ans)
Fun: <ufunc 'true_divide'> Value: 69.29 Parents: [
Fun: <ufunc 'add'> Value: 1108.633 Parents: [
Fun: <ufunc 'multiply'> Value: 12 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: [],
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []],
Fun: <ufunc 'multiply'> Value: 1096.633 Parents: [
Fun: <ufunc 'exp'> Value: 20.086 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: []],
Fun: <ufunc 'exp'> Value: 54.598 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]],
Fun: <ufunc 'multiply'> Value: 16 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]
The result looks complex, but that is because our __repr__
function is basic and doesn’t handle nested representations. Still, all the information is there, and we have created a computational graph successfully.
Next steps Link to heading
At this point we can create functions using common operators and automatically trace their computation graph. Nice!
But we aren’t quite there yet. There’s a few things missing.
- we don’t want to replace
np.add
withadd_new
,np.exp
withexp_new
etc everywhere. That’s a pain, especially we have a lot of code to do that for. - currently we have to implement primitives for every
numpy
function we want. Is there a way to get them all? - how do we handle non-differentiable functions?
We’ll cover these in the next post!