提交 09e903eb 编写于 作者: C caoying03

fix v2 infer interface.

上级 45ced9da
...@@ -39,7 +39,6 @@ void CostForOneSequence::calValidExpandStep() { ...@@ -39,7 +39,6 @@ void CostForOneSequence::calValidExpandStep() {
if (start + beamSize_ == findEnd) return; if (start + beamSize_ == findEnd) return;
goldColIds_[i] = findEnd - start; goldColIds_[i] = findEnd - start;
} }
if (goldColIds_[beams_->expansionCount - 1] != -1) goldAsExtraPath_ = false; if (goldColIds_[beams_->expansionCount - 1] != -1) goldAsExtraPath_ = false;
} }
......
...@@ -70,7 +70,7 @@ class Inference(object): ...@@ -70,7 +70,7 @@ class Inference(object):
item = [each_result[each_field] for each_field in field] item = [each_result[each_field] for each_field in field]
yield item yield item
def infer(self, input, field='value', **kwargs): def infer(self, input, field='value', flatten_result=True, **kwargs):
""" """
Infer a data by model. Infer a data by model.
:param input: input data batch. Should be python iterable object. :param input: input data batch. Should be python iterable object.
...@@ -83,7 +83,10 @@ class Inference(object): ...@@ -83,7 +83,10 @@ class Inference(object):
retv = [[] for i in xrange(len(result))] retv = [[] for i in xrange(len(result))]
for i, item in enumerate(result): for i, item in enumerate(result):
retv[i].append(item) retv[i].append(item)
retv = [numpy.concatenate(out) for out in retv]
if flatten_result:
retv = [numpy.concatenate(out) for out in retv]
if len(retv) == 1: if len(retv) == 1:
return retv[0] return retv[0]
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册