提交 b1ff0a34 编写于 作者: J jiangjiajun

fix scope bug for interpret

上级 75596064
......@@ -15,11 +15,17 @@
import numpy as np
import cv2
import copy
import paddle.fluid as fluid
from paddlex.cv.transforms import arrange_transforms
def interpretation_predict(model, images):
images = images.astype('float32')
model.arrange_transforms(transforms=model.test_transforms, mode='test')
arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms.transforms)
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
......@@ -29,9 +35,11 @@ def interpretation_predict(model, images):
new_imgs.append(model.test_transforms(images[i])[0])
new_imgs = np.array(new_imgs)
out = model.exe.run(model.test_prog,
feed={'image': new_imgs},
fetch_list=list(model.interpretation_feats.values()))
with fluid.scope_guard(model.scope):
out = model.exe.run(
model.test_prog,
feed={'image': new_imgs},
fetch_list=list(model.interpretation_feats.values()))
model.test_transforms.transforms = tmp_transforms
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -22,6 +22,7 @@ from .interpretation_predict import interpretation_predict
from .core.interpretation import Interpretation
from .core.normlime_base import precompute_global_classifier
from .core._session_preparation import gen_user_home
from paddlex.cv.transforms import arrange_transforms
def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
......@@ -48,7 +49,11 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir):
os.makedirs(save_dir)
model.arrange_transforms(transforms=model.test_transforms, mode='test')
arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0]
......@@ -94,7 +99,11 @@ def normlime(img_file,
'The interpretation only can deal with the Normal model')
if not osp.exists(save_dir):
os.makedirs(save_dir)
model.arrange_transforms(transforms=model.test_transforms, mode='test')
arrange_transforms(
model.model_type,
model.__class__.__name__,
transforms=model.test_transforms,
mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册