diff --git a/Code/4_viewer/6_hook_for_grad_cam.py b/Code/4_viewer/6_hook_for_grad_cam.py new file mode 100644 index 0000000000000000000000000000000000000000..476d40f7be885b9d4c7ef026b0d0a97bd6696731 --- /dev/null +++ b/Code/4_viewer/6_hook_for_grad_cam.py @@ -0,0 +1,182 @@ +# coding: utf-8 +""" +通过实现Grad-CAM学习module中的forward_hook和backward_hook函数 +""" +import cv2 +import os +import numpy as np +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool1(F.relu(self.conv1(x))) + x = self.pool1(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def img_transform(img_in, transform): + """ + 将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W + :param img_roi: np.array + :return: + """ + img = img_in.copy() + img = Image.fromarray(np.uint8(img)) + img = transform(img) + img = img.unsqueeze(0) # C*H*W --> B*C*H*W + return img + + +def img_preprocess(img_in): + """ + 读取图片,转为模型可读的形式 + :param img_in: ndarray, [H, W, C] + :return: PIL.image + """ + img = img_in.copy() + img = cv2.resize(img,(32, 32)) + img = img[:, :, ::-1] # BGR --> RGB + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.4948052, 0.48568845, 0.44682974], [0.24580306, 0.24236229, 0.2603115]) + ]) + img_input = img_transform(img, transform) + return img_input + + +def backward_hook(module, grad_in, grad_out): + grad_block.append(grad_out[0].detach()) + + +def farward_hook(module, input, output): + fmap_block.append(output) + + +def show_cam_on_image(img, mask, out_dir): + heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) + heatmap = np.float32(heatmap) / 255 + cam = heatmap + np.float32(img) + cam = cam / np.max(cam) + + path_cam_img = os.path.join(out_dir, "cam.jpg") + path_raw_img = os.path.join(out_dir, "raw.jpg") + if not os.path.exists(out_dir): + os.makedirs(out_dir) + cv2.imwrite(path_cam_img, np.uint8(255 * cam)) + cv2.imwrite(path_raw_img, np.uint8(255 * img)) + + +def comp_class_vec(ouput_vec, index=None): + """ + 计算类向量 + :param ouput_vec: tensor + :param index: int,指定类别 + :return: tensor + """ + if not index: + index = np.argmax(ouput_vec.cpu().data.numpy()) + index = index[np.newaxis, np.newaxis] + index = torch.from_numpy(index) + one_hot = torch.zeros(1, 10).scatter_(1, index, 1) + one_hot.requires_grad = True + class_vec = torch.sum(one_hot * output) # one_hot = 11.8605 + + return class_vec + + +def gen_cam(feature_map, grads): + """ + 依据梯度和特征图,生成cam + :param feature_map: np.array, in [C, H, W] + :param grads: np.array, in [C, H, W] + :return: np.array, [H, W] + """ + cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W) + + weights = np.mean(grads, axis=(1, 2)) # + + for i, w in enumerate(weights): + cam += w * feature_map[i, :, :] + + cam = np.maximum(cam, 0) + cam = cv2.resize(cam, (32, 32)) + cam -= np.min(cam) + cam /= np.max(cam) + + return cam + + +if __name__ == '__main__': + + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_1.png") + path_net = os.path.join(BASE_DIR, "../../Data/", "net_params_72p.pkl") + output_dir = os.path.join(BASE_DIR, "../../Result/backward_hook_cam/") + + classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + fmap_block = list() + grad_block = list() + + # 图片读取;网络加载 + img = cv2.imread(path_img, 1) # H*W*C + img_input = img_preprocess(img) + net = Net() + net.load_state_dict(torch.load(path_net)) + + # 注册hook + net.conv2.register_forward_hook(farward_hook) + net.conv2.register_backward_hook(backward_hook) + + # forward + output = net(img_input) + idx = np.argmax(output.cpu().data.numpy()) + print("predict: {}".format(classes[idx])) + + # backward + net.zero_grad() + class_loss = comp_class_vec(output) + class_loss.backward() + + # 生成cam + grads_val = grad_block[0].cpu().data.numpy().squeeze() + fmap = fmap_block[0].cpu().data.numpy().squeeze() + cam = gen_cam(fmap, grads_val) + + # 保存cam图片 + img_show = np.float32(cv2.resize(img, (32, 32))) / 255 + show_cam_on_image(img_show, cam, output_dir) + + + + + + + + + + + + + + + + diff --git a/Data/cam_img/.DS_Store b/Data/cam_img/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/Data/cam_img/.DS_Store differ diff --git a/Data/cam_img/test_img_1.png b/Data/cam_img/test_img_1.png new file mode 100644 index 0000000000000000000000000000000000000000..bbdb65791fd861e4265f869a090ea7d9fc6a5ce5 Binary files /dev/null and b/Data/cam_img/test_img_1.png differ diff --git a/Data/cam_img/test_img_2.png b/Data/cam_img/test_img_2.png new file mode 100644 index 0000000000000000000000000000000000000000..8da5599ed8c5f7eaa71cb72bacc00023520a8a6d Binary files /dev/null and b/Data/cam_img/test_img_2.png differ diff --git a/Data/cam_img/test_img_3.png b/Data/cam_img/test_img_3.png new file mode 100644 index 0000000000000000000000000000000000000000..bb054813786ba5e403a9c0030ab10db00a47a429 Binary files /dev/null and b/Data/cam_img/test_img_3.png differ diff --git a/Data/cam_img/test_img_4.png b/Data/cam_img/test_img_4.png new file mode 100644 index 0000000000000000000000000000000000000000..a3762d045707a5236b548041abcbf168de823632 Binary files /dev/null and b/Data/cam_img/test_img_4.png differ diff --git a/Data/cam_img/test_img_5.png b/Data/cam_img/test_img_5.png new file mode 100644 index 0000000000000000000000000000000000000000..1af7f3ec808b334f53e170f7eda42f34ff93a968 Binary files /dev/null and b/Data/cam_img/test_img_5.png differ diff --git a/Data/cam_img/test_img_6.png b/Data/cam_img/test_img_6.png new file mode 100644 index 0000000000000000000000000000000000000000..9c6e7f2285c46b77f7140c7681fe90469cb82398 Binary files /dev/null and b/Data/cam_img/test_img_6.png differ diff --git a/Data/cam_img/test_img_7.png b/Data/cam_img/test_img_7.png new file mode 100644 index 0000000000000000000000000000000000000000..5a50b63e809b48637174cc3b193ac91814dc685d Binary files /dev/null and b/Data/cam_img/test_img_7.png differ diff --git a/Data/cam_img/test_img_8.png b/Data/cam_img/test_img_8.png new file mode 100644 index 0000000000000000000000000000000000000000..0f23fdafeb4b2fc3ed4e0b38f121488b3e8bbe23 Binary files /dev/null and b/Data/cam_img/test_img_8.png differ