未验证 提交 a3436672 编写于 作者: X xiongkun 提交者: GitHub

[Dy2Static] Remove deprecated code in dy2static (#47148)

上级 0e552c08
...@@ -27,10 +27,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br ...@@ -27,10 +27,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer
from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
...@@ -92,7 +90,6 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -92,7 +90,6 @@ class DygraphToStaticAst(BaseTransformer):
EarlyReturnTransformer, EarlyReturnTransformer,
BasicApiTransformer, # Basic Api BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
#ListTransformer, # List used in control flow
BreakContinueTransformer, # break/continue in loops BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not LogicalTransformer, # logical and/or/not
...@@ -103,7 +100,6 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -103,7 +100,6 @@ class DygraphToStaticAst(BaseTransformer):
PrintTransformer, # print statement PrintTransformer, # print statement
CallTransformer, # transform call recursively CallTransformer, # transform call recursively
CastTransformer, # type casting statement CastTransformer, # type casting statement
#GradTransformer, # transform paddle.grad to paddle.gradients
DecoratorTransformer, # transform decorators to function call DecoratorTransformer, # transform decorators to function call
TypeHintTransformer, # remove all typehint in gast.Name TypeHintTransformer, # remove all typehint in gast.Name
] ]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.utils import gast
import warnings
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class GradTransformer(BaseTransformer):
"""
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of GradTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Call(self, node):
self.generic_visit(node)
if not is_grad_api_node(node):
return node
dygraph_grad_parameters = [
"outputs", "inputs", "grad_outputs", "retain_graph", "create_graph",
"only_inputs", "allow_unused", "no_grad_vars"
]
to_static_grad_param = {
"outputs": "targets",
"inputs": "inputs",
"grad_outputs": "target_gradients",
"no_grad_vars": "no_grad_set"
}
static_keywords = []
for kw in node.keywords:
if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
dygraph_grad_parameters.remove(kw.arg)
kw.arg = to_static_grad_param[kw.arg]
static_keywords.append(kw)
for i in range(len(node.args)):
arg_name = dygraph_grad_parameters[i]
if arg_name not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
kw = gast.keyword(arg=to_static_grad_param[arg_name],
value=node.args[i])
static_keywords.append(kw)
node.func = gast.parse('paddle.static.gradients').body[0].value
node.keywords = static_keywords
node.args = []
return node
def is_grad_api_node(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_paddle_api(node):
if 'no_grad' in api_name:
warnings.warn(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
)
return False
return api_name.endswith("grad")
return False
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import astor
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
class ListTransformer(BaseTransformer):
"""
This class transforms python list used in control flow into Static Graph Ast.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.list_name_to_updated = dict()
self.list_nodes = set()
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
var_env = self.static_analysis_visitor.get_var_env()
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
self.visit(self.root)
self.replace_list_with_tensor_array(self.root)
def visit_Call(self, node):
if isinstance(node.func, gast.Attribute):
func_name = node.func.attr
if func_name == "pop":
node = self._replace_pop(node)
return node
def visit_Assign(self, node):
if self._update_list_name_to_updated(node):
return node
if self._need_to_array_write_node(node):
return self._transform_slice_to_tensor_write(node)
self.generic_visit(node)
return node
def visit_If(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_While(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_For(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def replace_list_with_tensor_array(self, node):
for child_node in gast.walk(node):
if isinstance(child_node, gast.Assign):
if self._need_to_create_tensor_array(child_node):
child_node.value = self._create_tensor_array(
child_node.value)
def _transform_list_append_in_control_flow(self, node):
for child_node in gast.walk(node):
if self._need_to_array_write_node(child_node):
child_node.value = \
self._to_array_write_node(child_node.value)
def _need_to_array_write_node(self, node):
if isinstance(node, gast.Expr):
if isinstance(node.value, gast.Call):
if self._is_list_append_tensor(node.value):
return True
if isinstance(node, gast.Assign):
target_node = node.targets[0]
if isinstance(target_node, gast.Subscript):
list_name = ast_to_source_code(target_node.value).strip()
if list_name in self.list_name_to_updated:
if self.list_name_to_updated[list_name] == True:
return True
return False
def _transform_slice_to_tensor_write(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
target_name = target_node.value.id
slice_node = target_node.slice
if isinstance(slice_node, gast.Slice):
pass
elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \
"x=_jst.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name)
assign_node = gast.parse(assign_code).body[0]
return assign_node
def _is_list_append_tensor(self, node):
"""
a.append(b): a is list, b is Tensor
self.x.append(b): self.x is list, b is Tensor
"""
assert isinstance(node, gast.Call)
# 1. The func is `append`.
if not isinstance(node.func, gast.Attribute):
return False
if node.func.attr != 'append':
return False
# 2. It's a `python list` to call append().
value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
if value_name not in self.list_name_to_updated:
return False
# 3. The number of arg of append() is one
# Only one argument is supported in Python list.append()
if len(node.args) != 1:
return False
# TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
# the arg is not required to be Tensor here.
# 4. The arg of append() is Tensor
# arg = node.args[0]
# if isinstance(arg, gast.Name):
# # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
# # Need a better way to confirm whether `arg.id` is a Tensor.
# try:
# var_type_set = self.scope_var_type_dict[arg.id]
# except KeyError:
# return False
# if NodeVarType.NUMPY_NDARRAY in var_type_set:
# return False
# if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
# return False
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x))
# # else:
# # return True
self.list_name_to_updated[value_name.strip()] = True
return True
def _need_to_create_tensor_array(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
if self.list_name_to_updated.get(target_id) and node in self.list_nodes:
return True
return False
def _create_tensor_array(self, value_node):
# Although `dtype='float32'`, other types such as `int32` can also be supported
init_value = ast_to_source_code(value_node).strip()
func_code = "paddle.tensor.create_array('float32', {})".format(
init_value)
func_node = gast.parse(func_code).body[0].value
return func_node
def _to_array_write_node(self, node):
assert isinstance(node, gast.Call)
array = astor.to_source(gast.gast_to_ast(node.func.value))
x = astor.to_source(gast.gast_to_ast(node.args[0]))
i = "paddle.tensor.array_length({})".format(array)
func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format(
x, i, array)
return gast.parse(func_code).body[0].value
def _update_list_name_to_updated(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
# NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]`
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.List):
self.list_name_to_updated[target_id] = False
self.list_nodes.add(node)
return True
elif target_id in self.list_name_to_updated and \
self.list_name_to_updated[target_id] == False:
del self.list_name_to_updated[target_id]
return False
def _replace_pop(self, node):
"""
Replace a pop statement for a list or dict.
For example:
list_a = [0,1,2,3,4]
x = list_a.pop() # --> convert_pop(list_a)
y = list_a.pop(1) # --> convert_pop(list_a, 1)
dict_a = {"red":0, "blue":1, "yellow":2}
m = dict_a.pop("red") # --> convert_pop(dict_a, "red")
n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3)
"""
assert isinstance(node, gast.Call)
assert isinstance(node.func, gast.Attribute)
target_node = node.func.value
target_str = ast_to_source_code(target_node).strip()
args_str = [ast_to_source_code(arg).strip() for arg in node.args]
# NOTE(liym27):
# 1. pop stmt for a list if len(args_str) == 0
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2:
new_pop_str = "_jst.Pop({}, {})"\
.format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node
else:
return node
...@@ -330,22 +330,6 @@ def is_numpy_api(node): ...@@ -330,22 +330,6 @@ def is_numpy_api(node):
return False return False
def is_control_flow_to_transform(node,
static_analysis_visitor=None,
var_name_to_type=None):
"""
Determines whether the node is a PaddlePaddle control flow statement which needs to
be transformed into a static graph control flow statement.
"""
assert isinstance(node, gast.AST), \
"The type of input node must be gast.AST, but received %s." % type(node)
visitor = IsControlFlowVisitor(node,
static_analysis_visitor,
node_var_type_map=var_name_to_type)
need_to_transform = visitor.transform()
return need_to_transform
def _delete_keywords_from(node): def _delete_keywords_from(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func)) func_src = astor.to_source(gast.gast_to_ast(node.func))
...@@ -1001,31 +985,6 @@ def _compatible_non_tensor_spec(src_spec, desired_spec): ...@@ -1001,31 +985,6 @@ def _compatible_non_tensor_spec(src_spec, desired_spec):
return True return True
def slice_is_num(slice_node):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
# (2) ast.Slice, which is represented by bounds such as [2:-1]
# (3) ast.Tuple, which includes the above two cases such as [2:-1, 1]
# If slice node is case (1), return True, Otherwise, return False.
#
# NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced
# other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on.
# Considering the compatibility of gast, here use ast note to check whether the
# node is a num. For more details, please visit https://github.com/serge-sans-paille/gast
assert isinstance(slice_node, gast.Subscript)
slice_node_str = ast_to_source_code(slice_node).strip()
ast_node = ast.parse(slice_node_str).body[0].value
if isinstance(ast_node.slice, (ast.Tuple, ast.Slice)):
return False
if isinstance(ast_node.slice, ast.Index):
return True
return False
class NameScope: class NameScope:
def __init__(self): def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册