From e8869a907b890b2b96c451ae35eb984620df138d Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 30 Apr 2020 09:54:00 +0800 Subject: [PATCH] Fix bug in ProgramTranslator.get_output, convert all items into VarBase in nested list. (#24267) --- .../fluid/dygraph/dygraph_to_static/program_translator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index a4d1131ab8..bb8fbeb50a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -27,8 +27,9 @@ from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_st from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check -from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.data_feeder import check_type +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.layers.utils import map_structure __all__ = ['ProgramTranslator', 'convert_function_with_cache'] @@ -403,10 +404,11 @@ class ProgramTranslator(object): if not program_cache.in_build_process: outputs = self._run(*args, **kwargs) with guard(): + outputs = map_structure(to_variable, outputs) if len(outputs) == 1: - outputs = to_variable(outputs[0]) + outputs = outputs[0] else: - outputs = tuple(to_variable(x) for x in outputs) + outputs = tuple(outputs) return outputs def get_func(self, dygraph_func): -- GitLab