提交 8f095942 编写于 作者: S sunyanfang01

fix the tutorial

上级 0d54ff26
...@@ -17,6 +17,7 @@ import cv2 ...@@ -17,6 +17,7 @@ import cv2
import copy import copy
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import paddlex as pdx
from .interpretation_predict import interpretation_predict from .interpretation_predict import interpretation_predict
from .core.interpretation import Interpretation from .core.interpretation import Interpretation
from .core.normlime_base import precompute_normlime_weights from .core.normlime_base import precompute_normlime_weights
...@@ -35,7 +36,7 @@ def visualize(img_file, ...@@ -35,7 +36,7 @@ def visualize(img_file,
model (paddlex.cv.models): paddlex中的模型。 model (paddlex.cv.models): paddlex中的模型。
dataset (paddlex.datasets): 数据集读取器,默认为None。 dataset (paddlex.datasets): 数据集读取器,默认为None。
algo (str): 可解释性方式,当前可选'lime'和'normlime'。 algo (str): 可解释性方式,当前可选'lime'和'normlime'。
num_samples (int): 随机采样数量,默认为3000。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。
batch_size (int): 预测数据batch大小,默认为50。 batch_size (int): 预测数据batch大小,默认为50。
save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。
""" """
...@@ -111,8 +112,8 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 ...@@ -111,8 +112,8 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5
pre_models_path = osp.join(root_path, "pre_models") pre_models_path = osp.join(root_path, "pre_models")
if not osp.exists(pre_models_path): if not osp.exists(pre_models_path):
os.makedirs(pre_models_path) os.makedirs(pre_models_path)
# TODO url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
# paddlex.utils.download_and_decompress(url, path=pre_models_path) pdx.utils.download_and_decompress(url, path=pre_models_path)
npy_dir = precompute_for_normlime(precompute_predict_func, npy_dir = precompute_for_normlime(precompute_predict_func,
dataset, dataset,
num_samples=num_samples, num_samples=num_samples,
......
...@@ -2,11 +2,12 @@ import os ...@@ -2,11 +2,12 @@ import os
# 选择使用0号卡 # 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import os.path as osp
import paddlex as pdx import paddlex as pdx
from paddlex.cla import transforms from paddlex.cls import transforms
# 下载和解压Imagenet果蔬分类数据集 # 下载和解压Imagenet果蔬分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/mini_imagenet_veg.tar.gz' veg_dataset = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./') pdx.utils.download_and_decompress(veg_dataset, path='./')
# 定义测试集的transform # 定义测试集的transform
...@@ -24,7 +25,7 @@ test_dataset = pdx.datasets.ImageNet( ...@@ -24,7 +25,7 @@ test_dataset = pdx.datasets.ImageNet(
transforms=test_transforms) transforms=test_transforms)
# 下载和解压已训练好的MobileNetV2模型 # 下载和解压已训练好的MobileNetV2模型
model_file = 'https://bj.bcebos.com/paddlex/models/mini_imagenet_veg_mobilenetv2.tar.gz' model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilenetv2.tar.gz'
pdx.utils.download_and_decompress(model_file, path='./') pdx.utils.download_and_decompress(model_file, path='./')
# 导入模型 # 导入模型
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册