interpret.py 1.2 KB
Newer Older
S
sunyanfang01 已提交
1 2 3 4
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

S
sunyanfang01 已提交
5
import os.path as osp
S
sunyanfang01 已提交
6 7 8
import paddlex as pdx

# 下载和解压Imagenet果蔬分类数据集
S
sunyanfang01 已提交
9
veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
S
sunyanfang01 已提交
10 11
pdx.utils.download_and_decompress(veg_dataset, path='./')

J
jiangjiajun 已提交
12 13 14 15 16 17
# 下载和解压已训练好的MobileNetV2模型
model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
pdx.utils.download_and_decompress(model_file, path='./')

# 加载模型
model = pdx.load_model('mini_imagenet_veg_mobilenetv2')
S
sunyanfang01 已提交
18 19 20 21 22 23

# 定义测试所用的数据集
test_dataset = pdx.datasets.ImageNet(
    data_dir='mini_imagenet_veg',
    file_list=osp.join('mini_imagenet_veg', 'test_list.txt'),
    label_list=osp.join('mini_imagenet_veg', 'labels.txt'),
J
jiangjiajun 已提交
24
    transforms=model.test_transforms)
S
sunyanfang01 已提交
25 26

# 可解释性可视化
J
jiangjiajun 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
# LIME算法
pdx.interpret.visualize(
    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
    model,
    test_dataset,
    algo='lime',
    save_dir='./')

# NormLIME算法
pdx.interpret.visualize(
    'mini_imagenet_veg/mushroom/n07734744_1106.JPEG',
    model,
    test_dataset,
    algo='normlime',
    save_dir='./')