未验证 提交 522c91ec 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Remove gast.Index for compatibility of gast 0.4.0 (#31358)

上级 62289fcc
......@@ -18,7 +18,10 @@ import astor
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
......@@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer):
def _transform_slice_to_tensor_write(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
target_name = target_node.value.id
slice_node = target_node.slice
if isinstance(slice_node, gast.Slice):
pass
elif isinstance(slice_node, gast.Index):
elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \
......
......@@ -19,6 +19,7 @@ import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
......@@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node,
if isinstance(var_shape_node, gast.Attribute):
args = [ast_to_source_code(var_shape_node.value).strip()]
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(
value=api_shape_node, slice=slice_node, ctx=gast.Load())
value=api_shape_node, slice=slice_node.slice, ctx=gast.Load())
return api_shape_node
if isinstance(var_shape_node, gast.Subscript):
result_node = copy.deepcopy(var_shape_node)
result_node = create_convert_shape_node(
result_node.value, result_node.slice, in_control_flow)
result_node = create_convert_shape_node(result_node.value, result_node,
in_control_flow)
return result_node
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
# Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]
if slice_node is not None and isinstance(slice_node, gast.Index):
args.append(ast_to_source_code(slice_node).strip())
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not isinstance(slice_node, gast.Index):
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(
value=choose_shape_node, slice=slice_node, ctx=gast.Load())
value=choose_shape_node, slice=slice_node.slice, ctx=gast.Load())
return choose_shape_node
......@@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer):
if value_node.id in self.name_to_var_shape and self._used_by_paddle_api(
value_node):
return create_choose_shape_node(
value_node.id, self.name_to_var_shape[value_node.id],
slice_node)
value_node.id, self.name_to_var_shape[value_node.id], node)
elif isinstance(value_node, gast.Attribute):
if self._used_by_paddle_api(value_node):
value_name = ast_to_source_code(value_node).strip()
if value_name in self.name_to_var_shape:
return create_choose_shape_node(
value_name, self.name_to_var_shape[value_name],
slice_node)
value_name, self.name_to_var_shape[value_name], node)
if self._is_var_shape(value_node):
return create_convert_shape_node(value_node, slice_node)
return create_convert_shape_node(value_node, node)
return node
def visit_Attribute(self, node):
......@@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
static_shape_value_name = self.name_to_var_shape[
value_node.id]
static_shape_value_node = gast.parse(
static_shape_value_name).body[0].value
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript(
value=static_shape_value_node,
slice=slice_index_node,
ctx=gast.Load())
sub_node_str = "{}[{}]".format(static_shape_value_name,
idx)
sub_node = gast.parse(sub_node_str).body[0].value
update_static_shape_var_node.append(
gast.Assign(
......@@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node = ShapeAttributeTransformer(
).visit(static_shape_value_node)
index_value_node = gast.Constant(value=idx, kind=None)
slice_index_node = gast.Index(value=index_value_node)
sub_node = gast.Subscript(
value=static_shape_value_node,
slice=slice_index_node,
ctx=gast.Load())
sub_node_str = "{}[{}]".format(
ast_to_source_code(static_shape_value_node).strip(),
idx)
sub_node = gast.parse(sub_node_str).body[0].value
update_static_shape_var_node.append(
gast.Assign(
......
......@@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer):
def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)):
value_node = gast.Name(
id=tuple_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
value_node_str = tuple_name
for i in idx:
value_node = gast.Subscript(
value=value_node,
slice=gast.Index(value=gast.Constant(
value=i, kind=None)),
ctx=gast.Load())
return [gast.Assign(targets=[node], value=value_node)]
value_node_str = value_node_str + "[{}]".format(i)
node_str = ast_to_source_code(node).strip()
assign_node_str = "{} = {}".format(node_str, value_node_str)
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]
# isinstance(node, (gast.Tuple, gast.List))
ret = []
for i, element in enumerate(node.elts):
......@@ -1240,14 +1237,9 @@ class ForNodeVisitor(object):
value=step_node)
def _build_assign_var_slice_node(self):
var_slice_node = gast.Subscript(
value=self.iter_node,
slice=gast.Index(value=gast.Name(
id=self.iter_idx_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)),
ctx=gast.Load(), )
var_slice_str = "{}[{}]".format(
ast_to_source_code(self.iter_node).strip(), self.iter_idx_name)
var_slice_node = gast.parse(var_slice_str).body[0].value
new_iter_var_name = unique_name.generate(FOR_ITER_VAR_NAME_PREFIX)
target_node, assign_node = create_assign_node(new_iter_var_name,
var_slice_node)
......@@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
return False
return True
def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
# (2) ast.Slice, which is represented by bounds such as [2:-1]
# (3) ast.Tuple, which includes the above two cases such as [2:-1, 1]
# If slice node is case (1), return True, Otherwise, return False.
#
# NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced
# other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on.
# Considering the compatibility of gast, here use ast note to check whether the
# node is a num. For more details, please visit https://github.com/serge-sans-paille/gast
assert isinstance(slice_node, gast.Subscript)
slice_node_str = ast_to_source_code(slice_node).strip()
ast_node = ast.parse(slice_node_str).body[0].value
if isinstance(ast_node.slice, (ast.Tuple, ast.Slice)):
return False
if isinstance(ast_node.slice, ast.Index):
return True
return False
......@@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer):
It will be generally represented by gast.Index or gast.Slice in gast.
Note: Paddle doesn't support PY3.8 currently.
"""
assert isinstance(node.slice, (gast.Index, gast.Slice))
self.generic_visit(node)
return node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册