Hook mechanism in Pytorch.

钩子编程(hooking),也称作“挂钩”,是计算机程序设计术语,指通过拦截软件模块间的函数调用、消息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

HookPyTorch中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的特征或梯度,从而诊断神经网络中可能出现的问题,分析网络有效性。

1. Hook for Tensors

本节介绍张量的hook。在PyTorch计算图(computation graph)中,只有叶节点(leaf node)的变量会保留梯度,而所有中间变量的梯度只在反向传播中使用,一旦反向传播完成,中间变量的梯度将自动释放,从而节约内存。

下图是一个简单的计算图,其中$x,y,w$是叶节点(直接给定数值的变量),$z,o$是中间变量(由其他变量计算得到的变量)。

import torch

x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x + y
o = w.matmul(z)
o.backward()

print('x.requires_grad:', x.requires_grad)  # True
print('y.requires_grad:', y.requires_grad)  # True
print('z.requires_grad:', z.requires_grad)  # True
print('w.requires_grad:', w.requires_grad)  # True
print('o.requires_grad:', o.requires_grad)  # True

print('x.grad:', x.grad)  # tensor([1., 2., 3., 4.])
print('y.grad:', y.grad)  # tensor([1., 2., 3., 4.])
print('w.grad:', w.grad)  # tensor([4., 6., 8., 10.])
print('z.grad:', z.grad)  # None
print('o.grad:', o.grad)  # None

从上面的例子中可以看出,由于$z,o$是中间变量,它们虽然requires_grad的参数都是True,但反向传播后其梯度并没有保存下来,而是直接删除了,因此为None。如果想在反向传播后保留他们的梯度,则需要特殊指定:

z.retain_grad()
o.retain_grad()

print('z.requires_grad:', z.requires_grad) # True
print('o.requires_grad:', o.requires_grad) # True
print('z.grad:', z.grad)  # tensor([1., 2., 3., 4.])
print('o.grad:', o.grad)  # tensor(1.)

但这种使用retain_grad()的方案会增加内存的占用,并不是一个好的方法。可以使用hook保存中间变量的梯度。

对于中间变量$z$,hook的使用方法为:z.register_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

def hook_fn(grad): -> Tensor or None

该函数输入为变量$z$的梯度,输出为一个TensorNoneNone一般用于直接打印梯度)。反向传播时,梯度传播到变量$z$后,再继续往前传播之前,将会传入hook_fn函数。如果hook_fn的返回值是None,则梯度不改变,继续向前传播;如果hook_fn的返回值是Tensor类型,则该Tensor将取代变量$z$原有的梯度,继续向前传播。

下面的例子中hook_fn打印梯度值并修改为原来的两倍:

def hook_fn(grad):
    print(g)
    g = 2 * grad
    return g

z.register_hook(hook_fn)

o.backward()  # tensor([1., 2., 3., 4.])
print('z.grad:', z.grad)  # None

在实际代码中,为简化表示,也可以用lambda表达式代替函数,简写如下:

z.register_hook(lambda x: print(x))
z.register_hook(lambda x: 2*x)

注意到一个变量可以绑定多个hook_fn函数,反向传播时,按绑定顺序依次执行。

2. Hook for Modules

本节介绍模块的hook。模块不像上一节介绍的Tensor一样拥有显式的变量名可以访问,而是被封装在神经网络中。通常只能获得网络整体的输入和输出,而对于网络中间的模块,不仅很难得到它输入和输出的梯度,甚至连输入输出的数值都无法获得。比较麻烦的做法是,在forward函数的返回值中包含中间模块的输出;或者把网络按照模块的名称拆分再组合,提取中间层的特征。

Pytorch设计了两种hookregister_forward_hookregister_backward_hook,分别用来获取前向传播和反向传播时中间层模块的输入和输出特征及梯度,从而大大降低了获取模型内部信息流的难度。

register_forward_hook

register_forward_hook的作用是获取前向传播过程中,网络各模块的输入和输出。对于模块module,其使用方法为:module.register_forward_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

def hook_fn(module, input, output): -> Tensor or None

hook_fn函数的输入变量分别为模块、模块的输入和模块的输出。输出为NonePytorch1.2.0之后的版本也可以返回张量,用于修改模块的输出。借助这个hook,可以方便的使用预训练的神经网络提取特征,而不用改变预训练网络的结构。下面是一个简单的例子:

import torch
from torch import nn

#  全局变量,用于存储中间层的特征
total_feat_out = []
total_feat_in = []

#  定义 forward hook function
def hook_fn_forward(module, input, output):
    print(module)  # 打印模块名,用于区分模块
    print('input', input)   # 打印该模块的输入
    print('output', output) # 打印该模块的输出
    total_feat_out.append(output) # 保存该模块的输出
    total_feat_in.append(input)   # 保存该模块的输入

model = Model()

modules = model.named_children()
for name, module in modules:
    module.register_forward_hook(hook_fn_forward)

#  注意下面代码中 x 的维度,第一维是 batch size
#  forward hook 中看不出来,但是 backward hook 中是必要的。
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_() 

register_backward_hook

register_backward_hook的作用是获取反向传播过程中,网络各模块输入端和输出端的梯度值。对于模块module,其使用方法为:module.register_backward_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

def hook_fn(module, grad_input, grad_output): -> Tensor or None

hook_fn函数的输入变量分别为模块、模块输入端的梯度和模块输出端的梯度(这里的输入端和输出端是站在前向传播的角度来说的)。如果模块有多个输入端或输出端,则对应的梯度是tuple类型(例如对于线性模块,其grad_input是一个三元组,排列顺序分别为:对bias的导数、对输入x的导数、对权重W的导数)。下面是一个简单的例子:

import torch
from torch import nn

#  全局变量,用于存储中间层的梯度
total_grad_out = []
total_grad_in = []

# 定义 backward hook function
def hook_fn_backward(module, grad_input, grad_output):
    print(module)  # 打印模块名,用于区分模块
    print('grad_output', grad_output)  # 打印该模块输出端的梯度
    print('grad_input', grad_input)    # 打印该模块输入端的梯度
    total_grad_in.append(grad_input)   # 保存该模块输入端的梯度
    total_grad_out.append(grad_output) # 保存该模块输出端的梯度

model = Model()

modules = model.named_children()
for name, module in modules:
    module.register_backward_hook(hook_fn_backward)

#  这里的 requires_grad 很重要,如果不加,backward hook
#  执行到第一层,对 x 的导数将为 None 。
#  此外再强调一遍 x 的维度,第一维一定是 batch size
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()

注意事项

register_backward_hook在全连接层和卷积层中的表现是不一致的,具体如下:

特别地,如果已知某个模块的类型,也可以用下面的方式对其加hook

for name, module in modules:
    if isinstance(module, nn.ReLU):
        module.register_forward_hook(forward_hook_fn)
        module.register_backward_hook(backward_hook_fn)