normlime.py 1.1 KB
Newer Older
S
sunyanfang01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import os.path as osp
import paddlex as pdx

# 下载和解压Imagenet果蔬分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')

# 下载和解压已训练好的MobileNetV2模型
model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
pdx.utils.download_and_decompress(model_file, path='./')

# 加载模型
S
seven 已提交
17 18
model_file = 'mini_imagenet_veg_mobilenetv2'
model = pdx.load_model(model_file)
S
sunyanfang01 已提交
19 20

# 定义测试所用的数据集
S
seven 已提交
21
dataset = 'mini_imagenet_veg'
S
sunyanfang01 已提交
22
test_dataset = pdx.datasets.ImageNet(
S
seven 已提交
23 24 25
    data_dir=dataset,
    file_list=osp.join(dataset, 'test_list.txt'),
    label_list=osp.join(dataset, 'labels.txt'),
S
sunyanfang01 已提交
26 27
    transforms=model.test_transforms)

S
seven 已提交
28 29 30 31 32 33 34 35
# 可解释性可视化
pdx.interpret.normlime(
    test_dataset.file_list[0][0],
    model,
    test_dataset,
    save_dir='./',
    normlime_weights_file='{}_{}.npy'.format(
        dataset.split('/')[-1], model.model_name))