diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 7e4c6ca33cb72d09534aee1ffb98daba951a49a0..a3311765a996f6592a680d2fdb878c13006143c1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -18,7 +18,10 @@ import astor import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num +from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform + from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer @@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer): def _transform_slice_to_tensor_write(self, node): assert isinstance(node, gast.Assign) target_node = node.targets[0] + target_name = target_node.value.id slice_node = target_node.slice if isinstance(slice_node, gast.Slice): pass - elif isinstance(slice_node, gast.Index): + elif slice_is_num(target_node): value_code = ast_to_source_code(node.value) i = "paddle.cast(" \ "x=paddle.jit.dy2static.to_static_variable({})," \ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 2a0b2cadb5979673a736f153b8c88e18b397e8d1..ffa1d65e6280af9e5c4d3eac8b29351c8177db69 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -19,6 +19,7 @@ import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper @@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node, if isinstance(var_shape_node, gast.Attribute): args = [ast_to_source_code(var_shape_node.value).strip()] - # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index - # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index + # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant + # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant # In (1) case, we pass the number as 'idx' argument in convert_var_shape # In (2) case, we have to make it like `convert_var_shape(x)[slice]` - if slice_node is not None and isinstance(slice_node, gast.Index): - args.append(ast_to_source_code(slice_node).strip()) + if slice_node is not None and slice_is_num(slice_node): + args.append(ast_to_source_code(slice_node.slice).strip()) convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( ",".join(args), in_control_flow) api_shape_node = gast.parse(convert_var_shape_func).body[0].value - if slice_node is not None and not isinstance(slice_node, gast.Index): + if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( - value=api_shape_node, slice=slice_node, ctx=gast.Load()) + value=api_shape_node, slice=slice_node.slice, ctx=gast.Load()) return api_shape_node if isinstance(var_shape_node, gast.Subscript): result_node = copy.deepcopy(var_shape_node) - result_node = create_convert_shape_node( - result_node.value, result_node.slice, in_control_flow) + result_node = create_convert_shape_node(result_node.value, result_node, + in_control_flow) return result_node def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): - # Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly. eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format( api_shape_name) args = [attr_shape_name, eval_exist_func] - if slice_node is not None and isinstance(slice_node, gast.Index): - args.append(ast_to_source_code(slice_node).strip()) + if slice_node is not None and slice_is_num(slice_node): + args.append(ast_to_source_code(slice_node.slice).strip()) choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( ",".join(args)) choose_shape_node = gast.parse(choose_shape_func).body[0].value - if slice_node is not None and not isinstance(slice_node, gast.Index): + if slice_node is not None and not slice_is_num(slice_node): return gast.Subscript( - value=choose_shape_node, slice=slice_node, ctx=gast.Load()) + value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load()) return choose_shape_node @@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer): if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( value_node): return create_choose_shape_node( - value_node.id, self.name_to_var_shape[value_node.id], - slice_node) + value_node.id, self.name_to_var_shape[value_node.id], node) elif isinstance(value_node, gast.Attribute): if self._used_by_paddle_api(value_node): value_name = ast_to_source_code(value_node).strip() if value_name in self.name_to_var_shape: return create_choose_shape_node( - value_name, self.name_to_var_shape[value_name], - slice_node) + value_name, self.name_to_var_shape[value_name], node) if self._is_var_shape(value_node): - return create_convert_shape_node(value_node, slice_node) + return create_convert_shape_node(value_node, node) return node def visit_Attribute(self, node): @@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer): static_shape_value_name = self.name_to_var_shape[ value_node.id] - static_shape_value_node = gast.parse( - static_shape_value_name).body[0].value - index_value_node = gast.Constant(value=idx, kind=None) - slice_index_node = gast.Index(value=index_value_node) - sub_node = gast.Subscript( - value=static_shape_value_node, - slice=slice_index_node, - ctx=gast.Load()) + + sub_node_str = "{}[{}]".format(static_shape_value_name, + idx) + sub_node = gast.parse(sub_node_str).body[0].value update_static_shape_var_node.append( gast.Assign( @@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer): # x.shape becomes convert_var_shape_simple(x) static_shape_value_node = ShapeAttributeTransformer( ).visit(static_shape_value_node) - index_value_node = gast.Constant(value=idx, kind=None) - slice_index_node = gast.Index(value=index_value_node) - sub_node = gast.Subscript( - value=static_shape_value_node, - slice=slice_index_node, - ctx=gast.Load()) + + sub_node_str = "{}[{}]".format( + ast_to_source_code(static_shape_value_node).strip(), + idx) + sub_node = gast.parse(sub_node_str).body[0].value update_static_shape_var_node.append( gast.Assign( diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index e9f8afc06c7ca9cd611bc8dff4b00b8b16b48225..1071fc1350bfeb8f8b768a81204f134b345101b5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer): def tuple_to_stmts(self, node, tuple_name, idx=[]): if not isinstance(node, (gast.Tuple, gast.List)): - value_node = gast.Name( - id=tuple_name, - ctx=gast.Load(), - annotation=None, - type_comment=None) + value_node_str = tuple_name for i in idx: - value_node = gast.Subscript( - value=value_node, - slice=gast.Index(value=gast.Constant( - value=i, kind=None)), - ctx=gast.Load()) - return [gast.Assign(targets=[node], value=value_node)] + value_node_str = value_node_str + "[{}]".format(i) + + node_str = ast_to_source_code(node).strip() + assign_node_str = "{} = {}".format(node_str, value_node_str) + assign_node = gast.parse(assign_node_str).body[0] + return [assign_node] + # isinstance(node, (gast.Tuple, gast.List)) ret = [] for i, element in enumerate(node.elts): @@ -1240,14 +1237,9 @@ class ForNodeVisitor(object): value=step_node) def _build_assign_var_slice_node(self): - var_slice_node = gast.Subscript( - value=self.iter_node, - slice=gast.Index(value=gast.Name( - id=self.iter_idx_name, - ctx=gast.Load(), - annotation=None, - type_comment=None)), - ctx=gast.Load(), ) + var_slice_str = "{}[{}]".format( + ast_to_source_code(self.iter_node).strip(), self.iter_idx_name) + var_slice_node = gast.parse(var_slice_str).body[0].value new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX) target_node, assign_node = create_assign_node(new_iter_var_name, var_slice_node) @@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs): return False return True + + +def slice_is_num(slice_node): + # A slice_node.slice can be a: + # (1) ast.Index, which is a simple number such as [1], [-2] + # (2) ast.Slice, which is represented by bounds such as [2:-1] + # (3) ast.Tuple, which includes the above two cases such as [2:-1, 1] + # If slice node is case (1), return True, Otherwise, return False. + # + # NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced + # other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on. + # Considering the compatibility of gast, here use ast note to check whether the + # node is a num. For more details, please visit https://github.com/serge-sans-paille/gast + + assert isinstance(slice_node, gast.Subscript) + slice_node_str = ast_to_source_code(slice_node).strip() + ast_node = ast.parse(slice_node_str).body[0].value + + if isinstance(ast_node.slice, (ast.Tuple, ast.Slice)): + return False + + if isinstance(ast_node.slice, ast.Index): + return True + + return False diff --git a/python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py b/python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py index c176ff09e024db90ea5a81bcf2afe18939c4f538..17ba6869534fe72d6d82062d957cf9a546f672aa 100644 --- a/python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py +++ b/python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py @@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer): It will be generally represented by gast.Index or gast.Slice in gast. Note: Paddle doesn't support PY3.8 currently. """ - assert isinstance(node.slice, (gast.Index, gast.Slice)) self.generic_visit(node) return node