提交 ac843bb8 编写于 作者: T Tao Luo 提交者: GitHub

Update with comments

上级 a503f3ca
...@@ -38,23 +38,26 @@ class Inference(object): ...@@ -38,23 +38,26 @@ class Inference(object):
def iter_infer_field(self, field, **kwargs): def iter_infer_field(self, field, **kwargs):
for result in self.iter_infer(**kwargs): for result in self.iter_infer(**kwargs):
yield [each_result[field] for each_result in result] yield [
each_result[each_field]
for each_result in result for each_field in field
]
def infer(self, field='value', **kwargs): def infer(self, field='value', **kwargs):
if not isinstance(field, list) and not isinstance(field, tuple): if not isinstance(field, list) and not isinstance(field, tuple):
field = [field] field = [field]
retv_list = [] retv = None
for each_field in field: for result in self.iter_infer_field(field=field, **kwargs):
retv = None if retv is None:
for result in self.iter_infer_field(field=each_field, **kwargs): retv = [[]] * len(result)
if retv is None: for i, item in enumerate(result):
retv = [[]] * len(result) retv[i].append(item)
for i, item in enumerate(result): retv = [numpy.concatenate(out) for out in retv]
retv[i].append(item) if len(retv) == 1:
retv = [numpy.concatenate(out) for out in retv] return retv[0]
retv_list.append(retv[0] if len(retv) == 1 else retv) else:
return retv_list[0] if len(retv_list) == 1 else retv_list return retv
def infer(output_layer, parameters, input, feeding=None, field='value'): def infer(output_layer, parameters, input, feeding=None, field='value'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部