From fd882edb2657bae1bbf2861bc42f3b743d438c04 Mon Sep 17 00:00:00 2001 From: SunAhong1993 <48579383+SunAhong1993@users.noreply.github.com> Date: Sat, 16 May 2020 16:12:29 +0800 Subject: [PATCH] Update classifier.py --- paddlex/cv/models/classifier.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index 793a8e9..3a94df7 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -63,7 +63,9 @@ class BaseClassifier(BaseAPI): net_out = model(image, num_classes=self.num_classes) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) inputs = OrderedDict([('image', image)]) - outputs = OrderedDict([('predict', softmax_out), ('logits', net_out)]) + outputs = OrderedDict([('predict', softmax_out)]) + if mode == 'test': + self.explanation_feats = OrderedDict([('logits', net_out)]) if mode != 'test': cost = fluid.layers.cross_entropy(input=softmax_out, label=label) avg_cost = fluid.layers.mean(cost) @@ -284,8 +286,8 @@ class BaseClassifier(BaseAPI): result = self.exe.run( self.test_prog, feed={'image': new_imgs}, - fetch_list=list(self.test_outputs.values())) - return result[1:] + fetch_list=list(self.explanation_feats.values())) + return result class ResNet18(BaseClassifier): def __init__(self, num_classes=1000): -- GitLab