计算图与动态图机制

1. 计算图

1.1 用计算图表示:

image-20200721135913837

计算图是用来描述运算的有向无环图,计算图有两个主要元素:结点(Node)和边(Edge)

结点表示数据,如向量,矩阵,张量,边表示运算,如加减乘除卷积等

y=(x+w)(w+1)a=x+wb=w+1y=aby = (x+ w) * (w+1)\\ a = x + w \\ b = w + 1\\ y = a * b

1.2 计算图与梯度求导

image-20200721135439573

yw=yaaw+ybbw=b1+a1=b+a=(w+1)+(x+w)=2w+x+1=21+2+1=5\begin{aligned} \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ &=b * 1+a * 1 \\ &=b+a \\ &=(w+1)+(x+w) \\ &=2 * w+x+1 \\ &=2 * 1+2+1=5 \end{aligned}

1.3 叶子结点:

用户创建的结点称为叶子结点,如X与W
is_leaf: 指示张量是否为叶子结点

image-20200721135546915
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
# a.retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)

# 查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
1
2
3
4
5
6
7
tensor([5.])
is_leaf:
True True False False False
gradient:
tensor([5.]) tensor([2.]) None None None
grad_fn:
None None <AddBackward0 object at 0x000001B105634A08> <AddBackward0 object at 0x000001B17AAE8708> <MulBackward0 object at 0x000001B17AAE6148>

返回的y的导数是5;只有w和x是叶子节点,非叶子节点的梯度不会被保存(自动及时释放);grad_fn可以查看该节点是用的什么方法计算梯度,这与计算图有关;如果使用retain_grad()可以保存非叶子节点的梯度。

2. 动态图

2.1 动态图vs 静态图

根据计算图搭建方式,可将计算图分为动态图和静态图

image-20200721141030875

2.2 动态图 PyTorch

image-20200721141104702

2.3 静态图 TensorFlow

image-20200721141113247