autograd—自动求导系统

1. torch.autograd.backward

1
2
3
4
5
torch.autograd.backward(tensors, 
grad_tensors=None,
retain_graph=None,
create_graph=False,
grad_variables=None)

功能:自动求取梯度

  • tensors: 用于求导的张量,如 loss
  • retain_graph : 保存计算图
  • create_graph : 创建导数计算图,用于高阶求导
  • grad_tensors:多梯度权重
1
2
3
4
5
6
7
8
9
10
11
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)
1
tensor([5.])

注意:使用自动求导,必须在建立变量时声明requires_grad=True

y这个tensor的类里面的方法backward只有一个函数,就是torch.autograd.backward。

那么实际上我们对loss调用backward方法,就可以直接对叶子节点调用tensor.grad属性得到他们的梯度了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
torch.manual_seed(10)

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x) # retain_grad()
b = torch.add(w, 1)

y0 = torch.mul(a, b) # y0 = (x+w) * (w+1)
y1 = torch.add(a, b) # y1 = (x+w) + (w+1) dy1/dw = 2

loss = torch.cat([y0, y1], dim=0) # [y0, y1]
grad_tensors = torch.tensor([1., 2.])

loss.backward(gradient=grad_tensors) # gradient 传入 torch.autograd.backward()中的grad_tensors

print(w.grad)
1
tensor([9.])

这里loss是两个数的拼接,那么使用grad_tensors可以给这两数求加权和,再计算总的梯度。上面代码中,dy1/dw总是等于2,所以总的梯度就是y0的梯度5x1 加上y1的梯度2x2,结果就是9

2. torch.autograd.grad

1
2
3
4
5
6
7
torch.autograd.grad(outputs, 
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False)

功能:求取梯度

  • outputs: 用于求导的张量,如 loss(dy/dx的y)
  • inputs : 需要梯度的张量(dy/dx的x)
  • create_graph : 创建导数计算图,用于高阶求导
  • retain_graph : 保存计算图
  • grad_outputs:多梯度权重
1
2
3
4
5
6
7
8
9
10
11
import torch
torch.manual_seed(10)

x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2) # y = x**2

grad_1 = torch.autograd.grad(y, x, create_graph=True) # grad_1 = dy/dx = 2x = 2 * 3 = 6
print(grad_1)

grad_2 = torch.autograd.grad(grad_1[0], x) # grad_2 = d(dy/dx)/dx = d(2x)/dx = 2
print(grad_2)
1
2
(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

grad_1中create_graph=True创建了计算图,使得可以继续计算grad_1相对于x的导数,实现二阶求导

3. autograd的注意事项

3.1 梯度不自动清零

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

for i in range(4):
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)

# w.grad.zero_()
1
2
3
4
tensor([5.])
tensor([10.])
tensor([15.])
tensor([20.])

只有开启了w.grad.zero_()以后,梯度才会清零

3.2 依赖于叶子结点的结点,requires_grad默认为True

1
2
3
4
5
6
7
8
9
10
11
import torch
torch.manual_seed(10)

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

print(a.requires_grad, b.requires_grad, y.requires_grad)
1
True True True

w和x是叶子节点,则a和b直接与叶子节点相连,requires_grad=True;y与a和b相连,间接依赖于叶子节点的requires_grad=True。

其实,在计算图中所有的节点,估计都是requires_grad=True的,因为都需要他们的导数进行反向传播

所以只需要设置叶子节点requires_grad=True即可

3.3 叶子结点不可执行in-place

3.3.1 in-place运算

1
2
3
4
5
6
7
8
import torch
torch.manual_seed(10)

a = torch.ones((1, ))
print(id(a), a)

a += torch.ones((1, ))
print(id(a), a)
1
2
2627959083480 tensor([1.])
2627959083480 tensor([2.])

自加运算是原位运算,不改变地址

1
2
3
4
5
6
7
8
import torch
torch.manual_seed(10)

a = torch.ones((1, ))
print(id(a), a)

a.add_(1)
print(id(a), a)
1
2
3110703681928 tensor([1.])
3110703681928 tensor([2.])

add_()也是原位运算哦,通常函数后带下划线的都是原位运算。add()就不是原位运算。

1
2
3
4
5
6
7
8
import torch
torch.manual_seed(10)

a = torch.ones((1, ))
print(id(a), a)

a = a + torch.ones((1, ))
print(id(a), a)
1
2
2088678674824 tensor([1.])
2088680712216 tensor([2.])

这样就不是原位运算了,地址被改变

3.3.2 叶子节点不能做in-place运算

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
torch.manual_seed(10)

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

w.add_(1)

y.backward()
1
2
3
4
Traceback (most recent call last):
<module>
w.add_(1)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

报错了,w.add_(1)是原位运算。在计算图中,各个节点的导数计算都依赖于叶子节点的值。因此不能在计算图中对叶子节点做原位运算。