diff --git a/paddlex/cv/models/explanation/visualize.py b/paddlex/cv/models/explanation/visualize.py index 9da7ba381c9276c595f092d671ffdf97edae6f20..3e7b45fb58318119a4ef0a2d693e82cb6346d1f6 100644 --- a/paddlex/cv/models/explanation/visualize.py +++ b/paddlex/cv/models/explanation/visualize.py @@ -23,7 +23,7 @@ from .core.normlime_base import precompute_normlime_weights def visualize(img_file, model, - normlime_dataset=None, + dataset=None, explanation_type='lime', num_samples=3000, batch_size=50, @@ -39,11 +39,11 @@ def visualize(img_file, img = np.expand_dims(img, axis=0) explaier = None if explanation_type == 'lime': - explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size) + explaier = get_lime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size) elif explanation_type == 'normlime': - if normlime_dataset is None: - raise Exception('The normlime_dataset is None. Cannot implement this kind of explanation') - explaier = get_normlime_explaier(img, model, normlime_dataset, + if dataset is None: + raise Exception('The dataset is None. Cannot implement this kind of explanation') + explaier = get_normlime_explaier(img, model, dataset, num_samples=num_samples, batch_size=batch_size, save_dir=save_dir) else: @@ -52,7 +52,7 @@ def visualize(img_file, explaier.explain(img, save_dir=save_dir) -def get_lime_explaier(img, model, num_samples=3000, batch_size=50): +def get_lime_explaier(img, model, dataset, num_samples=3000, batch_size=50): def predict_func(image): image = image.astype('float32') for i in range(image.shape[0]): @@ -60,14 +60,18 @@ def get_lime_explaier(img, model, num_samples=3000, batch_size=50): model.test_transforms.transforms = model.test_transforms.transforms[-2:] out = model.explanation_predict(image) return out[0] + labels_name = None + if dataset is not None: + labels_name = dataset.labels explaier = Explanation('lime', predict_func, + labels_name, num_samples=num_samples, batch_size=batch_size) return explaier -def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'): +def get_normlime_explaier(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'): def precompute_predict_func(image): image = image.astype('float32') model.test_transforms.transforms = model.test_transforms.transforms[-2:] @@ -80,6 +84,9 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_ model.test_transforms.transforms = model.test_transforms.transforms[-2:] out = model.explanation_predict(image) return out[0] + labels_name = None + if dataset is not None: + labels_name = dataset.labels root_path = os.environ['HOME'] root_path = osp.join(root_path, '.paddlex') pre_models_path = osp.join(root_path, "pre_models") @@ -88,21 +95,22 @@ def get_normlime_explaier(img, model, normlime_dataset, num_samples=3000, batch_ # TODO # paddlex.utils.download_and_decompress(url, path=pre_models_path) npy_dir = precompute_for_normlime(precompute_predict_func, - normlime_dataset, + dataset, num_samples=num_samples, batch_size=batch_size, save_dir=save_dir) explaier = Explanation('normlime', predict_func, + labels_name, num_samples=num_samples, batch_size=batch_size, normlime_weights=npy_dir) return explaier -def precompute_for_normlime(predict_func, normlime_dataset, num_samples=3000, batch_size=50, save_dir='./'): +def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'): image_list = [] - for item in normlime_dataset.file_list: + for item in dataset.file_list: image_list.append(item[0]) return precompute_normlime_weights( image_list,