diff --git a/docs/apis/visualize.md b/docs/apis/visualize.md index 8a253fc04310f43b8e424729e7d75f2712e8dc19..069913274580f1e8bd5fdb5ee6e6e642c977b3ce 100755 --- a/docs/apis/visualize.md +++ b/docs/apis/visualize.md @@ -114,27 +114,54 @@ pdx.slim.visualize(model, 'mobilenetv2.sensitivities', save_dir='./') # 可视化结果保存在./sensitivities.png ``` -## 可解释性结果可视化 +## LIME可解释性结果可视化 ``` -paddlex.interpret.visualize(img_file, - model, - dataset=None, - algo='lime', - num_samples=3000, - batch_size=50, - save_dir='./') +paddlex.interpret.lime(img_file, + model, + num_samples=3000, + batch_size=50, + save_dir='./') ``` -将模型预测结果的可解释性可视化,目前只支持分类模型。 +使用LIME算法将模型预测结果的可解释性可视化。 +LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,得到每个输入维度的权重,以此来解释模型。 + +**注意:** 可解释性结果可视化目前只支持分类模型。 ### 参数 >* **img_file** (str): 预测图像路径。 >* **model** (paddlex.cv.models): paddlex中的模型。 ->* **dataset** (paddlex.datasets): 数据集读取器,默认为None。 ->* **algo** (str): 可解释性方式,当前可选'lime'和'normlime'。 >* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。 >* **batch_size** (int): 预测数据batch大小,默认为50。 >* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 ### 使用示例 -> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。 +> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/lime.py)。 + + +## NormLIME可解释性结果可视化 +``` +paddlex.interpret.normlime(img_file, + model, + dataset=None, + num_samples=3000, + batch_size=50, + save_dir='./') +``` +使用NormLIME算法将模型预测结果的可解释性可视化。 +NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 + +**注意:** 可解释性结果可视化目前只支持分类模型。 + +### 参数 +>* **img_file** (str): 预测图像路径。 +>* **model** (paddlex.cv.models): paddlex中的模型。 +>* **dataset** (paddlex.datasets): 数据集读取器,默认为None。 +>* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。 +>* **batch_size** (int): 预测数据batch大小,默认为50。 +>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + +**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 +### 使用示例 +> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。 + diff --git a/paddlex/interpret/__init__.py b/paddlex/interpret/__init__.py index 00a70c496d67194f22ef3b92d579da8f02792b14..1b8ae70a0c581bb24a9402e6a73ad7a4cb55fca6 100644 --- a/paddlex/interpret/__init__.py +++ b/paddlex/interpret/__init__.py @@ -15,4 +15,5 @@ from __future__ import absolute_import from . import visualize -visualize = visualize.visualize \ No newline at end of file +lime = visualize.lime +normlime = visualize.normlime \ No newline at end of file diff --git a/paddlex/interpret/core/normlime_base.py b/paddlex/interpret/core/normlime_base.py index 6fdd2597df319fa641146b33542439a9d87a6d05..df470b6f218f1591f6b648eb86189f425ef294e5 100644 --- a/paddlex/interpret/core/normlime_base.py +++ b/paddlex/interpret/core/normlime_base.py @@ -116,9 +116,8 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav if os.path.exists(save_path): logging.info(save_path + ' exists, not computing this one.', use_color=True) continue - - logging.info('processing'+each_data_ if isinstance(each_data_, str) else data_index + \ - f'+{data_index}/{len(list_data_)}', use_color=True) + img_file_name = each_data_ if isinstance(each_data_, str) else data_index + logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True) image_show = read_image(each_data_) result = predict_fn(image_show) diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index 1b2e12110efef909abe18a18463bc8e2417d83df..de8e9151b9417fd3307c74d7bb67767bed1845c7 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -22,20 +22,65 @@ from .interpretation_predict import interpretation_predict from .core.interpretation import Interpretation from .core.normlime_base import precompute_normlime_weights from .core._session_preparation import gen_user_home - -def visualize(img_file, + +def lime(img_file, + model, + num_samples=3000, + batch_size=50, + save_dir='./'): + """使用LIME算法将模型预测结果的可解释性可视化。 + + LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心, + 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入 + 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系, + 得到每个输入维度的权重,以此来解释模型。 + + 注意:LIME可解释性结果可视化目前只支持分类模型。 + + Args: + img_file (str): 预测图像路径。 + model (paddlex.cv.models): paddlex中的模型。 + num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 + batch_size (int): 预测数据batch大小,默认为50。 + save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + """ + assert model.model_type == 'classifier', \ + 'Now the interpretation visualize only be supported in classifier!' + if model.status != 'Normal': + raise Exception('The interpretation only can deal with the Normal model') + if not osp.exists(save_dir): + os.makedirs(save_dir) + model.arrange_transforms( + transforms=model.test_transforms, mode='test') + tmp_transforms = copy.deepcopy(model.test_transforms) + tmp_transforms.transforms = tmp_transforms.transforms[:-2] + img = tmp_transforms(img_file)[0] + img = np.around(img).astype('uint8') + img = np.expand_dims(img, axis=0) + interpreter = None + interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size) + img_name = osp.splitext(osp.split(img_file)[-1])[0] + interpreter.interpret(img, save_dir=save_dir) + + +def normlime(img_file, model, dataset=None, - algo='lime', num_samples=3000, batch_size=50, save_dir='./'): - """可解释性可视化。 + """使用NormLIME算法将模型预测结果的可解释性可视化。 + + NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 + 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 + + 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 + 注意2:NormLIME可解释性结果可视化目前只支持分类模型。 + Args: img_file (str): 预测图像路径。 model (paddlex.cv.models): paddlex中的模型。 dataset (paddlex.datasets): 数据集读取器,默认为None。 - algo (str): 可解释性方式,当前可选'lime'和'normlime'。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 batch_size (int): 预测数据batch大小,默认为50。 save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 @@ -54,21 +99,16 @@ def visualize(img_file, img = np.around(img).astype('uint8') img = np.expand_dims(img, axis=0) interpreter = None - if algo == 'lime': - interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size) - elif algo == 'normlime': - if dataset is None: - raise Exception('The dataset is None. Cannot implement this kind of interpretation') - interpreter = get_normlime_interpreter(img, model, dataset, - num_samples=num_samples, batch_size=batch_size, + if dataset is None: + raise Exception('The dataset is None. Cannot implement this kind of interpretation') + interpreter = get_normlime_interpreter(img, model, dataset, + num_samples=num_samples, batch_size=batch_size, save_dir=save_dir) - else: - raise Exception('The {} interpretation method is not supported yet!'.format(algo)) img_name = osp.splitext(osp.split(img_file)[-1])[0] interpreter.interpret(img, save_dir=save_dir) -def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50): +def get_lime_interpreter(img, model, num_samples=3000, batch_size=50): def predict_func(image): image = image.astype('float32') for i in range(image.shape[0]): @@ -79,8 +119,8 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50): model.test_transforms.transforms = tmp_transforms return out[0] labels_name = None - if dataset is not None: - labels_name = dataset.labels + if hasattr(model, 'labels'): + labels_name = model.labels interpreter = Interpretation('lime', predict_func, labels_name, diff --git a/tutorials/interpret/lime.py b/tutorials/interpret/lime.py new file mode 100644 index 0000000000000000000000000000000000000000..ae862aa9e41f4ad95c335c8e2a6de5a3b76a4ea2 --- /dev/null +++ b/tutorials/interpret/lime.py @@ -0,0 +1,23 @@ +import os +# 选择使用0号卡 +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import os.path as osp +import paddlex as pdx + +# 下载和解压Imagenet果蔬分类数据集 +veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz' +pdx.utils.download_and_decompress(veg_dataset, path='./') + +# 下载和解压已训练好的MobileNetV2模型 +model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz' +pdx.utils.download_and_decompress(model_file, path='./') + +# 加载模型 +model = pdx.load_model('mini_imagenet_veg_mobilenetv2') + +# 可解释性可视化 +pdx.interpret.lime( + 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', + model, + save_dir='./') diff --git a/tutorials/interpret/interpret.py b/tutorials/interpret/normlime.py similarity index 81% rename from tutorials/interpret/interpret.py rename to tutorials/interpret/normlime.py index f52d1053f5dcb1b2f1a585f50e9e0f2b1cb13ef2..3e501388e44aeab8548ae123831bc3211b08cea7 100644 --- a/tutorials/interpret/interpret.py +++ b/tutorials/interpret/normlime.py @@ -24,15 +24,8 @@ test_dataset = pdx.datasets.ImageNet( transforms=model.test_transforms) # 可解释性可视化 -pdx.interpret.visualize( - 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', - model, - test_dataset, - algo='lime', - save_dir='./') -pdx.interpret.visualize( +pdx.interpret.normlime( 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', model, test_dataset, - algo='normlime', save_dir='./')