diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 5e23acfe4249c4c74448810a7c3647591aa18028..ff7a9a2a957b634054d2d43c0f66a7bc5cbaf62c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -27,10 +27,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer -from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer from paddle.fluid.dygraph.dygraph_to_static.typehint_transformer import TypeHintTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer -from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer @@ -92,7 +90,6 @@ class DygraphToStaticAst(BaseTransformer): EarlyReturnTransformer, BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor) - #ListTransformer, # List used in control flow BreakContinueTransformer, # break/continue in loops ReturnTransformer, # return in functions LogicalTransformer, # logical and/or/not @@ -103,7 +100,6 @@ class DygraphToStaticAst(BaseTransformer): PrintTransformer, # print statement CallTransformer, # transform call recursively CastTransformer, # type casting statement - #GradTransformer, # transform paddle.grad to paddle.gradients DecoratorTransformer, # transform decorators to function call TypeHintTransformer, # remove all typehint in gast.Name ] diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py deleted file mode 100644 index 4f1b1c44752698a4feda12da64d76ad0479834ba..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddle.utils import gast -import warnings - -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper -from paddle.fluid.dygraph.dygraph_to_static import utils -from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer - - -class GradTransformer(BaseTransformer): - """ - A class transforms dygraph paddle.grad to static graph paddle.gradients. The - transformation is applied to support double grad mode. - """ - - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of GradTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node - - def transform(self): - self.visit(self.root) - - def visit_Call(self, node): - self.generic_visit(node) - if not is_grad_api_node(node): - return node - - dygraph_grad_parameters = [ - "outputs", "inputs", "grad_outputs", "retain_graph", "create_graph", - "only_inputs", "allow_unused", "no_grad_vars" - ] - to_static_grad_param = { - "outputs": "targets", - "inputs": "inputs", - "grad_outputs": "target_gradients", - "no_grad_vars": "no_grad_set" - } - static_keywords = [] - - for kw in node.keywords: - if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param: - warnings.warn("paddle.grad has unsupported parameter in jit: " + - kw.arg + ", jit will discard it") - continue - dygraph_grad_parameters.remove(kw.arg) - kw.arg = to_static_grad_param[kw.arg] - static_keywords.append(kw) - - for i in range(len(node.args)): - arg_name = dygraph_grad_parameters[i] - if arg_name not in to_static_grad_param: - warnings.warn("paddle.grad has unsupported parameter in jit: " + - kw.arg + ", jit will discard it") - continue - kw = gast.keyword(arg=to_static_grad_param[arg_name], - value=node.args[i]) - static_keywords.append(kw) - - node.func = gast.parse('paddle.static.gradients').body[0].value - node.keywords = static_keywords - node.args = [] - return node - - -def is_grad_api_node(node): - assert isinstance(node, gast.Call) - api_name = utils.ast_to_source_code(node.func).strip() - if utils.is_paddle_api(node): - if 'no_grad' in api_name: - warnings.warn( - "paddle.no_grad is only supported for inference model, and not supported for training under @to_static." - ) - return False - return api_name.endswith("grad") - return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py deleted file mode 100644 index 540189e646001e264a981b020b8e3de56b5d8093..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import astor -from paddle.utils import gast - -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code -from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num -from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform -from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer - - -class ListTransformer(BaseTransformer): - """ - This class transforms python list used in control flow into Static Graph Ast. - """ - - def __init__(self, wrapper_root): - assert isinstance( - wrapper_root, AstNodeWrapper - ), "Input non-AstNodeWrapper node for the initialization of ListTransformer." - self.wrapper_root = wrapper_root - self.root = wrapper_root.node - self.list_name_to_updated = dict() - self.list_nodes = set() - - self.static_analysis_visitor = StaticAnalysisVisitor(self.root) - self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( - ) - var_env = self.static_analysis_visitor.get_var_env() - var_env.cur_scope = var_env.cur_scope.sub_scopes[0] - self.scope_var_type_dict = var_env.get_scope_var_type() - - def transform(self): - self.visit(self.root) - self.replace_list_with_tensor_array(self.root) - - def visit_Call(self, node): - if isinstance(node.func, gast.Attribute): - func_name = node.func.attr - if func_name == "pop": - node = self._replace_pop(node) - return node - - def visit_Assign(self, node): - if self._update_list_name_to_updated(node): - return node - - if self._need_to_array_write_node(node): - return self._transform_slice_to_tensor_write(node) - - self.generic_visit(node) - return node - - def visit_If(self, node): - self.generic_visit(node) - if is_control_flow_to_transform(node, self.static_analysis_visitor, - self.scope_var_type_dict): - self._transform_list_append_in_control_flow(node) - return node - - def visit_While(self, node): - self.generic_visit(node) - if is_control_flow_to_transform(node, self.static_analysis_visitor, - self.scope_var_type_dict): - self._transform_list_append_in_control_flow(node) - return node - - def visit_For(self, node): - self.generic_visit(node) - if is_control_flow_to_transform(node, self.static_analysis_visitor, - self.scope_var_type_dict): - self._transform_list_append_in_control_flow(node) - return node - - def replace_list_with_tensor_array(self, node): - for child_node in gast.walk(node): - if isinstance(child_node, gast.Assign): - if self._need_to_create_tensor_array(child_node): - child_node.value = self._create_tensor_array( - child_node.value) - - def _transform_list_append_in_control_flow(self, node): - for child_node in gast.walk(node): - if self._need_to_array_write_node(child_node): - child_node.value = \ - self._to_array_write_node(child_node.value) - - def _need_to_array_write_node(self, node): - if isinstance(node, gast.Expr): - if isinstance(node.value, gast.Call): - if self._is_list_append_tensor(node.value): - return True - - if isinstance(node, gast.Assign): - target_node = node.targets[0] - if isinstance(target_node, gast.Subscript): - list_name = ast_to_source_code(target_node.value).strip() - if list_name in self.list_name_to_updated: - if self.list_name_to_updated[list_name] == True: - return True - return False - - def _transform_slice_to_tensor_write(self, node): - assert isinstance(node, gast.Assign) - target_node = node.targets[0] - - target_name = target_node.value.id - slice_node = target_node.slice - - if isinstance(slice_node, gast.Slice): - pass - elif slice_is_num(target_node): - value_code = ast_to_source_code(node.value) - i = "paddle.cast(" \ - "x=_jst.to_static_variable({})," \ - "dtype='int64')".format(ast_to_source_code(slice_node)) - assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \ - .format(target_name, value_code, i, target_name) - assign_node = gast.parse(assign_code).body[0] - return assign_node - - def _is_list_append_tensor(self, node): - """ - a.append(b): a is list, b is Tensor - self.x.append(b): self.x is list, b is Tensor - """ - assert isinstance(node, gast.Call) - # 1. The func is `append`. - if not isinstance(node.func, gast.Attribute): - return False - if node.func.attr != 'append': - return False - - # 2. It's a `python list` to call append(). - value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip() - if value_name not in self.list_name_to_updated: - return False - - # 3. The number of arg of append() is one - # Only one argument is supported in Python list.append() - if len(node.args) != 1: - return False - - # TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis, - # the arg is not required to be Tensor here. - # 4. The arg of append() is Tensor - # arg = node.args[0] - # if isinstance(arg, gast.Name): - # # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function - # # Need a better way to confirm whether `arg.id` is a Tensor. - # try: - # var_type_set = self.scope_var_type_dict[arg.id] - # except KeyError: - # return False - # if NodeVarType.NUMPY_NDARRAY in var_type_set: - # return False - # if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: - # return False - # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x)) - # # else: - # # return True - self.list_name_to_updated[value_name.strip()] = True - return True - - def _need_to_create_tensor_array(self, node): - assert isinstance(node, gast.Assign) - target_node = node.targets[0] - try: - target_id = target_node.id - except AttributeError: - return False - if self.list_name_to_updated.get(target_id) and node in self.list_nodes: - return True - return False - - def _create_tensor_array(self, value_node): - # Although `dtype='float32'`, other types such as `int32` can also be supported - init_value = ast_to_source_code(value_node).strip() - func_code = "paddle.tensor.create_array('float32', {})".format( - init_value) - func_node = gast.parse(func_code).body[0].value - return func_node - - def _to_array_write_node(self, node): - assert isinstance(node, gast.Call) - array = astor.to_source(gast.gast_to_ast(node.func.value)) - x = astor.to_source(gast.gast_to_ast(node.args[0])) - i = "paddle.tensor.array_length({})".format(array) - func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format( - x, i, array) - return gast.parse(func_code).body[0].value - - def _update_list_name_to_updated(self, node): - assert isinstance(node, gast.Assign) - target_node = node.targets[0] - # NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]` - try: - target_id = target_node.id - except AttributeError: - return False - value_node = node.value - if isinstance(value_node, gast.List): - self.list_name_to_updated[target_id] = False - self.list_nodes.add(node) - return True - elif target_id in self.list_name_to_updated and \ - self.list_name_to_updated[target_id] == False: - del self.list_name_to_updated[target_id] - return False - - def _replace_pop(self, node): - """ - Replace a pop statement for a list or dict. - For example: - - list_a = [0,1,2,3,4] - x = list_a.pop() # --> convert_pop(list_a) - y = list_a.pop(1) # --> convert_pop(list_a, 1) - - dict_a = {"red":0, "blue":1, "yellow":2} - m = dict_a.pop("red") # --> convert_pop(dict_a, "red") - n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3) - - """ - assert isinstance(node, gast.Call) - assert isinstance(node.func, gast.Attribute) - - target_node = node.func.value - target_str = ast_to_source_code(target_node).strip() - - args_str = [ast_to_source_code(arg).strip() for arg in node.args] - - # NOTE(liym27): - # 1. pop stmt for a list if len(args_str) == 0 - # 2. pop stmt for a list or dict if len(args_str) == 1 - # 3. pop stmt for a dict if len(args_str) == 2 - if len(args_str) <= 2: - new_pop_str = "_jst.Pop({}, {})"\ - .format(target_str, ",".join(args_str)) - new_pop_node = gast.parse(new_pop_str).body[0].value - return new_pop_node - else: - return node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index f96feaee9d739d387f354ead62eded25f52d64d8..8e1950b21fca77d8fc543b62e024ac54f5543a94 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -330,22 +330,6 @@ def is_numpy_api(node): return False -def is_control_flow_to_transform(node, - static_analysis_visitor=None, - var_name_to_type=None): - """ - Determines whether the node is a PaddlePaddle control flow statement which needs to - be transformed into a static graph control flow statement. - """ - assert isinstance(node, gast.AST), \ - "The type of input node must be gast.AST, but received %s." % type(node) - visitor = IsControlFlowVisitor(node, - static_analysis_visitor, - node_var_type_map=var_name_to_type) - need_to_transform = visitor.transform() - return need_to_transform - - def _delete_keywords_from(node): assert isinstance(node, gast.Call) func_src = astor.to_source(gast.gast_to_ast(node.func)) @@ -1001,31 +985,6 @@ def _compatible_non_tensor_spec(src_spec, desired_spec): return True -def slice_is_num(slice_node): - # A slice_node.slice can be a: - # (1) ast.Index, which is a simple number such as [1], [-2] - # (2) ast.Slice, which is represented by bounds such as [2:-1] - # (3) ast.Tuple, which includes the above two cases such as [2:-1, 1] - # If slice node is case (1), return True, Otherwise, return False. - # - # NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced - # other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on. - # Considering the compatibility of gast, here use ast note to check whether the - # node is a num. For more details, please visit https://github.com/serge-sans-paille/gast - - assert isinstance(slice_node, gast.Subscript) - slice_node_str = ast_to_source_code(slice_node).strip() - ast_node = ast.parse(slice_node_str).body[0].value - - if isinstance(ast_node.slice, (ast.Tuple, ast.Slice)): - return False - - if isinstance(ast_node.slice, ast.Index): - return True - - return False - - class NameScope: def __init__(self):