Hook mechanism in Pytorch.
钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)。
Hook是PyTorch中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的特征或梯度,从而诊断神经网络中可能出现的问题,分析网络有效性。
1. Hook for Tensors
本节介绍张量的hook。在PyTorch的计算图(computation graph)中,只有叶节点(leaf node)的变量会保留梯度,而所有中间变量的梯度只在反向传播中使用,一旦反向传播完成,中间变量的梯度将自动释放,从而节约内存。
下图是一个简单的计算图,其中$x,y,w$是叶节点(直接给定数值的变量),$z,o$是中间变量(由其他变量计算得到的变量)。
import torch
x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x + y
o = w.matmul(z)
o.backward()
print('x.requires_grad:', x.requires_grad) # True
print('y.requires_grad:', y.requires_grad) # True
print('z.requires_grad:', z.requires_grad) # True
print('w.requires_grad:', w.requires_grad) # True
print('o.requires_grad:', o.requires_grad) # True
print('x.grad:', x.grad) # tensor([1., 2., 3., 4.])
print('y.grad:', y.grad) # tensor([1., 2., 3., 4.])
print('w.grad:', w.grad) # tensor([4., 6., 8., 10.])
print('z.grad:', z.grad) # None
print('o.grad:', o.grad) # None
从上面的例子中可以看出,由于$z,o$是中间变量,它们虽然requires_grad
的参数都是True
,但反向传播后其梯度并没有保存下来,而是直接删除了,因此为None
。如果想在反向传播后保留他们的梯度,则需要特殊指定:
z.retain_grad()
o.retain_grad()
print('z.requires_grad:', z.requires_grad) # True
print('o.requires_grad:', o.requires_grad) # True
print('z.grad:', z.grad) # tensor([1., 2., 3., 4.])
print('o.grad:', o.grad) # tensor(1.)
但这种使用retain_grad()
的方案会增加内存的占用,并不是一个好的方法。可以使用hook保存中间变量的梯度。
对于中间变量$z$,hook的使用方法为:z.register_hook(hook_fn)
,其中hook_fn
为一个用户自定义的函数:
def hook_fn(grad): -> Tensor or None
该函数输入为变量$z$的梯度,输出为一个Tensor或None
(None
一般用于直接打印梯度)。反向传播时,梯度传播到变量$z$后,再继续往前传播之前,将会传入hook_fn
函数。如果hook_fn
的返回值是None
,则梯度不改变,继续向前传播;如果hook_fn
的返回值是Tensor类型,则该Tensor将取代变量$z$原有的梯度,继续向前传播。
下面的例子中hook_fn
打印梯度值并修改为原来的两倍:
def hook_fn(grad):
print(g)
g = 2 * grad
return g
z.register_hook(hook_fn)
o.backward() # tensor([1., 2., 3., 4.])
print('z.grad:', z.grad) # None
在实际代码中,为简化表示,也可以用lambda
表达式代替函数,简写如下:
z.register_hook(lambda x: print(x))
z.register_hook(lambda x: 2*x)
注意到一个变量可以绑定多个hook_fn
函数,反向传播时,按绑定顺序依次执行。
2. Hook for Modules
本节介绍模块的hook。模块不像上一节介绍的Tensor一样拥有显式的变量名可以访问,而是被封装在神经网络中。通常只能获得网络整体的输入和输出,而对于网络中间的模块,不仅很难得到它输入和输出的梯度,甚至连输入输出的数值都无法获得。比较麻烦的做法是,在forward函数的返回值中包含中间模块的输出;或者把网络按照模块的名称拆分再组合,提取中间层的特征。
Pytorch设计了两种hook:register_forward_hook
和register_backward_hook
,分别用来获取前向传播和反向传播时中间层模块的输入和输出特征及梯度,从而大大降低了获取模型内部信息流的难度。
register_forward_hook
register_forward_hook
的作用是获取前向传播过程中,网络各模块的输入和输出。对于模块module
,其使用方法为:module.register_forward_hook(hook_fn)
,其中hook_fn
为一个用户自定义的函数:
def hook_fn(module, input, output): -> Tensor or None
hook_fn
函数的输入变量分别为模块、模块的输入和模块的输出。输出为None,Pytorch1.2.0之后的版本也可以返回张量,用于修改模块的输出。借助这个hook,可以方便的使用预训练的神经网络提取特征,而不用改变预训练网络的结构。下面是一个简单的例子:
import torch
from torch import nn
# 全局变量,用于存储中间层的特征
total_feat_out = []
total_feat_in = []
# 定义 forward hook function
def hook_fn_forward(module, input, output):
print(module) # 打印模块名,用于区分模块
print('input', input) # 打印该模块的输入
print('output', output) # 打印该模块的输出
total_feat_out.append(output) # 保存该模块的输出
total_feat_in.append(input) # 保存该模块的输入
model = Model()
modules = model.named_children()
for name, module in modules:
module.register_forward_hook(hook_fn_forward)
# 注意下面代码中 x 的维度,第一维是 batch size
# forward hook 中看不出来,但是 backward hook 中是必要的。
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
register_backward_hook
register_backward_hook
的作用是获取反向传播过程中,网络各模块输入端和输出端的梯度值。对于模块module
,其使用方法为:module.register_backward_hook(hook_fn)
,其中hook_fn
为一个用户自定义的函数:
def hook_fn(module, grad_input, grad_output): -> Tensor or None
hook_fn
函数的输入变量分别为模块、模块输入端的梯度和模块输出端的梯度(这里的输入端和输出端是站在前向传播的角度来说的)。如果模块有多个输入端或输出端,则对应的梯度是tuple类型(例如对于线性模块,其grad_input
是一个三元组,排列顺序分别为:对bias
的导数、对输入x
的导数、对权重W
的导数)。下面是一个简单的例子:
import torch
from torch import nn
# 全局变量,用于存储中间层的梯度
total_grad_out = []
total_grad_in = []
# 定义 backward hook function
def hook_fn_backward(module, grad_input, grad_output):
print(module) # 打印模块名,用于区分模块
print('grad_output', grad_output) # 打印该模块输出端的梯度
print('grad_input', grad_input) # 打印该模块输入端的梯度
total_grad_in.append(grad_input) # 保存该模块输入端的梯度
total_grad_out.append(grad_output) # 保存该模块输出端的梯度
model = Model()
modules = model.named_children()
for name, module in modules:
module.register_backward_hook(hook_fn_backward)
# 这里的 requires_grad 很重要,如果不加,backward hook
# 执行到第一层,对 x 的导数将为 None 。
# 此外再强调一遍 x 的维度,第一维一定是 batch size
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
注意事项
register_backward_hook
在全连接层和卷积层中的表现是不一致的,具体如下:
- 形状不一致
- 在卷积层中,weight的梯度和weight的形状相同;
- 在全连接层中,weight的梯度的形状是weight形状的转置。
grad_input
元组中梯度的顺序不一致- 在卷积层中,梯度的顺序为:(对feature的梯度,对weight的梯度,对bias的梯度);
- 在全连接层中,梯度的顺序为:(对bias的梯度,对feature的梯度,对weight的梯度)。
- 当batch size大于$1$时,对bias的梯度处理不一致
- 在卷积层中,对bias的梯度为整个batch的数据在bias上的梯度之和:(对feature的梯度,对weight的梯度,对bias的梯度);
- 在全连接层中,对bias的梯度是分开的,batch中的每个数据对应一个bias的梯度:((data1对bias的梯度,data2对bias的梯度…),对feature的梯度,对weight的梯度)。
特别地,如果已知某个模块的类型,也可以用下面的方式对其加hook:
for name, module in modules:
if isinstance(module, nn.ReLU):
module.register_forward_hook(forward_hook_fn)
module.register_backward_hook(backward_hook_fn)