# Building a computational graph: part 2

This is the second post in a series on computational graphs. You can go back to the previous post or ahead to the next post.

If you’d like to see what we are working towards in these posts, here is the Github link:

Last time we

- looked at computational graphs and their use in autodiff packages.
- looked at the autodiff problem and the structure of the
`grad`

function in`autograd`

- showed how Python breaks down expressions to create computational graphs
- created a simple graph manually using a simplified
`Node`

class

How do we automatically create a computational graph for a function? We could create it manually last time, but we’ll need to be able to do it automatically for any function. That’s what we cover here.

As a running example we’ll use this logistic function throughout:

```
def logistic(z): return 1 / (1 + np.exp(-z))
```

## Primitives Link to heading

Loosely speaking, a *primitive* is a basic operation, like $+, \times, /, \exp$ or $\log$. We want to create a function for each primitive that adds them to a computation graph whenever they are called. Something like this:

```
def add_new(x,y):
# add to computation graph
print('Add to graph!')
return x+y
```

The `numpy`

package implements well-tested functions for each primitive, like `np.add`

, `np.multiply`

or `np.exp`

. Because `numpy`

goes to all the work of creating reliable, tested primitives, it’d be great to reuse their work instead of creating our functions from scratch. So that’s what we’ll do.

We create a function `primitive`

that

- takes a function
`f`

as an input (which will be a`numpy`

function) - returns the same function
`f`

, except we add`f`

to our computation graph as a`Node`

.

Here’s the basic structure of `primitive`

, just with placeholder code for the computational-graph adding bit.

```
def primitive(f):
def inner(*args, **kwargs):
"""This is a nested function"""
# add to graph
print("add to graph!")
return f(*args, **kwargs)
return inner
```

Use it like this.

```
mult_new = primitive(np.multiply) #
print(mult_new(1,4))
```

```
add to graph!
4
```

Since `primitive`

is a function that returns a function, we can also use it as a decorator. I’ve written this other post on decorators if you want to know more.

```
# another way to use it
@primitive
def mult_new2(*args, **kwargs): return np.multiply(*args, **kwargs)
print(mult_new2(1,4))
```

```
add to graph!
4
```

A problem with this as it stands is that we lose all the metadata of the `numpy`

function we wrap in `primitive`

, like its documentation and name. It won’t get copied over. Instead this new function has the metadata of the nested function `inner`

inside `primitive`

.

```
print("Name of new function:", mult_new.__name__)
print("Doc of new function:", mult_new.__doc__)
```

```
Name of new function: inner
Doc of new function: This is a nested function
```

We obviously don’t want this, but we can get around it by adding the `@wraps(f)`

decorator from the `functools`

package above `inner`

inside the `primitive`

definition. This copies over the name, docs, and some other things from the numpy function to our version. Now we don’t lose all the documentation.

```
from functools import wraps
def primitive(f):
@wraps(f)
def inner(*args, **kwargs):
"""This is a nested function"""
# add to graph
print("add to graph!")
return f(*args, **kwargs)
return inner
mult_new3 = primitive(np.multiply)
mult_new3.__name__ # multiply
print(mult_new3.__doc__[0:300])
```

```
multiply(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Multiply arguments element-wise.
Parameters
----------
x1, x2 : array_like
Input arrays to be multiplied.
out : ndarray, None, or tuple of ndarray and None, optional
A
```

### Creating primitives Link to heading

Last time we created a Node class. Remember, Nodes hold operations/primitives in them (as the `fun`

attribute), the value at that point, and their parents in the graph.

Below is the same `Node`

class. I have just added a `__repr__`

method to make debugging a bit easier.

```
class Node:
"""A node in a computation graph."""
def __init__(self, value, fun, parents):
self.parents = parents
self.value = value
self.fun = fun
def __repr__(self):
"""A (very) basic string representation"""
if self.value is None: str_val = 'None'
else: str_val = str(round(self.value,3))
return "\n" + "Fun: " + str(self.fun) +\
" Value: "+ str_val + \
" Parents: " + str(self.parents)
```

Let’s create some primitives. There are a few differences to before:

`inner`

doesn’t return a function value like`f(*args, **kwargs)`

, but a`Node`

with the function value as the`value`

attribute:`Node(f(*args, **kwargs), f, args)`

- Sometimes
`Node`

’s interact with integers. There is some extra code below to handle that situation, mostly around extracting the`value`

attribute of the node and savng that in`args`

and`kwargs`

for use in`f`

.

```
from functools import wraps
def primitive(f):
@wraps(f)
def inner(*args, **kwargs):
## Code to add operation/primitive to computation graph
# We need to separate out the integer/non node case. Sometimes you are adding
# constants to nodes.
def getval(o): return o.value if type(o) == Node else o
if len(args): argvals = [getval(o) for o in args]
else: argvals = args
if len(kwargs): kwargvals = dict([(k,getval(o)) for k,o in kwargs.items()])
else: kwargvals = kwargs
# get parents
l = list(args) + list(kwargs.values())
parents = [o for o in l if type(o) == Node ]
value = f(*argvals, **kwargvals)
print("add", "'" + f.__name__ + "'", "to graph with value",value)
return Node(value, f, parents)
return inner
```

Now wrap some basic `numpy`

functions with `primitive`

to get computational-graph versions of these functions:

```
add_new = primitive(np.add)
mul_new = primitive(np.multiply)
div_new = primitive(np.divide)
sub_new = primitive(np.subtract)
neg_new = primitive(np.negative)
exp_new = primitive(np.exp)
```

Let’s try it out! We can’t try it out on our `logistic`

function yet, because that uses operators like $+$ and $\times$ instead of `np.add`

