提交 2f92c61b 编写于 作者: S sunyanfang01

fix the interpret

上级 8887832e
...@@ -442,3 +442,6 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): ...@@ -442,3 +442,6 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
save_outdir, f_out save_outdir, f_out
) )
) )
print('The image of intrepretation result save in {}'.format(os.path.join(
save_outdir, f_out
)))
...@@ -36,6 +36,7 @@ from skimage.color import gray2rgb ...@@ -36,6 +36,7 @@ from skimage.color import gray2rgb
from sklearn.linear_model import Ridge, lars_path from sklearn.linear_model import Ridge, lars_path
from sklearn.utils import check_random_state from sklearn.utils import check_random_state
import tqdm
import copy import copy
from functools import partial from functools import partial
from skimage.segmentation import quickshift from skimage.segmentation import quickshift
...@@ -509,7 +510,7 @@ class LimeImageInterpreter(object): ...@@ -509,7 +510,7 @@ class LimeImageInterpreter(object):
labels = [] labels = []
data[0, :] = 1 data[0, :] = 1
imgs = [] imgs = []
for row in data: for row in tqdm.tqdm(data):
temp = copy.deepcopy(image) temp = copy.deepcopy(image)
zeros = np.where(row == 0)[0] zeros = np.where(row == 0)[0]
mask = np.zeros(segments.shape).astype(bool) mask = np.zeros(segments.shape).astype(bool)
......
...@@ -44,6 +44,8 @@ def visualize(img_file, ...@@ -44,6 +44,8 @@ def visualize(img_file,
'Now the interpretation visualize only be supported in classifier!' 'Now the interpretation visualize only be supported in classifier!'
if model.status != 'Normal': if model.status != 'Normal':
raise Exception('The interpretation only can deal with the Normal model') raise Exception('The interpretation only can deal with the Normal model')
if not osp.exists(save_dir):
os.makedirs(save_dir)
model.arrange_transforms( model.arrange_transforms(
transforms=model.test_transforms, mode='test') transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
...@@ -108,12 +110,12 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 ...@@ -108,12 +110,12 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
if dataset is not None: if dataset is not None:
labels_name = dataset.labels labels_name = dataset.labels
root_path = os.environ['HOME'] root_path = os.environ['HOME']
root_path = osp.join(root_path, '.paddlex') root_path = osp.join(root_path, '.paddlex0')
pre_models_path = osp.join(root_path, "pre_models") pre_models_path = osp.join(root_path, "pre_models")
if not osp.exists(pre_models_path): if not osp.exists(pre_models_path):
os.makedirs(pre_models_path) os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=pre_models_path) pdx.utils.download_and_decompress(url, path=root_path)
npy_dir = precompute_for_normlime(precompute_predict_func, npy_dir = precompute_for_normlime(precompute_predict_func,
dataset, dataset,
num_samples=num_samples, num_samples=num_samples,
......
...@@ -4,37 +4,40 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' ...@@ -4,37 +4,40 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import os.path as osp import os.path as osp
import paddlex as pdx import paddlex as pdx
from paddlex.cls import transforms
# 下载和解压Imagenet果蔬分类数据集 # 下载和解压Imagenet果蔬分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz' veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./') pdx.utils.download_and_decompress(veg_dataset, path='./')
# 下载和解压已训练好的MobileNetV2模型 # 定义测试集的transform
model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz' test_transforms = transforms.Compose([
pdx.utils.download_and_decompress(model_file, path='./') transforms.ResizeByShort(short_size=256),
transforms.CenterCrop(crop_size=224),
# 加载模型 transforms.Normalize()
model = pdx.load_model('mini_imagenet_veg_mobilenetv2') ])
# 定义测试所用的数据集 # 定义测试所用的数据集
test_dataset = pdx.datasets.ImageNet( test_dataset = pdx.datasets.ImageNet(
data_dir='mini_imagenet_veg', data_dir='mini_imagenet_veg',
file_list=osp.join('mini_imagenet_veg', 'test_list.txt'), file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
label_list=osp.join('mini_imagenet_veg', 'labels.txt'), label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
transforms=model.test_transforms) transforms=test_transforms)
# 下载和解压已训练好的MobileNetV2模型
model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
pdx.utils.download_and_decompress(model_file, path='./')
# 导入模型
model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
# 可解释性可视化 # 可解释性可视化
# LIME算法 pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
pdx.interpret.visualize(
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
model, model,
test_dataset, test_dataset,
algo='lime', algo='lime',
save_dir='./') save_dir='./')
pdx.interpret.visualize('mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
# NormLIME算法
pdx.interpret.visualize(
'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
model, model,
test_dataset, test_dataset,
algo='normlime', algo='normlime',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册