未验证 提交 5a2ad684 编写于 作者: S SunAhong1993 提交者: GitHub

Update visualize.py

上级 a2c278a7
...@@ -57,8 +57,10 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50): ...@@ -57,8 +57,10 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
image = image.astype('float32') image = image.astype('float32')
for i in range(image.shape[0]): for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
if dataset is not None: if dataset is not None:
...@@ -74,15 +76,19 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50): ...@@ -74,15 +76,19 @@ def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'): def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image): def precompute_predict_func(image):
image = image.astype('float32') image = image.astype('float32')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
def predict_func(image): def predict_func(image):
image = image.astype('float32') image = image.astype('float32')
for i in range(image.shape[0]): for i in range(image.shape[0]):
image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR)
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.explanation_predict(image)
model.test_transforms.transforms = tmp_transforms
return out[0] return out[0]
labels_name = None labels_name = None
if dataset is not None: if dataset is not None:
...@@ -118,6 +124,4 @@ def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size= ...@@ -118,6 +124,4 @@ def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size, batch_size=batch_size,
save_dir=save_dir) save_dir=save_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册