Hook函数

1. Hook函数概念

Hook函数机制:不改变主体,实现额外功能,就是开挂的意思

大致分为两类:torch.Tensor(对参数)和torch.nn.Module(对模型)的register_hook函数

  1. torch.Tensor.register_hook(hook)
  2. torch.nn.Module.register_forward_hook
  3. torch.nn.Module.register_forward_pre_hook
  4. torch.nn.Module.register_backward_hook

关于torch.nn.Module的hook函数,我们可以在module.py看到是怎么执行的。对于一个nn.module类,在调用时都会进入__call__函数

可以看到,__call__函数是分为四个部分的,依次执行顺序是register_forward_pre_hook,forward,forward_hook,backward_hook。根据需要hook函数执行的位置,我们可以选择不同位置进行hook注册。

image-20200809225312341

1.1 torch.Tensor.register_hook

1
torch.Tensor.register_hook(hook)

功能:注册一个Tensor反向传播的hook函数

参数:hook是一个函数

1
hook(grad) -> Tensor or None

hook函数仅一个输入参数(张量的梯度),可以定义一些对参数梯度的操作。

此函数返回带有方法的句柄,使用remove()可以从模块中移除挂钩。

实验:

image-20200809202429113
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
a_grad.append(grad)

handle = a.register_hook(grad_hook)

y.backward()

# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
1
2
gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])

可以看到非叶子节点的梯度都已经被系统自动释放了;我们用hook函数把参数a的梯度保留了下来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
grad *= 2
# return grad*2


handle = w.register_hook(grad_hook)

y.backward()

# 查看梯度
print("w.grad: ", w.grad)
handle.remove()
1
w.grad:  tensor([10.])

hook函数不仅可以获取grad,甚至可以修改grad(像外挂一样)

1.2 torch.nn.Module.register_forward_hook

1
Module.register_forward_hook(hook)

功能:注册module的前向传播hook函数

参数:hook函数

模型的每次forward()计算出一个输出后,都会调用hook函数。hook函数应该有如下形式:

1
hook(module, input, output) -> None or modified output

此函数返回带有方法的句柄,使用remove()可以从模块中移除挂钩。

参数:

  • module: 当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

实验:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1) # 直接修改conv网络层的参数,由于输出通道为2,所以有两个卷积核
net.conv1.weight[1].detach().fill_(2) # weight[0]是第一个卷积核 weight[1]是第二个卷积核
net.conv1.bias.data.detach().zero_() # bias设为0

# 注册hook
fmap_block = list()
input_block = list()

def forward_hook(module, data_input, data_output):
fmap_block.append(data_output) # 收集各层的输出特征图
input_block.append(data_input) # 收集各层的输入

net.conv1.register_forward_hook(forward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)

# 观察
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

可视化这个网络的输入和卷积核参数如下所示:

image-20200809215220453

由于我们是对net.conv1注册的hook函数(net.conv1.register_forward_hook(forward_hook)),因此在运行到net.conv1的forward()函数时,会自动调用hook函数

print输出conv1的输入和输出如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9., 9.],
[ 9., 9.]],

[[18., 18.],
[18., 18.]]]], grad_fn=<MkldnnConvolutionBackward>)

input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)

1.3 torch.nn.Module.register_forward_pre_hook

1
Module.register_forward_pre_hook(hook)

功能:注册module前向传播前的hook函数

参数:hook函数

模型在forward()之前调用hook函数。hook函数应该有如下形式:

1
hook(module, input) -> None or modified output

此函数返回带有方法的句柄,使用remove()可以从模块中移除挂钩。

参数:

  • module: 当前网络层
  • input:当前网络层输入数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

# 注册hook
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))

net.conv1.register_forward_pre_hook(forward_pre_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)

输出:

1
2
3
4
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)

可以看到hook函数将conv1的输入打印出来了

1.4 torch.nn.Module.register_backward_hook

1
Module.register_backward_hook(hook)

功能:注册module反向传播的hook函数

参数:hook函数

模型在反向传播之后调用hook函数。hook函数应该有如下形式:

1
hook(module, grad_input, grad_output) -> None or modified output

此函数返回带有方法的句柄,使用remove()可以从模块中移除挂钩。

参数:

  • module: 当前网络层
  • grad_input:当前网络层输入梯度数据
  • grad_output:当前网络层输出梯度数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1) # 设置随机种子

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)

def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x

# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

# 注册hook
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))

net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]],


[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
[0.0000, 0.0000]],

[[0.5000, 0.0000],
[0.0000, 0.0000]]]]),)

2. Hook函数与特征图提取

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
import torchvision.models as models

set_seed(1) # 设置随机种子

# ----------------------------------- feature map visualization -----------------------------------


writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

# 数据
path_img = "./lena.png" # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]

norm_transform = transforms.Normalize(normMean, normStd)
img_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
norm_transform
])

img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil)
img_tensor.unsqueeze_(0) # chw --> bchw

# 模型
alexnet = models.alexnet(pretrained=True)

# 注册hook
fmap_dict = dict()
for name, sub_module in alexnet.named_modules():

if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape) # 按卷积核的尺寸命名
fmap_dict.setdefault(key_name, list()) # 初始化fmap_dict,键为卷积核尺寸,值为空列表

n1, n2 = name.split(".")

def hook_func(m, i, o): # hook函数把网络层的输出按卷积核尺寸分类加入字典
key_name = str(m.weight.shape)
fmap_dict[key_name].append(o)

alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) # 由于alexnet是用Sequential容器装载网络层的,
# 所以先索引容器名,再索引容器內部的网络层。
# 这样就为每个Conv2d网络层都注册了hook函数

# forward
output = alexnet(img_tensor)

# add image
for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]
fmap.transpose_(0, 1) # bchw→cbhw,这是因为我们的batch=1,但是通道数c由于在不同的卷积层,通道数不一样,
# 对于一个超过3通道的卷积核我们只能把每个通道用灰度显示

nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

tensorboard可视化特征图的效果如下:

image-20200810013148726

image-20200810013158604

image-20200810013210146

image-20200810013220562