提交 b1ff0a34 编写于 作者: J jiangjiajun

fix scope bug for interpret

上级 75596064
...@@ -15,11 +15,17 @@ ...@@ -15,11 +15,17 @@
import numpy as np import numpy as np
import cv2 import cv2
import copy import copy
import paddle.fluid as fluid
from paddlex.cv.transforms import arrange_transforms
def interpretation_predict(model, images): def interpretation_predict(model, images):
images = images.astype('float32') images = images.astype('float32')
model.arrange_transforms(transforms=model.test_transforms, mode='test') arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms) tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:] model.test_transforms.transforms = model.test_transforms.transforms[-2:]
...@@ -29,7 +35,9 @@ def interpretation_predict(model, images): ...@@ -29,7 +35,9 @@ def interpretation_predict(model, images):
new_imgs.append(model.test_transforms(images[i])[0]) new_imgs.append(model.test_transforms(images[i])[0])
new_imgs = np.array(new_imgs) new_imgs = np.array(new_imgs)
out = model.exe.run(model.test_prog, with fluid.scope_guard(model.scope):
out = model.exe.run(
model.test_prog,
feed={'image': new_imgs}, feed={'image': new_imgs},
fetch_list=list(model.interpretation_feats.values())) fetch_list=list(model.interpretation_feats.values()))
......
...@@ -22,6 +22,7 @@ from .interpretation_predict import interpretation_predict ...@@ -22,6 +22,7 @@ from .interpretation_predict import interpretation_predict
from .core.interpretation import Interpretation from .core.interpretation import Interpretation
from .core.normlime_base import precompute_global_classifier from .core.normlime_base import precompute_global_classifier
from .core._session_preparation import gen_user_home from .core._session_preparation import gen_user_home
from paddlex.cv.transforms import arrange_transforms
def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
...@@ -48,7 +49,11 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): ...@@ -48,7 +49,11 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
'The interpretation only can deal with the Normal model') 'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
model.arrange_transforms(transforms=model.test_transforms, mode='test') arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2] tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0] img = tmp_transforms(img_file)[0]
...@@ -94,7 +99,11 @@ def normlime(img_file, ...@@ -94,7 +99,11 @@ def normlime(img_file,
'The interpretation only can deal with the Normal model') 'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir): if not osp.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
model.arrange_transforms(transforms=model.test_transforms, mode='test') arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2] tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0] img = tmp_transforms(img_file)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册