From 1252f4bb3e574df80aa6d18c7ddae1b3a90bd81c Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Thu, 10 Feb 2022 21:16:14 +0800 Subject: [PATCH] [Dy2St]Handle `a, b = paddle.shape(x)` in Static Analysis (#39245) * refine Assign * add UT --- .../fluid/dygraph/dygraph_to_static/static_analysis.py | 7 +++++++ .../unittests/dygraph_to_static/test_static_analysis.py | 5 ++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 45a42d481b..368a01de81 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -334,6 +334,13 @@ class StaticAnalysisVisitor(object): if isinstance(target, gast.Name): self.node_to_wrapper_map[target].node_var_type = ret_type self.var_env.set_var_type(target.id, ret_type) + # Handle statements like `a, b = paddle.shape(x)` + elif isinstance(target, gast.Tuple): + for sub_target in target.elts: + if isinstance(sub_target, gast.Name): + self.node_to_wrapper_map[ + sub_target].node_var_type = ret_type + self.var_env.set_var_type(sub_target.id, ret_type) return ret_type if isinstance(node, gast.AnnAssign): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py index afccaca693..eb545e5ca2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py @@ -146,13 +146,16 @@ result_var_type6 = { def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'): a = True + e, f = paddle.shape(c) result_var_type7 = { 'a': {NodeVarType.BOOLEAN}, 'b': {NodeVarType.FLOAT}, 'c': {NodeVarType.TENSOR}, - 'd': {NodeVarType.STRING} + 'd': {NodeVarType.STRING}, + 'e': {NodeVarType.PADDLE_RETURN_TYPES}, + 'f': {NodeVarType.PADDLE_RETURN_TYPES} } test_funcs = [ -- GitLab