1. Hook函数概念
Hook函数机制:不改变主体,实现额外功能,就是开挂的意思
大致分为两类:torch.Tensor(对参数)和torch.nn.Module(对模型)的register_hook函数
torch.Tensor.register_hook(hook)
torch.nn.Module.register_forward_hook
torch.nn.Module.register_forward_pre_hook
torch.nn.Module.register_backward_hook
关于torch.nn.Module的hook函数,我们可以在module.py看到是怎么执行的。对于一个nn.module类,在调用时都会进入__call__
函数
可以看到,__call__
函数是分为四个部分的,依次执行顺序是register_forward_pre_hook,forward,forward_hook,backward_hook。根据需要hook函数执行的位置,我们可以选择不同位置进行hook注册。
1.1 torch.Tensor.register_hook
1 torch.Tensor.register_hook(hook)
功能:注册一个Tensor反向传播的hook函数
参数:hook是一个函数
1 hook(grad) -> Tensor or None
hook函数仅一个输入参数(张量的梯度),可以定义一些对参数梯度的操作。
此函数返回带有方法的句柄,使用remove()
可以从模块中移除挂钩。
实验:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 import torchimport torch.nn as nnfrom tools.common_tools import set_seedset_seed(1 ) w = torch.tensor([1. ], requires_grad=True ) x = torch.tensor([2. ], requires_grad=True ) a = torch.add(w, x) b = torch.add(w, 1 ) y = torch.mul(a, b) a_grad = list () def grad_hook (grad ): a_grad.append(grad) handle = a.register_hook(grad_hook) y.backward() print ("gradient:" , w.grad, x.grad, a.grad, b.grad, y.grad)print ("a_grad[0]: " , a_grad[0 ])handle.remove()
1 2 gradient: tensor([5. ]) tensor([2. ]) None None None a_grad[0 ]: tensor([2. ])
可以看到非叶子节点的梯度都已经被系统自动释放了;我们用hook函数把参数a的梯度保留了下来。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 import torchimport torch.nn as nnfrom tools.common_tools import set_seedset_seed(1 ) w = torch.tensor([1. ], requires_grad=True ) x = torch.tensor([2. ], requires_grad=True ) a = torch.add(w, x) b = torch.add(w, 1 ) y = torch.mul(a, b) a_grad = list () def grad_hook (grad ): grad *= 2 handle = w.register_hook(grad_hook) y.backward() print ("w.grad: " , w.grad)handle.remove()
hook函数不仅可以获取grad,甚至可以修改grad(像外挂一样)
1.2 torch.nn.Module.register_forward_hook
1 Module.register_forward_hook(hook)
功能:注册module的前向传播hook函数
参数:hook函数
模型的每次forward()计算出一个输出后,都会调用hook函数。hook函数应该有如下形式:
1 hook(module, input , output) -> None or modified output
此函数返回带有方法的句柄,使用remove()
可以从模块中移除挂钩。
参数:
module: 当前网络层
input:当前网络层输入数据
output:当前网络层输出数据
实验:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 import torchimport torch.nn as nnfrom tools.common_tools import set_seedset_seed(1 ) class Net (nn.Module): def __init__ (self ): super (Net, self).__init__() self.conv1 = nn.Conv2d(1 , 2 , 3 ) self.pool1 = nn.MaxPool2d(2 , 2 ) def forward (self, x ): x = self.conv1(x) x = self.pool1(x) return x net = Net() net.conv1.weight[0 ].detach().fill_(1 ) net.conv1.weight[1 ].detach().fill_(2 ) net.conv1.bias.data.detach().zero_() fmap_block = list () input_block = list () def forward_hook (module, data_input, data_output ): fmap_block.append(data_output) input_block.append(data_input) net.conv1.register_forward_hook(forward_hook) fake_img = torch.ones((1 , 1 , 4 , 4 )) output = net(fake_img) print ("feature maps shape: {}\noutput value: {}\n" .format (fmap_block[0 ].shape, fmap_block[0 ]))print ("input shape: {}\ninput value: {}" .format (input_block[0 ][0 ].shape, input_block[0 ]))
可视化这个网络的输入和卷积核参数如下所示:
由于我们是对net.conv1注册的hook函数(net.conv1.register_forward_hook(forward_hook)
),因此在运行到net.conv1的forward()函数时,会自动调用hook函数
print输出conv1的输入和输出如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 feature maps shape: torch.Size([1 , 2 , 2 , 2 ]) output value: tensor([[[[ 9. , 9. ], [ 9. , 9. ]], [[18. , 18. ], [18. , 18. ]]]], grad_fn=<MkldnnConvolutionBackward>) input shape: torch.Size([1 , 1 , 4 , 4 ])input value: (tensor([[[[1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ]]]]),)
1.3 torch.nn.Module.register_forward_pre_hook
1 Module.register_forward_pre_hook(hook)
功能:注册module前向传播前的hook函数
参数:hook函数
模型在forward()之前调用hook函数。hook函数应该有如下形式:
1 hook(module, input ) -> None or modified output
此函数返回带有方法的句柄,使用remove()
可以从模块中移除挂钩。
参数:
module: 当前网络层
input:当前网络层输入数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import torchimport torch.nn as nnfrom tools.common_tools import set_seedset_seed(1 ) class Net (nn.Module): def __init__ (self ): super (Net, self).__init__() self.conv1 = nn.Conv2d(1 , 2 , 3 ) self.pool1 = nn.MaxPool2d(2 , 2 ) def forward (self, x ): x = self.conv1(x) x = self.pool1(x) return x net = Net() net.conv1.weight[0 ].detach().fill_(1 ) net.conv1.weight[1 ].detach().fill_(2 ) net.conv1.bias.data.detach().zero_() def forward_pre_hook (module, data_input ): print ("forward_pre_hook input:{}" .format (data_input)) net.conv1.register_forward_pre_hook(forward_pre_hook) fake_img = torch.ones((1 , 1 , 4 , 4 )) output = net(fake_img)
输出:
1 2 3 4 forward_pre_hook input :(tensor([[[[1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ], [1. , 1. , 1. , 1. ]]]]),)
可以看到hook函数将conv1的输入打印出来了
1.4 torch.nn.Module.register_backward_hook
1 Module.register_backward_hook(hook)
功能:注册module反向传播的hook函数
参数:hook函数
模型在反向传播之后调用hook函数。hook函数应该有如下形式:
1 hook(module, grad_input, grad_output) -> None or modified output
此函数返回带有方法的句柄,使用remove()
可以从模块中移除挂钩。
参数:
module: 当前网络层
grad_input:当前网络层输入梯度数据
grad_output:当前网络层输出梯度数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 import torchimport torch.nn as nnfrom tools.common_tools import set_seedset_seed(1 ) class Net (nn.Module): def __init__ (self ): super (Net, self).__init__() self.conv1 = nn.Conv2d(1 , 2 , 3 ) self.pool1 = nn.MaxPool2d(2 , 2 ) def forward (self, x ): x = self.conv1(x) x = self.pool1(x) return x net = Net() net.conv1.weight[0 ].detach().fill_(1 ) net.conv1.weight[1 ].detach().fill_(2 ) net.conv1.bias.data.detach().zero_() def backward_hook (module, grad_input, grad_output ): print ("backward hook input:{}" .format (grad_input)) print ("backward hook output:{}" .format (grad_output)) net.conv1.register_backward_hook(backward_hook) fake_img = torch.ones((1 , 1 , 4 , 4 )) output = net(fake_img) loss_fnc = nn.L1Loss() target = torch.randn_like(output) loss = loss_fnc(target, output) loss.backward()
输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 backward hook input :(None , tensor([[[[0.5000 , 0.5000 , 0.5000 ], [0.5000 , 0.5000 , 0.5000 ], [0.5000 , 0.5000 , 0.5000 ]]], [[[0.5000 , 0.5000 , 0.5000 ], [0.5000 , 0.5000 , 0.5000 ], [0.5000 , 0.5000 , 0.5000 ]]]]), tensor([0.5000 , 0.5000 ])) backward hook output:(tensor([[[[0.5000 , 0.0000 ], [0.0000 , 0.0000 ]], [[0.5000 , 0.0000 ], [0.0000 , 0.0000 ]]]]),)
2. Hook函数与特征图提取
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 import torch.nn as nnimport numpy as npfrom PIL import Imageimport torchvision.transforms as transformsimport torchvision.utils as vutilsfrom torch.utils.tensorboard import SummaryWriterfrom tools.common_tools import set_seedimport torchvision.models as modelsset_seed(1 ) writer = SummaryWriter(comment='test_your_comment' , filename_suffix="_test_your_filename_suffix" ) path_img = "./lena.png" normMean = [0.49139968 , 0.48215827 , 0.44653124 ] normStd = [0.24703233 , 0.24348505 , 0.26158768 ] norm_transform = transforms.Normalize(normMean, normStd) img_transforms = transforms.Compose([ transforms.Resize((224 , 224 )), transforms.ToTensor(), norm_transform ]) img_pil = Image.open (path_img).convert('RGB' ) if img_transforms is not None : img_tensor = img_transforms(img_pil) img_tensor.unsqueeze_(0 ) alexnet = models.alexnet(pretrained=True ) fmap_dict = dict () for name, sub_module in alexnet.named_modules(): if isinstance (sub_module, nn.Conv2d): key_name = str (sub_module.weight.shape) fmap_dict.setdefault(key_name, list ()) n1, n2 = name.split("." ) def hook_func (m, i, o ): key_name = str (m.weight.shape) fmap_dict[key_name].append(o) alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) output = alexnet(img_tensor) for layer_name, fmap_list in fmap_dict.items(): fmap = fmap_list[0 ] fmap.transpose_(0 , 1 ) nrow = int (np.sqrt(fmap.shape[0 ])) fmap_grid = vutils.make_grid(fmap, normalize=True , scale_each=True , nrow=nrow) writer.add_image('feature map in {}' .format (layer_name), fmap_grid, global_step=322 )
tensorboard可视化特征图的效果如下: