未验证 提交 37ef7c13 编写于 作者: L liym27 提交者: GitHub

[dy2static]Fix a bug of is_dygraph_api and move BasicApiTransformer to a separate file(#23923)

* Move BasicApiTransformer to a separate file. test=develop

* Fix a bug: A api in module is not a real dygraph api in dygraph_to_static. test=develop
上级 c645d235
......@@ -26,6 +26,7 @@ import textwrap
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
......@@ -120,111 +121,6 @@ class DygraphToStaticAst(gast.NodeTransformer):
return feed_name_to_idx
class BasicApiTransformer(gast.NodeTransformer):
"""
Class to transform basic API from dygraph to static graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
def transform(self):
self.visit(self.root)
return self.wrapper_root
def visit_Assign(self, node):
if self._update_class_node_dict(node):
return None
for child_node in gast.walk(node.value):
if isinstance(child_node, gast.Call):
self._visit_Call(child_node)
return node
def visit_Expr(self, node):
value_node = node.value
for child_node in gast.walk(value_node):
if isinstance(child_node, gast.Call):
if is_dygraph_api(child_node):
return
else:
self._visit_Call(child_node)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node)
return node
func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
return static_node
else:
return node
def _is_dygraph_forward(self, func_id):
return func_id in self.class_node_dict
def _get_class_node(self, func_id):
return self.class_node_dict[func_id]
def _update_class_node_dict(self, node):
assert isinstance(node, gast.Assign)
node_value = node.value
if isinstance(node_value, gast.Call):
if is_to_variable(node_value):
return False
if is_dygraph_api(node_value):
dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api):
return False
update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value
return True
# TODO: node.value is not dygraph class
return False
def _update_feed_dict(self, node):
assert isinstance(node, gast.Call)
value_node = None
for kw in node.keywords:
if kw.arg == 'value':
value_node = kw.value # eg: `a` for "value=a "
if not value_node:
value_node = node.args[0]
if not isinstance(value_node, gast.Name):
return
else:
var_name = value_node.id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[
feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
def convert_to_static(dyfunc):
"""
Converts dygraph function into static function.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import astor
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
class BasicApiTransformer(gast.NodeTransformer):
"""
Class to transform basic API from dygraph to static graph.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
def transform(self):
self.visit(self.root)
return self.wrapper_root
def visit_Assign(self, node):
if self._update_class_node_dict(node):
return None
for child_node in gast.walk(node.value):
if isinstance(child_node, gast.Call):
self._visit_Call(child_node)
return node
def visit_Expr(self, node):
value_node = node.value
for child_node in gast.walk(value_node):
if isinstance(child_node, gast.Call):
# TODO(liym27):
# Considers that a dygraph api which modifies the input or has a output.
if is_dygraph_api(child_node):
return
else:
self._visit_Call(child_node)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node)
return node
func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
return static_node
else:
return node
def _is_dygraph_forward(self, func_id):
return func_id in self.class_node_dict
def _get_class_node(self, func_id):
return self.class_node_dict[func_id]
def _update_class_node_dict(self, node):
assert isinstance(node, gast.Assign)
node_value = node.value
if isinstance(node_value, gast.Call):
if is_to_variable(node_value):
return False
if is_dygraph_api(node_value):
dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api):
return False
update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value
return True
# TODO: node.value is not dygraph class
return False
def _update_feed_dict(self, node):
assert isinstance(node, gast.Call)
value_node = None
for kw in node.keywords:
if kw.arg == 'value':
value_node = kw.value # eg: `a` for "value=a "
if not value_node:
value_node = node.args[0]
if not isinstance(value_node, gast.Name):
return
else:
var_name = value_node.id
feed_var_name = unique_name.generate(var_name) # eg: "a_0"
self.feed_name_to_arg_id[
feed_var_name] = var_name # eg: "a_0" : "a"
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
......@@ -61,6 +61,10 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
return False
return is_api_in_module(node, "paddle.fluid.dygraph")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册