1. 计算图
1.1 用计算图表示:
计算图是用来描述运算的有向无环图,计算图有两个主要元素:结点(Node)和边(Edge)
结点表示数据,如向量,矩阵,张量,边表示运算,如加减乘除卷积等
y=(x+w)∗(w+1)a=x+wb=w+1y=a∗b
1.2 计算图与梯度求导
∂w∂y=∂a∂y∂w∂a+∂b∂y∂w∂b=b∗1+a∗1=b+a=(w+1)+(x+w)=2∗w+x+1=2∗1+2+1=5
1.3 叶子结点:
用户创建的结点称为叶子结点,如X与W
is_leaf: 指示张量是否为叶子结点
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import torch
w = torch.tensor([1.], requires_grad=True) x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1) y = torch.mul(a, b)
y.backward() print(w.grad)
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
|
1 2 3 4 5 6 7
| tensor([5.]) is_leaf: True True False False False gradient: tensor([5.]) tensor([2.]) None None None grad_fn: None None <AddBackward0 object at 0x000001B105634A08> <AddBackward0 object at 0x000001B17AAE8708> <MulBackward0 object at 0x000001B17AAE6148>
|
返回的y的导数是5;只有w和x是叶子节点,非叶子节点的梯度不会被保存(自动及时释放);grad_fn可以查看该节点是用的什么方法计算梯度,这与计算图有关;如果使用retain_grad()可以保存非叶子节点的梯度。
2. 动态图
2.1 动态图vs 静态图
根据计算图搭建方式,可将计算图分为动态图和静态图
2.2 动态图 PyTorch
2.3 静态图 TensorFlow