Building a computational graph: part 3
This is the third part of a series on creating a computational graph in Python. You can go back to part one and part two.
There’s a bit of code in this post. Here is the final code in np_wrapping.py
. The module for numpy_autograd
is here.
In the last post we created computational graphs for a function, but it was a bit hard to use. We also had these problems:
- we didn’t want to replace
np.add
withadd_new
,np.exp
withexp_new
etc everywhere. That’s a pain, especially we have a lot of code to do that for. - we had to implement primitives manually for every
numpy
function we want. Is there a way to get them all? - how do we handle non-differentiable functions?
We’ll answer these questions here.
Here’s the gist of it. We create a fake numpy
module called numpy_autograd
with wrapped versions of all numpy
functions. This fake numpy
contains all the functions and objects of the original numpy
, except some functions (only the differentiable ones) are added to a computational graph as they are called. Then by writing import numpy_autograd as np
, any functions using numpy functions like np.add
automatically build a computation graph as they are executed.
Non-differentiable functions Link to heading
Autodiff packages like autograd
have to watch out for non-differentiable functions. Many functions are not differentiable, like np.asarray
, np.shape
or np.argmin
.
Take np.floor(x)
as an example. This is a non-differentiable function: its derivative does not exist for integer values of $x$, and the derivative is 0 everywhere else. So this is not something we’d add to the computation graph if we encountered it.
How should we deal with these functions? There are a few approaches. Some packages like autograd
don’t add them to the graph completely. The approach I take here is a bit different: I add them to the computation graph, but I’ll add a flag keepgrad
that indicates if the gradient of this function should be calculated or not. So let’s go ahead and modify our primitive
function from earlier to include this parameter:
def primitive(f, keepgrad=True):
@wraps(f)
def inner(*args, **kwargs):
## Code to add operation/primitive to computation graph
# We need to separate out the integer/non node case. Sometimes you are adding
# constants to nodes.
def getval(o): return o.value if type(o) == Node else o
if len(args): argvals = [getval(o) for o in args]
else: argvals = args
if len(kwargs): kwargvals = dict([(k,getval(o)) for k,o in kwargs.items()])
else: kwargvals = kwargs
# get parents
l = list(args) + list(kwargs.values())
parents = [o for o in l if type(o) == Node ]
value = f(*argvals, **kwargvals)
print("add", "'" + f.__name__ + "'", "to graph with value",value)
return Node(value, f, parents, keepgrad)
return inner
Triage Link to heading
A quick note. It can get confusing working with “original” numpy
and “new” numpy
, so note that throughout this post if you see something prefixed with _np
, that means “original” numpy
. Later I use anp
to refer to “new” numpy
.
Anyway, it’s time to create our version of numpy
. All the attributes of numpy
are available in _np.__dict__
. We are going to split the objects in this dict into three categories:
a) differentiable functions (wrap with primitive, keepgrad=True
(the default for primitive
)
b) non-differentiable functions (wrap with primitive, keepgrad=False
)
c) everything else (leave unchanged)
We create a function wrap_namespace
that will copy everything from _np.__dict__
into a new dictionary, wrapping functions based on the three categories above.
import numpy as _np
def wrap_namespace(old, new):
"""Performs triage on objects from numpy, copying them from old to new namespace.
old: __dict__ from original numpy
new: dict to copy old into
"""
# Taken from here:
# https://github.com/mattjj/autodidact/blob/b3b6e0c16863e6c7750b0fc067076c51f34fe271/autograd/numpy/numpy_wrapper.py#L8
nograd_functions = [
_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type, _np.zeros_like,
_np.ones_like, _np.floor, _np.ceil, _np.round, _np.rint, _np.around,
_np.fix, _np.trunc, _np.all, _np.any, _np.argmax, _np.argmin,
_np.argpartition, _np.argsort, _np.argwhere, _np.nonzero, _np.flatnonzero,
_np.count_nonzero, _np.searchsorted, _np.sign, _np.ndim, _np.shape,
_np.floor_divide, _np.logical_and, _np.logical_or, _np.logical_not,
_np.logical_xor, _np.isfinite, _np.isinf, _np.isnan, _np.isneginf,
_np.isposinf, _np.allclose, _np.isclose, _np.array_equal, _np.array_equiv,
_np.greater, _np.greater_equal, _np.less, _np.less_equal, _np.equal,
_np.not_equal, _np.iscomplexobj, _np.iscomplex, _np.size, _np.isscalar,
_np.isreal, _np.zeros_like, _np.ones_like, _np.result_type
]
function_types = {_np.ufunc, types.FunctionType, types.BuiltinFunctionType}
for name,obj in old.items():
if obj in nograd_functions:
# non-differentiable functions
new[name] = primitive(obj, keepgrad=False)
elif type(obj) in function_types: # functions with gradients
# differentiable functions
new[name] = primitive(obj)
else:
# just copy over
new[name] = obj
Creating our new module Link to heading
We’re ready to bring all our code together into a module. Here is an overview of the basic procedure.
- make new folder
./numpy_autograd
- put stuff in a file
np_wrapping.py
with definitions ofprimitive
notrace_primitive
Node
wrap_namespace
- add implementations of the dunder methods to the
Node
class withsetattr
. - make
__init__.py
and at the top putimport * from np_wrapping
- now import using
import numpy_autograd as np
and you are done
To get started, create a new folder called numpy_autograd
, and then create a file called np_wrapper.py
. Put the below code in this file.
Start off with imports and primitive
:
import numpy as _np
import types
from functools import wraps
def primitive(f, keepgrad=True):
@wraps(f)
def inner(*args, **kwargs):
## Code to add operation/primitive to computation graph
# We need to separate out the integer/non node case. Sometimes you are adding
# constants to nodes.
def getval(o): return o.value if type(o) == Node else o
if len(args): argvals = [getval(o) for o in args]
else: argvals = args
if len(kwargs): kwargvals = dict([(k,getval(o)) for k,o in kwargs.items()])
else: kwargvals = kwargs
# get parents
l = list(args) + list(kwargs.values())
parents = [o for o in l if type(o) == Node ]
value = f(*argvals, **kwargvals)
print("add", "'" + f.__name__ + "'", "to graph with value",value)
return Node(value, f, parents, keepgrad)
return inner
Now add the latest iteration of the Node
class. It’s similar to before, except it’s modified to incorporate the keepgrad
parameter. I’ve also moved the start_node
function from last time to be a static method of the class, instead of having it float around by itself.
class Node:
"""A node in a computation graph."""
def __init__(self, value, fun, parents, keepgrad):
self.parents = parents
self.value = value
self.fun = fun
self.keepgrad = keepgrad
def __repr__(self):
"""A (very) basic string representation"""
if self.value is None: str_val = 'None'
else: str_val = str(self.value)
return "\n" + "Fun: " + str(self.fun) +\
" Value: "+ str_val + \
" Parents: " + str(self.parents)
def start_node(value = None, keepgrad=True):
"""A function to create an empty node to start off the graph"""
fun,parents = lambda x: x, []
return Node(value, fun, parents, keepgrad=True)
Then the wrap_namespace
function from earlier:
def wrap_namespace(old, new):
"""Performs triage on objects from numpy, copying them from old to new namespace.
old: __dict__ from original numpy
new: dict to copy old into
"""
# Taken from here:
# https://github.com/mattjj/autodidact/blob/b3b6e0c16863e6c7750b0fc067076c51f34fe271/autograd/numpy/numpy_wrapper.py#L8
nograd_functions = [
_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type, _np.zeros_like,
_np.ones_like, _np.floor, _np.ceil, _np.round, _np.rint, _np.around,
_np.fix, _np.trunc, _np.all, _np.any, _np.argmax, _np.argmin,
_np.argpartition, _np.argsort, _np.argwhere, _np.nonzero, _np.flatnonzero,
_np.count_nonzero, _np.searchsorted, _np.sign, _np.ndim, _np.shape,
_np.floor_divide, _np.logical_and, _np.logical_or, _np.logical_not,
_np.logical_xor, _np.isfinite, _np.isinf, _np.isnan, _np.isneginf,
_np.isposinf, _np.allclose, _np.isclose, _np.array_equal, _np.array_equiv,
_np.greater, _np.greater_equal, _np.less, _np.less_equal, _np.equal,
_np.not_equal, _np.iscomplexobj, _np.iscomplex, _np.size, _np.isscalar,
_np.isreal, _np.zeros_like, _np.ones_like, _np.result_type
]
function_types = {_np.ufunc, types.FunctionType, types.BuiltinFunctionType}
for name,obj in old.items():
if obj in nograd_functions:
# non-differentiable functions
new[name] = primitive(obj, keepgrad=False)
elif type(obj) in function_types: # functions with gradients
# differentiable functions
new[name] = primitive(obj)
else:
# just copy over
new[name] = obj
Now call wrap_namespace()
. We’ll hold the wrapped functions in a dict called anp
, which we init with the current value of globals()
. Calling this anp
is to make the syntax for the next step (operator overloading) a bit clearer.
# using globals() here means we can access each np function like np.add:
# it means it is available to the global space.
anp = globals()
wrap_namespace(_np.__dict__, anp)
Finally it’s time for operator overloading. Instead of defining these all in Node
, we use setattr
to add each method one by one to the Node
class. There are many more functions than before and many more dunder methods to match. There are also properties to deal with, like np.ndim
, and we use the property
keyword of Python to handle these.
## Definitions taken from here:
## https://github.com/mattjj/autodidact/blob/b3b6e0c16863e6c7750b0fc067076c51f34fe271/autograd/numpy/numpy_boxes.py#L8
setattr(Node, 'ndim', property(lambda self: self.value.ndim))
setattr(Node, 'size', property(lambda self: self.value.size))
setattr(Node, 'dtype',property(lambda self: self.value.dtype))
setattr(Node, 'T', property(lambda self: anp['transpose'](self)))
setattr(Node, 'shape', property(lambda self: self.value.shape))
setattr(Node,'__len__', lambda self, other: len(self._value))
setattr(Node,'astype', lambda self,*args,**kwargs: anp['_astype'](self, *args, **kwargs))
setattr(Node,'__neg__', lambda self: anp['negative'](self))
setattr(Node,'__add__', lambda self, other: anp['add']( self, other))
setattr(Node,'__sub__', lambda self, other: anp['subtract'](self, other))
setattr(Node,'__mul__', lambda self, other: anp['multiply'](self, other))
setattr(Node,'__pow__', lambda self, other: anp['power'](self, other))
setattr(Node,'__div__', lambda self, other: anp['divide']( self, other))
setattr(Node,'__mod__', lambda self, other: anp['mod']( self, other))
setattr(Node,'__truediv__', lambda self, other: anp['true_divide'](self, other))
setattr(Node,'__matmul__', lambda self, other: anp['matmul'](self, other))
setattr(Node,'__radd__', lambda self, other: anp['add']( other, self))
setattr(Node,'__rsub__', lambda self, other: anp['subtract'](other, self))
setattr(Node,'__rmul__', lambda self, other: anp['multiply'](other, self))
setattr(Node,'__rpow__', lambda self, other: anp['power']( other, self))
setattr(Node,'__rdiv__', lambda self, other: anp['divide']( other, self))
setattr(Node,'__rmod__', lambda self, other: anp['mod']( other, self))
setattr(Node,'__rtruediv__', lambda self, other: anp['true_divide'](other, self))
setattr(Node,'__rmatmul__', lambda self, other: anp['matmul'](other, self))
setattr(Node,'__eq__', lambda self, other: anp['equal'](self, other))
setattr(Node,'__ne__', lambda self, other: anp['not_equal'](self, other))
setattr(Node,'__gt__', lambda self, other: anp['greater'](self, other))
setattr(Node,'__ge__', lambda self, other: anp['greater_equal'](self, other))
setattr(Node,'__lt__', lambda self, other: anp['less'](self, other))
setattr(Node,'__le__', lambda self, other: anp['less_equal'](self, other))
setattr(Node,'__abs__', lambda self: anp['abs'](self))
setattr(Node,'__hash__', lambda self: id(self))
Finally, make another file __init__.py
in the numpy_autograd
folder and add this to it:
from .np_wrapping import *
Now you’re done!
Using the wrapped numpy Link to heading
Let’s test it out:
import numpy_autograd as np
def logistic(z): return 1 / (1 + np.exp(-z))
a1 = logistic(1.5)
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
This works fine. What about something using a non-differentiable function?
def f2(x,y): return np.add(np.floor(np.log(x) * np.exp(y) + x*y), np.exp(x))
a2 = f2(5,1)
add 'log' to graph with value 1.6094379124341003
add 'exp' to graph with value 2.718281828459045
add 'multiply' to graph with value 4.374905831402675
add 'add' to graph with value 9.374905831402675
add 'floor' to graph with value 9.0
add 'exp' to graph with value 148.4131591025766
add 'add' to graph with value 157.4131591025766
Nice! And that’s how to create a computation graph.