From 09e903eb9417745952ced6db532594fd4a759d74 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Tue, 29 Aug 2017 13:44:51 +0800 Subject: [PATCH] fix v2 infer interface. --- paddle/gserver/layers/CrossEntropyOverBeam.cpp | 1 - python/paddle/v2/inference.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/gserver/layers/CrossEntropyOverBeam.cpp b/paddle/gserver/layers/CrossEntropyOverBeam.cpp index 500cd6ff8cc..bffcc301543 100644 --- a/paddle/gserver/layers/CrossEntropyOverBeam.cpp +++ b/paddle/gserver/layers/CrossEntropyOverBeam.cpp @@ -39,7 +39,6 @@ void CostForOneSequence::calValidExpandStep() { if (start + beamSize_ == findEnd) return; goldColIds_[i] = findEnd - start; } - if (goldColIds_[beams_->expansionCount - 1] != -1) goldAsExtraPath_ = false; } diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 4dcc3ab57e7..8acea6155c5 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -70,7 +70,7 @@ class Inference(object): item = [each_result[each_field] for each_field in field] yield item - def infer(self, input, field='value', **kwargs): + def infer(self, input, field='value', flatten_result=True, **kwargs): """ Infer a data by model. :param input: input data batch. Should be python iterable object. @@ -83,7 +83,10 @@ class Inference(object): retv = [[] for i in xrange(len(result))] for i, item in enumerate(result): retv[i].append(item) - retv = [numpy.concatenate(out) for out in retv] + + if flatten_result: + retv = [numpy.concatenate(out) for out in retv] + if len(retv) == 1: return retv[0] else: -- GitLab