and `np.multiply`

, and we haven’t done any operator overloading. But we can write out the `logistic`

function in terms of the operators and see if it works. We should get a final value of `0.818`

(and indeed we do).

```
def start_node(value = None):
"""A function to create an empty node to start off the graph"""
fun,parents = lambda x: x, []
return Node(value, fun, parents)
z = start_node(1.5)
t1 = mul_new(z, -1)
t2 = exp_new(t1)
t3 = add_new(t2, 1)
y = div_new(1,t3)
print("Final answer:", round(y.value,3)) # correct final output
print(y)
```

```
add 'multiply' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
Final answer: 0.818
Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'multiply'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fea27b8> Value: 1.5 Parents: []]]]]
```

### Operator overloading Link to heading

We want to be able to use these functions for common operators. In other words, if we define a function `def f(x,y): return x+y`

, and we pass in two `Node`

objects to `f`

as `x`

and `y`

, we want `f`

to use our `add_new`

method.

Let’s do this. All we have to do is redefine a version of `Node`

that implements the relevant dunder methods:

```
class Node:
"""A node in a computation graph."""
def __init__(self, value, fun, parents):
self.parents = parents
self.value = value
self.fun = fun
def __repr__(self):
"""A (very) basic string representation"""
if self.value is None: str_val = 'None'
else: str_val = str(round(self.value,3))
return "\n" + "Fun: " + str(self.fun) +\
" Value: "+ str_val + \
" Parents: " + str(self.parents)
## Code to overload operators
# Don't put self.value or other.value in the arguments of these functions,
# otherwise you won't be able to access the Node object to create the
# computational graph.
# Instead, pass the whole node through. And to prevent recursion errors,
# extract the value inside the `primitive` function.
def __add__(self, other): return add_new(self, other)
def __radd__(self, other): return add_new(other, self)
def __sub__(self, other): return sub_new(self, other)
def __rsub__(self, other): return sub_new(other, self)
def __truediv__(self, other): return div_new(self, other)
def __rtruediv__(self, other): return div_new(other, self)
def __mul__(self, other): return mul_new(self, other)
def __rmul__(self, other): return mul_new(other, self)
def __neg__(self): return neg_new(self)
def __exp__(self): return exp_new(self)
```

Now we can add nodes using $+$, divide them with $/$ and so on. Here is a basic example of adding Nodes with $+$:

```
val_z = 1.5
z = Node(val_z, None, [])
val_t1 = 4
t1 = Node(val_t1, None, [])
y = z + t1
```

```
add 'add' to graph with value 5.5
```

Here is the graph of `y`

:

```
print(y)
```

```
Fun: <ufunc 'add'> Value: 5.5 Parents: [
Fun: None Value: 1.5 Parents: [],
Fun: None Value: 4 Parents: []]
```

Let’s try it out on a modified version of `logistic`

function that uses our `exp_new`

function.

```
def logistic2(z): return 1 / (1 + exp_new(-z))
y = logistic2(start_node(value = 1.5))
```

```
add 'negative' to graph with value -1.5
add 'exp' to graph with value 0.22313016014842982
add 'add' to graph with value 1.22313016014843
add 'true_divide' to graph with value 0.8175744761936437
```

The graph of `y`

:

```
print(y)
```

```
Fun: <ufunc 'true_divide'> Value: 0.818 Parents: [
Fun: <ufunc 'add'> Value: 1.223 Parents: [
Fun: <ufunc 'exp'> Value: 0.223 Parents: [
Fun: <ufunc 'negative'> Value: -1.5 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x10fe90f28> Value: 1.5 Parents: []]]]]
```

Sweet! It is working. Now try a multivariate functions.

```
def somefun(x,y): return (x*y + exp_new(x)*exp_new(y))/(4*y)
def somefun2(x,y): return (x*y + np.exp(x)*np.exp(y))/(4*y)
```

```
val_x, val_y = 3,4
ans = somefun(start_node(3), start_node(4))
```

```
add 'multiply' to graph with value 12
add 'exp' to graph with value 20.085536923187668
add 'exp' to graph with value 54.598150033144236
add 'multiply' to graph with value 1096.6331584284585
add 'add' to graph with value 1108.6331584284585
add 'multiply' to graph with value 16
add 'true_divide' to graph with value 69.28957240177866
```

Graph of `ans`

:

```
print(ans)
```

```
Fun: <ufunc 'true_divide'> Value: 69.29 Parents: [
Fun: <ufunc 'add'> Value: 1108.633 Parents: [
Fun: <ufunc 'multiply'> Value: 12 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: [],
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []],
Fun: <ufunc 'multiply'> Value: 1096.633 Parents: [
Fun: <ufunc 'exp'> Value: 20.086 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566f28> Value: 3 Parents: []],
Fun: <ufunc 'exp'> Value: 54.598 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]],
Fun: <ufunc 'multiply'> Value: 16 Parents: [
Fun: <function start_node.<locals>.<lambda> at 0x11c566730> Value: 4 Parents: []]]
```

The result looks complex, but that is because our `__repr__`

function is basic and doesn’t handle nested representations. Still, all the information is there, and we have created a computational graph successfully.

## Next steps Link to heading

At this point we can create functions using common operators and automatically trace their computation graph. Nice!

But we aren’t quite there yet. There’s a few things missing.

- we don’t want to replace
`np.add`

with`add_new`

,`np.exp`

with`exp_new`

etc everywhere. That’s a pain, especially we have a lot of code to do that for. - currently we have to implement primitives for every
`numpy`

function we want. Is there a way to get them all? - how do we handle non-differentiable functions?

We’ll cover these in the next post!