未验证 提交 5a2ad684 编写于 作者: S SunAhong1993 提交者: GitHub

Update visualize.py

上级 a2c278a7
......@@ -57,8 +57,10 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
image = image.astype('float32')
for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
labels_name = None
if dataset is not None:
......@@ -74,15 +76,19 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image):
image = image.astype('float32')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
def predict_func(image):
image = image.astype('float32')
for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0]
labels_name = None
if dataset is not None:
......@@ -118,6 +124,4 @@ def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=
num_samples=num_samples,
batch_size=batch_size,
save_dir=save_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册