From eb3c7d00f029c6aa63167624fe2dd428f9040554 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 16 Feb 2022 14:47:19 +0800 Subject: [PATCH] [Dy2St]Refine AnnAssign in static_analysis (#39572) --- .../fluid/dygraph/dygraph_to_static/static_analysis.py | 6 +++++- .../unittests/dygraph_to_static/test_static_analysis.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index 368a01de81..98e76c0f46 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -349,7 +349,11 @@ class StaticAnalysisVisitor(object): ret_type = {NodeVarType.type_from_annotation(node.annotation)} # if annotation and value(Constant) are diffent type, we use value type if node.value: - ret_type = self.node_to_wrapper_map[node.value].node_var_type + node_value_type = self.node_to_wrapper_map[ + node.value].node_var_type + if not (node_value_type & + {NodeVarType.UNKNOWN, NodeVarType.STATEMENT}): + ret_type = node_value_type if isinstance(node.target, gast.Name): self.node_to_wrapper_map[node.target].node_var_type = ret_type self.var_env.set_var_type(node.target.id, ret_type) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py index eb545e5ca2..388291a51c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py @@ -147,6 +147,7 @@ result_var_type6 = { def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'): a = True e, f = paddle.shape(c) + g: paddle.Tensor = len(c) result_var_type7 = { @@ -155,7 +156,8 @@ result_var_type7 = { 'c': {NodeVarType.TENSOR}, 'd': {NodeVarType.STRING}, 'e': {NodeVarType.PADDLE_RETURN_TYPES}, - 'f': {NodeVarType.PADDLE_RETURN_TYPES} + 'f': {NodeVarType.PADDLE_RETURN_TYPES}, + 'g': {NodeVarType.TENSOR} } test_funcs = [ -- GitLab