未验证 提交 522c91ec 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Remove gast.Index for compatibility of gast 0.4.0 (#31358)

上级 62289fcc
...@@ -18,7 +18,10 @@ import astor ...@@ -18,7 +18,10 @@ import astor
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
...@@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer): ...@@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer):
def _transform_slice_to_tensor_write(self, node): def _transform_slice_to_tensor_write(self, node):
assert isinstance(node, gast.Assign) assert isinstance(node, gast.Assign)
target_node = node.targets[0] target_node = node.targets[0]
target_name = target_node.value.id target_name = target_node.value.id
slice_node = target_node.slice slice_node = target_node.slice
if isinstance(slice_node, gast.Slice): if isinstance(slice_node, gast.Slice):
pass pass
elif isinstance(slice_node, gast.Index): elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value) value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \ i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \ "x=paddle.jit.dy2static.to_static_variable({})," \
......
...@@ -19,6 +19,7 @@ import gast ...@@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
...@@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node, ...@@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node,
if isinstance(var_shape_node, gast.Attribute): if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()] args = [ast_to_source_code(var_shape_node.value).strip()]
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index # (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index # (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
# In (1) case, we pass the number as 'idx' argument in convert_var_shape # In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]` # In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and isinstance(slice_node, gast.Index): if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node).strip()) args.append(ast_to_source_code(slice_node.slice).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format( convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow) ",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index): if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript( return gast.Subscript(
value=api_shape_node, slice=slice_node, ctx=gast.Load()) value=api_shape_node, slice=slice_node.slice, ctx=gast.Load())
return api_shape_node return api_shape_node
if isinstance(var_shape_node, gast.Subscript): if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node) result_node = copy.deepcopy(var_shape_node)
result_node = create_convert_shape_node( result_node = create_convert_shape_node(result_node.value, result_node,
result_node.value, result_node.slice, in_control_flow) in_control_flow)
return result_node return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None): def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
# Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format( eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
api_shape_name) api_shape_name)
args = [attr_shape_name, eval_exist_func] args = [attr_shape_name, eval_exist_func]
if slice_node is not None and isinstance(slice_node, gast.Index): if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node).strip()) args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format( choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args)) ",".join(args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index): if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript( return gast.Subscript(
value=choose_shape_node, slice=slice_node, ctx=gast.Load()) value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load())
return choose_shape_node return choose_shape_node
...@@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer):
if value_node.id in self.name_to_var_shape and self._used_by_paddle_api( if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
value_node): value_node):
return create_choose_shape_node( return create_choose_shape_node(
value_node.id, self.name_to_var_shape[value_node.id], value_node.id, self.name_to_var_shape[value_node.id], node)
slice_node)
elif isinstance(value_node, gast.Attribute): elif isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node): if self._used_by_paddle_api(value_node):
value_name = ast_to_source_code(value_node).strip() value_name = ast_to_source_code(value_node).strip()
if value_name in self.name_to_var_shape: if value_name in self.name_to_var_shape:
return create_choose_shape_node( return create_choose_shape_node(
value_name, self.name_to_var_shape[value_name], value_name, self.name_to_var_shape[value_name], node)
slice_node)
if self._is_var_shape(value_node): if self._is_var_shape(value_node):
return create_convert_shape_node(value_node, slice_node) return create_convert_shape_node(value_node, node)
return node return node
def visit_Attribute(self, node): def visit_Attribute(self, node):
...@@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
static_shape_value_name = self.name_to_var_shape[ static_shape_value_name = self.name_to_var_shape[
value_node.id] value_node.id]
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value sub_node_str = "{}[{}]".format(static_shape_value_name,
index_value_node = gast.Constant(value=idx, kind=None) idx)
slice_index_node = gast.Index(value=index_value_node) sub_node = gast.parse(sub_node_str).body[0].value
sub_node = gast.Subscript(
value=static_shape_value_node,
slice=slice_index_node,
ctx=gast.Load())
update_static_shape_var_node.append( update_static_shape_var_node.append(
gast.Assign( gast.Assign(
...@@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
# x.shape becomes convert_var_shape_simple(x) # x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer( static_shape_value_node = ShapeAttributeTransformer(
).visit(static_shape_value_node) ).visit(static_shape_value_node)
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node) sub_node_str = "{}[{}]".format(
sub_node = gast.Subscript( ast_to_source_code(static_shape_value_node).strip(),
value=static_shape_value_node, idx)
slice=slice_index_node, sub_node = gast.parse(sub_node_str).body[0].value
ctx=gast.Load())
update_static_shape_var_node.append( update_static_shape_var_node.append(
gast.Assign( gast.Assign(
......
...@@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer): ...@@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer):
def tuple_to_stmts(self, node, tuple_name, idx=[]): def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)): if not isinstance(node, (gast.Tuple, gast.List)):
value_node = gast.Name( value_node_str = tuple_name
id=tuple_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
for i in idx: for i in idx:
value_node = gast.Subscript( value_node_str = value_node_str + "[{}]".format(i)
value=value_node,
slice=gast.Index(value=gast.Constant( node_str = ast_to_source_code(node).strip()
value=i, kind=None)), assign_node_str = "{} = {}".format(node_str, value_node_str)
ctx=gast.Load()) assign_node = gast.parse(assign_node_str).body[0]
return [gast.Assign(targets=[node], value=value_node)] return [assign_node]
# isinstance(node, (gast.Tuple, gast.List)) # isinstance(node, (gast.Tuple, gast.List))
ret = [] ret = []
for i, element in enumerate(node.elts): for i, element in enumerate(node.elts):
...@@ -1240,14 +1237,9 @@ class ForNodeVisitor(object): ...@@ -1240,14 +1237,9 @@ class ForNodeVisitor(object):
value=step_node) value=step_node)
def _build_assign_var_slice_node(self): def _build_assign_var_slice_node(self):
var_slice_node = gast.Subscript( var_slice_str = "{}[{}]".format(
value=self.iter_node, ast_to_source_code(self.iter_node).strip(), self.iter_idx_name)
slice=gast.Index(value=gast.Name( var_slice_node = gast.parse(var_slice_str).body[0].value
id=self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)),
ctx=gast.Load(), )
new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX) new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX)
target_node, assign_node = create_assign_node(new_iter_var_name, target_node, assign_node = create_assign_node(new_iter_var_name,
var_slice_node) var_slice_node)
...@@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs): ...@@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
return False return False
return True return True
def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
# (2) ast.Slice, which is represented by bounds such as [2:-1]
# (3) ast.Tuple, which includes the above two cases such as [2:-1, 1]
# If slice node is case (1), return True, Otherwise, return False.
#
# NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced
# other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on.
# Considering the compatibility of gast, here use ast note to check whether the
# node is a num. For more details, please visit https://github.com/serge-sans-paille/gast
assert isinstance(slice_node, gast.Subscript)
slice_node_str = ast_to_source_code(slice_node).strip()
ast_node = ast.parse(slice_node_str).body[0].value
if isinstance(ast_node.slice, (ast.Tuple, ast.Slice)):
return False
if isinstance(ast_node.slice, ast.Index):
return True
return False
...@@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer): ...@@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer):
It will be generally represented by gast.Index or gast.Slice in gast. It will be generally represented by gast.Index or gast.Slice in gast.
Note: Paddle doesn't support PY3.8 currently. Note: Paddle doesn't support PY3.8 currently.
""" """
assert isinstance(node.slice, (gast.Index, gast.Slice))
self.generic_visit(node) self.generic_visit(node)
return node return node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册