提交 09e903eb 编写于 作者: C caoying03

fix v2 infer interface.

上级 45ced9da
......@@ -39,7 +39,6 @@ void CostForOneSequence::calValidExpandStep() {
if (start + beamSize_ == findEnd) return;
goldColIds_[i] = findEnd - start;
}
if (goldColIds_[beams_->expansionCount - 1] != -1) goldAsExtraPath_ = false;
}
......
......@@ -70,7 +70,7 @@ class Inference(object):
item = [each_result[each_field] for each_field in field]
yield item
def infer(self, input, field='value', **kwargs):
def infer(self, input, field='value', flatten_result=True, **kwargs):
"""
Infer a data by model.
:param input: input data batch. Should be python iterable object.
......@@ -83,7 +83,10 @@ class Inference(object):
retv = [[] for i in xrange(len(result))]
for i, item in enumerate(result):
retv[i].append(item)
retv = [numpy.concatenate(out) for out in retv]
if flatten_result:
retv = [numpy.concatenate(out) for out in retv]
if len(retv) == 1:
return retv[0]
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册