From 2f92c61b7eb0b3a160d85a30e4f269a857873a65 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Tue, 19 May 2020 20:03:30 +0800 Subject: [PATCH] fix the interpret --- .../core/interpretation_algorithms.py | 3 ++ paddlex/interpret/core/lime_base.py | 3 +- paddlex/interpret/visualize.py | 8 ++-- tutorials/interpret/interpret.py | 47 ++++++++++--------- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/paddlex/interpret/core/interpretation_algorithms.py b/paddlex/interpret/core/interpretation_algorithms.py index 51f03c2..601a96d 100644 --- a/paddlex/interpret/core/interpretation_algorithms.py +++ b/paddlex/interpret/core/interpretation_algorithms.py @@ -442,3 +442,6 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): save_outdir, f_out ) ) + print('The image of intrepretation result save in {}'.format(os.path.join( + save_outdir, f_out + ))) diff --git a/paddlex/interpret/core/lime_base.py b/paddlex/interpret/core/lime_base.py index d0b2a79..57844ba 100644 --- a/paddlex/interpret/core/lime_base.py +++ b/paddlex/interpret/core/lime_base.py @@ -36,6 +36,7 @@ from skimage.color import gray2rgb from sklearn.linear_model import Ridge, lars_path from sklearn.utils import check_random_state +import tqdm import copy from functools import partial from skimage.segmentation import quickshift @@ -509,7 +510,7 @@ class LimeImageInterpreter(object): labels = [] data[0, :] = 1 imgs = [] - for row in data: + for row in tqdm.tqdm(data): temp = copy.deepcopy(image) zeros = np.where(row == 0)[0] mask = np.zeros(segments.shape).astype(bool) diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index 819f0de..2810846 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -44,6 +44,8 @@ def visualize(img_file, 'Now the interpretation visualize only be supported in classifier!' if model.status != 'Normal': raise Exception('The interpretation only can deal with the Normal model') + if not osp.exists(save_dir): + os.makedirs(save_dir) model.arrange_transforms( transforms=model.test_transforms, mode='test') tmp_transforms = copy.deepcopy(model.test_transforms) @@ -108,12 +110,12 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 if dataset is not None: labels_name = dataset.labels root_path = os.environ['HOME'] - root_path = osp.join(root_path, '.paddlex') + root_path = osp.join(root_path, '.paddlex0') pre_models_path = osp.join(root_path, "pre_models") if not osp.exists(pre_models_path): - os.makedirs(pre_models_path) + os.makedirs(root_path) url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" - pdx.utils.download_and_decompress(url, path=pre_models_path) + pdx.utils.download_and_decompress(url, path=root_path) npy_dir = precompute_for_normlime(precompute_predict_func, dataset, num_samples=num_samples, diff --git a/tutorials/interpret/interpret.py b/tutorials/interpret/interpret.py index a052ad8..c3673d6 100644 --- a/tutorials/interpret/interpret.py +++ b/tutorials/interpret/interpret.py @@ -4,38 +4,41 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' import os.path as osp import paddlex as pdx +from paddlex.cls import transforms # 下载和解压Imagenet果蔬分类数据集 veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz' pdx.utils.download_and_decompress(veg_dataset, path='./') -# 下载和解压已训练好的MobileNetV2模型 -model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz' -pdx.utils.download_and_decompress(model_file, path='./') - -# 加载模型 -model = pdx.load_model('mini_imagenet_veg_mobilenetv2') +# 定义测试集的transform +test_transforms = transforms.Compose([ + transforms.ResizeByShort(short_size=256), + transforms.CenterCrop(crop_size=224), + transforms.Normalize() +]) # 定义测试所用的数据集 test_dataset = pdx.datasets.ImageNet( data_dir='mini_imagenet_veg', file_list=osp.join('mini_imagenet_veg', 'test_list.txt'), label_list=osp.join('mini_imagenet_veg', 'labels.txt'), - transforms=model.test_transforms) + transforms=test_transforms) -# 可解释性可视化 -# LIME算法 -pdx.interpret.visualize( - 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', - model, - test_dataset, - algo='lime', - save_dir='./') +# 下载和解压已训练好的MobileNetV2模型 +model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz' +pdx.utils.download_and_decompress(model_file, path='./') -# NormLIME算法 -pdx.interpret.visualize( - 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', - model, - test_dataset, - algo='normlime', - save_dir='./') +# 导入模型 +model = pdx.load_model('mini_imagenet_veg_mobilenetv2') + +# 可解释性可视化 +pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', + model, + test_dataset, + algo='lime', + save_dir='./') +pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG', + model, + test_dataset, + algo='normlime', + save_dir='./') \ No newline at end of file -- GitLab