diff --git a/docs/apis/visualize.md b/docs/apis/visualize.md index 069913274580f1e8bd5fdb5ee6e6e642c977b3ce..8fe45d4abb82c01c859f0be60bf6f52706eb4e52 100755 --- a/docs/apis/visualize.md +++ b/docs/apis/visualize.md @@ -146,10 +146,11 @@ paddlex.interpret.normlime(img_file, dataset=None, num_samples=3000, batch_size=50, - save_dir='./') + save_dir='./', + normlime_weights_file=None) ``` 使用NormLIME算法将模型预测结果的可解释性可视化。 -NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 +NormLIME是利用一定数量的样本来出一个全局的解释。由于NormLIME计算量较大,此处采用一种简化的方式:使用一定数量的测试样本(目前默认使用所有测试样本),对每个样本进行特征提取,映射到同一个特征空间;然后以此特征做为输入,以模型输出做为输出,使用线性回归对其进行拟合,得到一个全局的输入和输出的关系。之后,对一测试样本进行解释时,使用NormLIME全局的解释,来对LIME的结果进行滤波,使最终的可视化结果更加稳定。 **注意:** 可解释性结果可视化目前只支持分类模型。 @@ -159,9 +160,10 @@ NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会 >* **dataset** (paddlex.datasets): 数据集读取器,默认为None。 >* **num_samples** (int): LIME用于学习线性模型的采样数,默认为3000。 >* **batch_size** (int): 预测数据batch大小,默认为50。 ->* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 +>* **save_dir** (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 +>* **normlime_weights_file** (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。 -**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 +**注意:** dataset`读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。NormLIME可解释性结果可视化目前只支持分类模型。 ### 使用示例 > 对预测可解释性结果可视化的过程可参见[代码](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)。 diff --git a/docs/appendix/interpret.md b/docs/appendix/interpret.md index 886620df2fa98c03abda4717dea627277715b2d9..43ecd48e23810c2e3ed3cd1652bf06b6e1fc04f7 100644 --- a/docs/appendix/interpret.md +++ b/docs/appendix/interpret.md @@ -20,9 +20,20 @@ LIME的使用方式可参见[代码示例](https://github.com/PaddlePaddle/Paddl ## NormLIME NormLIME是在LIME上的改进,LIME的解释是局部性的,是针对当前样本给的特定解释,而NormLIME是利用一定数量的样本对当前样本的一个全局性的解释,有一定的降噪效果。其实现步骤如下所示: 1. 下载Kmeans模型参数和ResNet50_vc网络前三层参数。(ResNet50_vc的参数是在ImageNet上训练所得网络的参数;使用ImageNet图像作为数据集,每张图像从ResNet50_vc的第三层输出提取对应超象素位置上的平均特征和质心上的特征,训练将得到此处的Kmeans模型) -2. 计算测试集中每张图像的LIME结果。(如无测试集,可用验证集代替) -3. 使用Kmeans模型对所有图像中的所有像素进行聚类。 -4. 对在同一个簇的超像素(相同的特征)进行权重的归一化,得到每个超像素的权重,以此来解释模型。 +2. 使用测试集中的数据计算normlime的权重信息(如无测试集,可用验证集代替): + 对每张图像的处理: + (1) 获取图像的超像素。 + (2) 使用ResNet50_vc获取第三层特征,针对每个超像素位置,组合质心特征和均值特征`F`。 + (3) 把`F`作为Kmeans模型的输入,计算每个超像素位置的聚类中心。 + (4) 使用训练好的分类模型,预测该张图像的`label`。 + 对所有图像的处理: + (1) 以每张图像的聚类中心信息组成的向量(若某聚类中心出现在盖章途中设置为1,反之为0)为输入, + 预测的`label`为输出,构建逻辑回归函数`regression_func`。 + (2) 由`regression_func`可获得每个聚类中心不同类别下的权重,并对权重进行归一化。 +3. 使用Kmeans模型获取需要可视化图像的每个超像素的聚类中心。 +4. 对需要可视化的图像的超像素进行随机遮掩构成新的图像。 +5. 对每张构造的图像使用预测模型预测label。 +6. 根据normlime的权重信息,每个超像素可获不同的权重,选取最高的权重为最终的权重,以此来解释模型。 NormLIME的使用方式可参见[代码示例](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/interpret/normlime.py)和[api介绍](../apis/visualize.html#normlime)。在使用时,参数中的`num_samples`设置尤为重要,其表示上述步骤2中的随机采样的个数,若设置过小会影响可解释性结果的稳定性,若设置过大则将在上述步骤3耗费较长时间;参数`batch_size`则表示在计算上述步骤3时,预测的batch size,若设置过小将在上述步骤3耗费较长时间,而上限则根据机器配置决定;而`dataset`则是由测试集或验证集构造的数据。 diff --git a/docs/images/lime.png b/docs/images/lime.png index de435a2e2375a788319f0d80a4cce7a21d395e41..801be69b57c80ad92dcc0ca69bf1a0a4de074b0f 100644 Binary files a/docs/images/lime.png and b/docs/images/lime.png differ diff --git a/docs/images/normlime.png b/docs/images/normlime.png index 4e5099347f261d3f5ce47b93d28cfa484c1d3776..dd9a2f8f96a3ade26179010f340c7c5185bf0656 100644 Binary files a/docs/images/normlime.png and b/docs/images/normlime.png differ diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index f0158402b69ad3eb90aac6b11a134889fda6dc2b..ff07a65e07a0408fefdf0eabd75fda2db305dfe1 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -70,8 +70,10 @@ def normlime(img_file, normlime_weights_file=None): """使用NormLIME算法将模型预测结果的可解释性可视化。 - NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 - 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 + NormLIME是利用一定数量的样本来出一个全局的解释。由于NormLIME计算量较大,此处采用一种简化的方式: + 使用一定数量的测试样本(目前默认使用所有测试样本),对每个样本进行特征提取,映射到同一个特征空间; + 然后以此特征做为输入,以模型输出做为输出,使用线性回归对其进行拟合,得到一个全局的输入和输出的关系。 + 之后,对一测试样本进行解释时,使用NormLIME全局的解释,来对LIME的结果进行滤波,使最终的可视化结果更加稳定。 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 注意2:NormLIME可解释性结果可视化目前只支持分类模型。