From 8d8b44f2b83c67cab01bc75bda2be6042fb26d0f Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Wed, 20 May 2020 12:49:49 +0800 Subject: [PATCH] modify the interpret --- docs/apis/visualize.md | 49 ++++++++++++++------ paddlex/interpret/__init__.py | 3 +- paddlex/interpret/visualize.py | 78 ++++++++++++++++++++++---------- tutorials/interpret/interpret.py | 7 +-- 4 files changed, 93 insertions(+), 44 deletions(-) diff --git a/docs/apis/visualize.md b/docs/apis/visualize.md index c4dba1c..1c67592 100755 --- a/docs/apis/visualize.md +++ b/docs/apis/visualize.md @@ -114,19 +114,42 @@ pdx.slim.visualize(model, 'mobilenetv2.sensitivities', save_dir='./') # 可视化结果保存在./sensitivities.png ``` -## 可解释性结果可视化 +## LIME可解释性结果可视化 ``` -paddlex.interpret.visualize(img_file, - model, - dataset=None, - algo='lime', - num_samples=3000, - batch_size=50, - save_dir='./') +paddlex.interpret.lime(img_file, + model, + num_samples=3000, + batch_size=50, + save_dir='./') ``` -将模型预测结果的可解释性可视化,支持LIME和NormLIME两种可解释性算法。 -LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,得到每个输入维度的权重,以此来解释模型。 -NormLIME则是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 +使用LIME算法将模型预测结果的可解释性可视化。 +LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心,在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系,得到每个输入维度的权重,以此来解释模型。 + +**注意:** 可解释性结果可视化目前只支持分类模型。 + +### 参数 +>* **img_file** (str): 预测图像路径。 +>* **model** (paddlex.cv.models): paddlex中的模型。 +>* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。 +>* **batch_size** (int): 预测数据batch大小,默认为50。 +>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + + +### 使用示例 +> 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。 + + +## LIME可解释性结果可视化 +``` +paddlex.interpret.normlime(img_file, + model, + dataset=None, + num_samples=3000, + batch_size=50, + save_dir='./') +``` +使用NormLIME算法将模型预测结果的可解释性可视化。 +NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 **注意:** 可解释性结果可视化目前只支持分类模型。 @@ -134,11 +157,11 @@ NormLIME则是利用一定数量的样本来出一个全局的解释。NormLIME >* **img_file** (str): 预测图像路径。 >* **model** (paddlex.cv.models): paddlex中的模型。 >* **dataset** (paddlex.datasets): 数据集读取器,默认为None。 ->* **algo** (str): 可解释性方式,当前可选'lime'和'normlime'。 >* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。 >* **batch_size** (int): 预测数据batch大小,默认为50。 >* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 -**注意:** `dataset`参数只有在`algo`为"normlime"的情况下才使用,`dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 +**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 ### 使用示例 > 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/interpret.py)。 + diff --git a/paddlex/interpret/__init__.py b/paddlex/interpret/__init__.py index 00a70c4..1b8ae70 100644 --- a/paddlex/interpret/__init__.py +++ b/paddlex/interpret/__init__.py @@ -15,4 +15,5 @@ from __future__ import absolute_import from . import visualize -visualize = visualize.visualize \ No newline at end of file +lime = visualize.lime +normlime = visualize.normlime \ No newline at end of file diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index 7d86073..de8e915 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -22,32 +22,65 @@ from .interpretation_predict import interpretation_predict from .core.interpretation import Interpretation from .core.normlime_base import precompute_normlime_weights from .core._session_preparation import gen_user_home - -def visualize(img_file, + +def lime(img_file, + model, + num_samples=3000, + batch_size=50, + save_dir='./'): + """使用LIME算法将模型预测结果的可解释性可视化。 + + LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心, + 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入 + 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系, + 得到每个输入维度的权重,以此来解释模型。 + + 注意:LIME可解释性结果可视化目前只支持分类模型。 + + Args: + img_file (str): 预测图像路径。 + model (paddlex.cv.models): paddlex中的模型。 + num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 + batch_size (int): 预测数据batch大小,默认为50。 + save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + """ + assert model.model_type == 'classifier', \ + 'Now the interpretation visualize only be supported in classifier!' + if model.status != 'Normal': + raise Exception('The interpretation only can deal with the Normal model') + if not osp.exists(save_dir): + os.makedirs(save_dir) + model.arrange_transforms( + transforms=model.test_transforms, mode='test') + tmp_transforms = copy.deepcopy(model.test_transforms) + tmp_transforms.transforms = tmp_transforms.transforms[:-2] + img = tmp_transforms(img_file)[0] + img = np.around(img).astype('uint8') + img = np.expand_dims(img, axis=0) + interpreter = None + interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size) + img_name = osp.splitext(osp.split(img_file)[-1])[0] + interpreter.interpret(img, save_dir=save_dir) + + +def normlime(img_file, model, dataset=None, - algo='lime', num_samples=3000, batch_size=50, save_dir='./'): - """可解释性可视化。 + """使用NormLIME算法将模型预测结果的可解释性可视化。 - 将模型预测结果的可解释性可视化,支持LIME和NormLIME两种可解释性算法。 - LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心, - 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入 - 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系, - 得到每个输入维度的权重,以此来解释模型。 - NormLIME则是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 + NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 - 注意:dataset参数只有在algo为"normlime"的情况下才使用,dataset读取的是一个数据集, - 该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 + 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 + 注意2:NormLIME可解释性结果可视化目前只支持分类模型。 Args: img_file (str): 预测图像路径。 model (paddlex.cv.models): paddlex中的模型。 dataset (paddlex.datasets): 数据集读取器,默认为None。 - algo (str): 可解释性方式,当前可选'lime'和'normlime'。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 batch_size (int): 预测数据batch大小,默认为50。 save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 @@ -66,21 +99,16 @@ def visualize(img_file, img = np.around(img).astype('uint8') img = np.expand_dims(img, axis=0) interpreter = None - if algo == 'lime': - interpreter = get_lime_interpreter(img, model, dataset, num_samples=num_samples, batch_size=batch_size) - elif algo == 'normlime': - if dataset is None: - raise Exception('The dataset is None. Cannot implement this kind of interpretation') - interpreter = get_normlime_interpreter(img, model, dataset, - num_samples=num_samples, batch_size=batch_size, + if dataset is None: + raise Exception('The dataset is None. Cannot implement this kind of interpretation') + interpreter = get_normlime_interpreter(img, model, dataset, + num_samples=num_samples, batch_size=batch_size, save_dir=save_dir) - else: - raise Exception('The {} interpretation method is not supported yet!'.format(algo)) img_name = osp.splitext(osp.split(img_file)[-1])[0] interpreter.interpret(img, save_dir=save_dir) -def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50): +def get_lime_interpreter(img, model, num_samples=3000, batch_size=50): def predict_func(image): image = image.astype('float32') for i in range(image.shape[0]): @@ -91,8 +119,8 @@ def get_lime_interpreter(img, model, dataset, num_samples=3000, batch_size=50): model.test_transforms.transforms = tmp_transforms return out[0] labels_name = None - if dataset is not None: - labels_name = dataset.labels + if hasattr(model, 'labels'): + labels_name = model.labels interpreter = Interpretation('lime', predict_func, labels_name, diff --git a/tutorials/interpret/interpret.py b/tutorials/interpret/interpret.py index f52d105..c90870d 100644 --- a/tutorials/interpret/interpret.py +++ b/tutorials/interpret/interpret.py @@ -24,15 +24,12 @@ test_dataset = pdx.datasets.ImageNet( transforms=model.test_transforms) # 可解释性可视化 -pdx.interpret.visualize( +pdx.interpret.lime( 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', model, - test_dataset, - algo='lime', save_dir='./') -pdx.interpret.visualize( +pdx.interpret.normlime( 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', model, test_dataset, - algo='normlime', save_dir='./') -- GitLab