未验证 提交 fd882edb 编写于 作者: S SunAhong1993 提交者: GitHub

Update classifier.py

上级 7822db4d
......@@ -63,7 +63,9 @@ class BaseClassifier(BaseAPI):
net_out = model(image, num_classes=self.num_classes)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
inputs = OrderedDict([('image', image)])
outputs = OrderedDict([('predict', softmax_out), ('logits', net_out)])
outputs = OrderedDict([('predict', softmax_out)])
if mode == 'test':
self.explanation_feats = OrderedDict([('logits', net_out)])
if mode != 'test':
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
......@@ -284,8 +286,8 @@ class BaseClassifier(BaseAPI):
result = self.exe.run(
self.test_prog,
feed={'image': new_imgs},
fetch_list=list(self.test_outputs.values()))
return result[1:]
fetch_list=list(self.explanation_feats.values()))
return result
class ResNet18(BaseClassifier):
def __init__(self, num_classes=1000):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册