未验证 提交 5dc2716d 编写于 作者: S SunAhong1993 提交者: GitHub

Update visualize.py

上级 7e661ba9
......@@ -23,7 +23,7 @@ from .core.normlime_base import precompute_normlime_weights
def visualize(img_file,
model,
normlime_dataset=None,
dataset=None,
explanation_type='lime',
num_samples=3000,
batch_size=50,
......@@ -39,11 +39,11 @@ def visualize(img_file,
img = np.expand_dims(img, axis=0)
explaier = None
if explanation_type == 'lime':
explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size)
elif explanation_type == 'normlime':
if normlime_dataset is None:
raise Exception('The normlime_dataset is None. Cannot implement this kind of explanation')
explaier = get_normlime_explaier(img, model, normlime_dataset,
if dataset is None:
raise Exception('The dataset is None. Cannot implement this kind of explanation')
explaier = get_normlime_explaier(img, model, dataset,
num_samples=num_samples, batch_size=batch_size,
save_dir=save_dir)
else:
......@@ -52,7 +52,7 @@ def visualize(img_file,
explaier.explain(img, save_dir=save_dir)
def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50):
def predict_func(image):
image = image.astype('float32')
for i in range(image.shape[0]):
......@@ -60,14 +60,18 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
return out[0]
labels_name = None
if dataset is not None:
labels_name = dataset.labels
explaier = Explanation('lime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size)
return explaier
def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_predict_func(image):
image = image.astype('float32')
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
......@@ -80,6 +84,9 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
return out[0]
labels_name = None
if dataset is not None:
labels_name = dataset.labels
root_path = os.environ['HOME']
root_path = osp.join(root_path, '.paddlex')
pre_models_path = osp.join(root_path, "pre_models")
......@@ -88,21 +95,22 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_
# TODO
# paddlex.utils.download_and_decompress(url, path=pre_models_path)
npy_dir = precompute_for_normlime(precompute_predict_func,
normlime_dataset,
dataset,
num_samples=num_samples,
batch_size=batch_size,
save_dir=save_dir)
explaier = Explanation('normlime',
predict_func,
labels_name,
num_samples=num_samples,
batch_size=batch_size,
normlime_weights=npy_dir)
return explaier
def precompute_for_normlime(predict_func, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'):
def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'):
image_list = []
for item in normlime_dataset.file_list:
for item in dataset.file_list:
image_list.append(item[0])
return precompute_normlime_weights(
image_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册