這篇文章主要介紹了如何使用 Hook 函數(shù)提取網(wǎng)絡(luò)中的特征圖進(jìn)行可視化,和 CAM(class activation map, 類激活圖)
Hook 函數(shù)概念
Hook 函數(shù)是在不改變主體的情況下,實(shí)現(xiàn)額外功能。由于 PyTorch 是基于動(dòng)態(tài)圖實(shí)現(xiàn)的,因此在一次迭代運(yùn)算結(jié)束后,一些中間變量如非葉子節(jié)點(diǎn)的梯度和特征圖,會(huì)被釋放掉。在這種情況下想要提取和記錄這些中間變量,就需要使用 Hook 函數(shù)。
PyTorch 提供了 4 種 Hook 函數(shù)。
torch.Tensor.register_hook(hook)
功能:注冊(cè)一個(gè)反向傳播 hook 函數(shù),僅輸入一個(gè)參數(shù),為張量的梯度。
hook函數(shù):
hook(grad)
參數(shù):
- grad:張量的梯度
代碼如下:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)b = torch.add(w, 1)
y = torch.mul(a, b)# 保存梯度的 list
a_grad = list()
# 定義 hook 函數(shù),把梯度添加到 list 中
def grad_hook(grad):
a_grad.Append(grad)
# 一個(gè)張量注冊(cè) hook 函數(shù)
handle = a.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
# 查看在 hook 函數(shù)里 list 記錄的梯度
print("a_grad[0]: ", a_grad[0])
handle.remove()
結(jié)果如下:
gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])
在反向傳播結(jié)束后,非葉子節(jié)點(diǎn)張量的梯度被清空了。而通過hook函數(shù)記錄的梯度仍然可以查看。
hook函數(shù)里面可以修改梯度的值,無(wú)需返回也可以作為新的梯度賦值給原來的梯度。代碼如下:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)b = torch.add(w, 1)
y = torch.mul(a, b)a_grad = list()def grad_hook(grad):
grad *= 2
return grad*3
handle = w.register_hook(grad_hook)y.backward()# 查看梯度
print("w.grad: ", w.grad)
handle.remove()
結(jié)果是:
w.grad: tensor([30.])
torch.nn.Module.register_forward_hook(hook)
功能:注冊(cè) module 的前向傳播hook函數(shù),可用于獲取中間的 feature map。
hook函數(shù):
hook(module, input, output)
參數(shù):
- module:當(dāng)前網(wǎng)絡(luò)層
- input:當(dāng)前網(wǎng)絡(luò)層輸入數(shù)據(jù)
- output:當(dāng)前網(wǎng)絡(luò)層輸出數(shù)據(jù)
下面代碼執(zhí)行的功能是 $3 times 3$ 的卷積和 $2 times 2$ 的池化。我們使用register_forward_hook()記錄中間卷積層輸入和輸出的 feature map。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) # 初始化網(wǎng)絡(luò) net = Net() net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_() # 注冊(cè)hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) # inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
# 觀察
print("output shape: {}noutput value: {}n".format(output.shape, output))
print("feature maps shape: {}noutput value: {}n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}ninput value: {}".format(input_block[0][0].shape, input_block[0]))
輸出如下:
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],
[[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9., 9.],
[ 9., 9.]],
[[18., 18.],
[18., 18.]]]], grad_fn=<ThnnConv2DBackward>)
input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
torch.Tensor.register_forward_pre_hook()
功能:注冊(cè) module 的前向傳播前的hook函數(shù),可用于獲取輸入數(shù)據(jù)。
hook函數(shù):
hook(module, input)
參數(shù):
- module:當(dāng)前網(wǎng)絡(luò)層
- input:當(dāng)前網(wǎng)絡(luò)層輸入數(shù)據(jù)
torch.Tensor.register_backward_hook()
功能:注冊(cè) module 的反向傳播的hook函數(shù),可用于獲取梯度。
hook函數(shù):
hook(module, grad_input, grad_output)
參數(shù):
- module:當(dāng)前網(wǎng)絡(luò)層
- input:當(dāng)前網(wǎng)絡(luò)層輸入的梯度數(shù)據(jù)
- output:當(dāng)前網(wǎng)絡(luò)層輸出的梯度數(shù)據(jù)
代碼如下:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x) x = self.pool1(x) return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output) input_block.append(data_input) def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化網(wǎng)絡(luò)
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 注冊(cè)hook
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
輸出如下:
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]],
[[[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
[0.0000, 0.0000]],
[[0.5000, 0.0000],
[0.0000, 0.0000]]]]),)
hook函數(shù)實(shí)現(xiàn)機(jī)制
hook函數(shù)實(shí)現(xiàn)的原理是在module的__call()__函數(shù)進(jìn)行攔截,__call()__函數(shù)可以分為 4 個(gè)部分:
- 第 1 部分是實(shí)現(xiàn) _forward_pre_hooks
- 第 2 部分是實(shí)現(xiàn) forward 前向傳播
- 第 3 部分是實(shí)現(xiàn) _forward_hooks
- 第 4 部分是實(shí)現(xiàn) _backward_hooks
由于卷積層也是一個(gè)module,因此可以記錄_forward_hooks。
def __call__(self, *input, **kwargs):
# 第 1 部分是實(shí)現(xiàn) _forward_pre_hooks
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,) input = result # 第 2 部分是實(shí)現(xiàn) forward 前向傳播
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
# 第 3 部分是實(shí)現(xiàn) _forward_hooks
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result # 第 4 部分是實(shí)現(xiàn) _backward_hooks
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) return result
Hook 函數(shù)提取網(wǎng)絡(luò)的特征圖
下面通過hook函數(shù)獲取 AlexNet 每個(gè)卷積層的所有卷積核參數(shù),以形狀作為 key,value 對(duì)應(yīng)該層多個(gè)卷積核的 list。然后取出每層的第一個(gè)卷積核,形狀是 [1, in_channle, h, w],轉(zhuǎn)換為 [in_channle, 1, h, w],使用 TensorBoard 進(jìn)行可視化,代碼如下:
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")
# 數(shù)據(jù) path_img = "imgs/lena.png" # your path to image
normMean = [0.49139968, 0.48215827, 0.44653124]
normStd = [0.24703233, 0.24348505, 0.26158768]
norm_transform = transforms.Normalize(normMean, normStd) img_transforms = transforms.Compose([ transforms.Resize((224, 224)),
transforms.ToTensor(), norm_transform ]) img_pil = Image.open(path_img).convert('RGB')
if img_transforms is not None:
img_tensor = img_transforms(img_pil) img_tensor.unsqueeze_(0) # chw --> bchw
# 模型 alexnet = models.alexnet(pretrained=True) # 注冊(cè)hook fmap_dict = dict() for name, sub_module in alexnet.named_modules():
if isinstance(sub_module, nn.Conv2d):
key_name = str(sub_module.weight.shape)
fmap_dict.setdefault(key_name, list()) # 由于AlexNet 使用 nn.Sequantial 包裝,所以 name 的形式是:features.0 features.1
n1, n2 = name.split(".")
def hook_func(m, i, o): key_name = str(m.weight.shape)
fmap_dict[key_name].append(o) alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) # forward output = alexnet(img_tensor) # add image for layer_name, fmap_list in fmap_dict.items():
fmap = fmap_list[0]# 取出第一個(gè)卷積核的參數(shù)
fmap.transpose_(0, 1) # 把 BCHW 轉(zhuǎn)換為 CBHW
nrow = int(np.sqrt(fmap.shape[0]))
fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow) writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
使用 TensorBoard 進(jìn)行可視化如下:
最后,咱給小編:
1. 點(diǎn)贊+關(guān)注
2. 點(diǎn)頭像關(guān)注后多多評(píng)論,轉(zhuǎn)發(fā)給有需要的朋友。
謝謝!!