From ac843bb8790410afa95a7150a0ee909695a8126f Mon Sep 17 00:00:00 2001 From: Tao Luo <luotao02@baidu.com> Date: Wed, 12 Apr 2017 18:14:12 +0800 Subject: [PATCH] Update with comments --- python/paddle/v2/inference.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 95968dede4..32636c5505 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -38,23 +38,26 @@ class Inference(object): def iter_infer_field(self, field, **kwargs): for result in self.iter_infer(**kwargs): - yield [each_result[field] for each_result in result] + yield [ + each_result[each_field] + for each_result in result for each_field in field + ] def infer(self, field='value', **kwargs): if not isinstance(field, list) and not isinstance(field, tuple): field = [field] - retv_list = [] - for each_field in field: - retv = None - for result in self.iter_infer_field(field=each_field, **kwargs): - if retv is None: - retv = [[]] * len(result) - for i, item in enumerate(result): - retv[i].append(item) - retv = [numpy.concatenate(out) for out in retv] - retv_list.append(retv[0] if len(retv) == 1 else retv) - return retv_list[0] if len(retv_list) == 1 else retv_list + retv = None + for result in self.iter_infer_field(field=field, **kwargs): + if retv is None: + retv = [[]] * len(result) + for i, item in enumerate(result): + retv[i].append(item) + retv = [numpy.concatenate(out) for out in retv] + if len(retv) == 1: + return retv[0] + else: + return retv def infer(output_layer, parameters, input, feeding=None, field='value'): -- GitLab