From ac843bb8790410afa95a7150a0ee909695a8126f Mon Sep 17 00:00:00 2001
From: Tao Luo <luotao02@baidu.com>
Date: Wed, 12 Apr 2017 18:14:12 +0800
Subject: [PATCH] Update with comments

---
 python/paddle/v2/inference.py | 27 +++++++++++++++------------
 1 file changed, 15 insertions(+), 12 deletions(-)

diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py
index 95968dede4..32636c5505 100644
--- a/python/paddle/v2/inference.py
+++ b/python/paddle/v2/inference.py
@@ -38,23 +38,26 @@ class Inference(object):
 
     def iter_infer_field(self, field, **kwargs):
         for result in self.iter_infer(**kwargs):
-            yield [each_result[field] for each_result in result]
+            yield [
+                each_result[each_field]
+                for each_result in result for each_field in field
+            ]
 
     def infer(self, field='value', **kwargs):
         if not isinstance(field, list) and not isinstance(field, tuple):
             field = [field]
 
-        retv_list = []
-        for each_field in field:
-            retv = None
-            for result in self.iter_infer_field(field=each_field, **kwargs):
-                if retv is None:
-                    retv = [[]] * len(result)
-                for i, item in enumerate(result):
-                    retv[i].append(item)
-            retv = [numpy.concatenate(out) for out in retv]
-            retv_list.append(retv[0] if len(retv) == 1 else retv)
-        return retv_list[0] if len(retv_list) == 1 else retv_list
+        retv = None
+        for result in self.iter_infer_field(field=field, **kwargs):
+            if retv is None:
+                retv = [[]] * len(result)
+            for i, item in enumerate(result):
+                retv[i].append(item)
+        retv = [numpy.concatenate(out) for out in retv]
+        if len(retv) == 1:
+            return retv[0]
+        else:
+            return retv
 
 
 def infer(output_layer, parameters, input, feeding=None, field='value'):
-- 
GitLab