未验证 提交 5dc2716d 编写于 作者: S SunAhong1993 提交者: GitHub

Update visualize.py

上级 7e661ba9
...@@ -23,7 +23,7 @@ from .core.normlime_base import precompute_normlime_weights ...@@ -23,7 +23,7 @@ from .core.normlime_base import precompute_normlime_weights
def visualize(img_file, def visualize(img_file,
model, model,
normlime_dataset=None, dataset=None,
explanation_type='lime', explanation_type='lime',
num_samples=3000, num_samples=3000,
batch_size=50, batch_size=50,
...@@ -39,11 +39,11 @@ def visualize(img_file, ...@@ -39,11 +39,11 @@ def visualize(img_file,
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
explaier = None explaier = None
if explanation_type == 'lime': if explanation_type == 'lime':
explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size) explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
elif explanation_type == 'normlime': elif explanation_type == 'normlime':
if normlime_dataset is None: if dataset is None:
raise Exception('The normlime_dataset is None. Cannot implement this kind of explanation') raise Exception('The dataset is None. Cannot implement this kind of explanation')
explaier = get_normlime_explaier(img, model, normlime_dataset, explaier = get_normlime_explaier(img, model, dataset,
num_samples=num_samples, batch_size=batch_size, num_samples=num_samples, batch_size=batch_size,
save_dir=save_dir) save_dir=save_dir)
else: else:
...@@ -52,7 +52,7 @@ def visualize(img_file, ...@@ -52,7 +52,7 @@ def visualize(img_file,
explaier.explain(img, save_dir=save_dir) explaier.explain(img, save_dir=save_dir)
def get_lime_explaier(img, model, num_samples=3000, batch_size=50): def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
def predict_func(image): def predict_func(image):
image = image.astype('float32') image = image.astype('float32')
for i in range(image.shape[0]): for i in range(image.shape[0]):
...@@ -60,14 +60,18 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50): ...@@ -60,14 +60,18 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.explanation_predict(image)
return out[0] return out[0]
labels_name = None
if dataset is not None:
labels_name = dataset.labels
explaier = Explanation('lime', explaier = Explanation('lime',
predict_func, predict_func,
labels_name,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size) batch_size=batch_size)
return explaier return explaier
def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'): def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image): def precompute_predict_func(image):
image = image.astype('float32') image = image.astype('float32')
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
...@@ -80,6 +84,9 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_ ...@@ -80,6 +84,9 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image) out = model.explanation_predict(image)
return out[0] return out[0]
labels_name = None
if dataset is not None:
labels_name = dataset.labels
root_path = os.environ['HOME'] root_path = os.environ['HOME']
root_path = osp.join(root_path, '.paddlex') root_path = osp.join(root_path, '.paddlex')
pre_models_path = osp.join(root_path, "pre_models") pre_models_path = osp.join(root_path, "pre_models")
...@@ -88,21 +95,22 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_ ...@@ -88,21 +95,22 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
# TODO # TODO
# paddlex.utils.download_and_decompress(url, path=pre_models_path) # paddlex.utils.download_and_decompress(url, path=pre_models_path)
npy_dir = precompute_for_normlime(precompute_predict_func, npy_dir = precompute_for_normlime(precompute_predict_func,
normlime_dataset, dataset,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size, batch_size=batch_size,
save_dir=save_dir) save_dir=save_dir)
explaier = Explanation('normlime', explaier = Explanation('normlime',
predict_func, predict_func,
labels_name,
num_samples=num_samples, num_samples=num_samples,
batch_size=batch_size, batch_size=batch_size,
normlime_weights=npy_dir) normlime_weights=npy_dir)
return explaier return explaier
def precompute_for_normlime(predict_func, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'): def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
image_list = [] image_list = []
for item in normlime_dataset.file_list: for item in dataset.file_list:
image_list.append(item[0]) image_list.append(item[0])
return precompute_normlime_weights( return precompute_normlime_weights(
image_list, image_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册