diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index ec3c67d89548f68d705a9b5de80e28597e9829da..95968dede4d303a6eaed0bece21323d51d84b83d 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -41,17 +41,20 @@ class Inference(object): yield [each_result[field] for each_result in result] def infer(self, field='value', **kwargs): - retv = None - for result in self.iter_infer_field(field=field, **kwargs): - if retv is None: - retv = [[]] * len(result) - for i, item in enumerate(result): - retv[i].append(item) - retv = [numpy.concatenate(out) for out in retv] - if len(retv) == 1: - return retv[0] - else: - return retv + if not isinstance(field, list) and not isinstance(field, tuple): + field = [field] + + retv_list = [] + for each_field in field: + retv = None + for result in self.iter_infer_field(field=each_field, **kwargs): + if retv is None: + retv = [[]] * len(result) + for i, item in enumerate(result): + retv[i].append(item) + retv = [numpy.concatenate(out) for out in retv] + retv_list.append(retv[0] if len(retv) == 1 else retv) + return retv_list[0] if len(retv_list) == 1 else retv_list def infer(output_layer, parameters, input, feeding=None, field='value'):