From b1ff0a34d07371a136e6aa8398074c5c3efd5bef Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Sat, 11 Jul 2020 08:38:56 +0000 Subject: [PATCH] fix scope bug for interpret --- paddlex/interpret/interpretation_predict.py | 16 ++++++++++++---- paddlex/interpret/visualize.py | 19 ++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/paddlex/interpret/interpretation_predict.py b/paddlex/interpret/interpretation_predict.py index 31b3b47..2ebe9a8 100644 --- a/paddlex/interpret/interpretation_predict.py +++ b/paddlex/interpret/interpretation_predict.py @@ -15,11 +15,17 @@ import numpy as np import cv2 import copy +import paddle.fluid as fluid +from paddlex.cv.transforms import arrange_transforms def interpretation_predict(model, images): images = images.astype('float32') - model.arrange_transforms(transforms=model.test_transforms, mode='test') + arrange_transforms( + model.model_type, + model.__class__.__name__, + transforms=model.test_transforms, + mode='test') tmp_transforms = copy.deepcopy(model.test_transforms.transforms) model.test_transforms.transforms = model.test_transforms.transforms[-2:] @@ -29,9 +35,11 @@ def interpretation_predict(model, images): new_imgs.append(model.test_transforms(images[i])[0]) new_imgs = np.array(new_imgs) - out = model.exe.run(model.test_prog, - feed={'image': new_imgs}, - fetch_list=list(model.interpretation_feats.values())) + with fluid.scope_guard(model.scope): + out = model.exe.run( + model.test_prog, + feed={'image': new_imgs}, + fetch_list=list(model.interpretation_feats.values())) model.test_transforms.transforms = tmp_transforms diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index 6c3570b..2d7c096 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -1,11 +1,11 @@ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,6 +22,7 @@ from .interpretation_predict import interpretation_predict from .core.interpretation import Interpretation from .core.normlime_base import precompute_global_classifier from .core._session_preparation import gen_user_home +from paddlex.cv.transforms import arrange_transforms def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): @@ -48,7 +49,11 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): 'The interpretation only can deal with the Normal model') if not osp.exists(save_dir): os.makedirs(save_dir) - model.arrange_transforms(transforms=model.test_transforms, mode='test') + arrange_transforms( + model.model_type, + model.__class__.__name__, + transforms=model.test_transforms, + mode='test') tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms.transforms = tmp_transforms.transforms[:-2] img = tmp_transforms(img_file)[0] @@ -94,7 +99,11 @@ def normlime(img_file, 'The interpretation only can deal with the Normal model') if not osp.exists(save_dir): os.makedirs(save_dir) - model.arrange_transforms(transforms=model.test_transforms, mode='test') + arrange_transforms( + model.model_type, + model.__class__.__name__, + transforms=model.test_transforms, + mode='test') tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms.transforms = tmp_transforms.transforms[:-2] img = tmp_transforms(img_file)[0] -- GitLab