diff --git a/Code/4_viewer/6_hook_for_grad_cam.py b/Code/4_viewer/6_hook_for_grad_cam.py index 476d40f7be885b9d4c7ef026b0d0a97bd6696731..bd01ba991c5168fbae2b0d9320dfc7e2c5562733 100644 --- a/Code/4_viewer/6_hook_for_grad_cam.py +++ b/Code/4_viewer/6_hook_for_grad_cam.py @@ -94,6 +94,8 @@ def comp_class_vec(ouput_vec, index=None): """ if not index: index = np.argmax(ouput_vec.cpu().data.numpy()) + else: + index = np.array(index) index = index[np.newaxis, np.newaxis] index = torch.from_numpy(index) one_hot = torch.zeros(1, 10).scatter_(1, index, 1) @@ -128,7 +130,8 @@ def gen_cam(feature_map, grads): if __name__ == '__main__': BASE_DIR = os.path.dirname(os.path.abspath(__file__)) - path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_1.png") + path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_8.png") + path_img = "/Users/tingsongyu/Desktop/t.png" path_net = os.path.join(BASE_DIR, "../../Data/", "net_params_72p.pkl") output_dir = os.path.join(BASE_DIR, "../../Result/backward_hook_cam/")