diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index ec3c67d89548f68d705a9b5de80e28597e9829da..31a5d26e6ebc07a0105e02fd8fd5cc8181d5424e 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -37,8 +37,13 @@ class Inference(object): self.__gradient_machine__.finish() def iter_infer_field(self, field, **kwargs): + if not isinstance(field, list) and not isinstance(field, tuple): + field = [field] + for result in self.iter_infer(**kwargs): - yield [each_result[field] for each_result in result] + for each_result in result: + item = [each_result[each_field] for each_field in field] + yield item def infer(self, field='value', **kwargs): retv = None