未验证 提交 4ea95b6f 编写于 作者: L liym27 提交者: GitHub

Support Tensor.shape in dygraph_to_static (#22830)

* support basic tensor.shape. 

* Support tensor.shape with dependencies. 
上级 1644926a
...@@ -14,19 +14,21 @@ ...@@ -14,19 +14,21 @@
from __future__ import print_function from __future__ import print_function
import copy
import inspect
import textwrap
import astor import astor
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions, # It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module. # as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/ # See details in https://github.com/serge-sans-paille/gast/
import gast import gast
import textwrap
import inspect
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .utils import * from .utils import *
__all__ = ['DygraphToStaticAst', 'convert_to_static'] __all__ = ['DygraphToStaticAst', 'convert_to_static']
...@@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer):
def get_static_ast(self, root): def get_static_ast(self, root):
# save root for some analysis may need global AST # save root for some analysis may need global AST
self.root = root self.root = root
self.static_analysis_root = StaticAnalysisVisitor( self.static_analysis_visitor = StaticAnalysisVisitor(root)
root).get_node_wrapper_root() self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
)
self.decorate_func_name = None self.decorate_func_name = None
self.arg_name_to_idx = {} self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root) self.transfer_from_node_type(self.static_analysis_root)
...@@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.visit(node_wrapper.node) self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph # Transform basic api of dygraph to static graph
basic_api_trans = BasicApiTransformer(node_wrapper) basic_api_trans = BasicApiTransformer(node_wrapper,
self.static_analysis_visitor)
basic_api_trans.ast_visit() basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
...@@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer):
Class to transform basic API from dygraph to static graph. Class to transform basic API from dygraph to static graph.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root, static_analysis_visitor):
assert isinstance( assert isinstance(
wrapper_root, AstNodeWrapper wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.class_node_dict = {} self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {} self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
# Used for transformation of Tensor.shape
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.scope_var_type_dict = {}
self._run_static_visitor()
def _run_static_visitor(self):
var_env = copy.deepcopy(self.static_analysis_visitor.get_var_env())
# TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def ast_visit(self): def ast_visit(self):
self.visit(self.root) self.visit(self.root)
...@@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._update_class_node_dict(node): if self._update_class_node_dict(node):
return None return None
value_node = node.value if self._update_name_to_tensor_shape(node):
for child_node in gast.walk(value_node): return node
for child_node in gast.walk(node.value):
if isinstance(child_node, gast.Call): if isinstance(child_node, gast.Call):
self._visit_Call(child_node) self._visit_Call(child_node)
return node return node
def visit_Expr(self, node): def visit_Expr(self, node):
...@@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer):
return return
else: else:
self._visit_Call(child_node) self._visit_Call(child_node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
if self.is_tensor_shape(node):
return create_api_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_tensor_shape:
if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id]
if isinstance(tensor_shape_node, gast.Attribute):
return create_api_shape_node(tensor_shape_node)
elif isinstance(tensor_shape_node, gast.Subscript):
result_node = copy.deepcopy(tensor_shape_node)
result_node.value = create_api_shape_node(
tensor_shape_node.value)
return result_node
return node return node
def _visit_Call(self, node): def _visit_Call(self, node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign` # Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node): if is_to_variable(node):
self._update_feed_dict(node) self._update_feed_dict(node)
node = to_assign_node(node) node = to_assign_node(node)
return node return node
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary
self.generic_visit(node)
func_name = astor.to_source(gast.gast_to_ast(node.func)) func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name): if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name) class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node) static_node = to_static_ast(node, class_node)
...@@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer):
else: else:
return node return node
def is_tensor_shape(self, node):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
"""
assert isinstance(node, gast.Attribute)
if node.attr != 'shape':
return False
try:
value_id = node.value.id
except AttributeError:
return False
if value_id in self.name_to_tensor_shape:
return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True
def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return False
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call):
if is_paddle_api(parent_node):
return True
else:
return False
wrapper_node = wrapper_node.parent
return False
def _is_dygraph_forward(self, func_id): def _is_dygraph_forward(self, func_id):
return func_id in self.class_node_dict return func_id in self.class_node_dict
...@@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer): ...@@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer):
def get_feed_name_to_arg_id(self): def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id return self.feed_name_to_arg_id
def _update_name_to_tensor_shape(self, node):
assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape:
self.name_to_tensor_shape[
target_id] = self.name_to_tensor_shape[value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node
return True
return False
def convert_to_static(dyfunc): def convert_to_static(dyfunc):
""" """
......
...@@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True): ...@@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
# TODO(Aurelius84): more elegant way to transform ast into callable object # TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \ import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \ "import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n" "import paddle.fluid.layers as layers\n" \
"import numpy as np\n" \
"import numpy\n"
with f: with f:
module_name = os.path.basename(f.name[:-3]) module_name = os.path.basename(f.name[:-3])
f.write(import_str) f.write(import_str)
......
...@@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name): ...@@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name):
node.args = [] node.args = []
node.keywords = added_keywords + node.keywords node.keywords = added_keywords + node.keywords
def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, gast.Attribute)
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node.value],
keywords=[])
return api_shape_node
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
def dyfunc_tensor_shape_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, shape=x.shape)
return res
def dyfunc_tensor_shape_2(x):
x = fluid.dygraph.to_variable(x)
shape = x.shape
shape2 = shape
res = fluid.layers.reshape(x, shape2)
return res
def dyfunc_tensor_shape_3(x):
# Don't transform y.shape because y is numpy.ndarray
x = fluid.dygraph.to_variable(x)
y = numpy.ones(5)
res = fluid.layers.reshape(x, shape=y.shape)
return res
def dyfunc_tensor_shape_4(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, shape=(-1, x.shape[0], len(x.shape)))
return res
def dyfunc_tensor_shape_5(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0]
res = fluid.layers.reshape(x, shape=(-1, s))
return res
test_funcs = [
dyfunc_tensor_shape_1, dyfunc_tensor_shape_2, dyfunc_tensor_shape_3,
dyfunc_tensor_shape_4, dyfunc_tensor_shape_5
]
class TestTensorShape(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def get_dygraph_output(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input).numpy()
return res
def get_static_output(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
def test_transformed_static_result(self):
for func in test_funcs:
self.dygraph_func = func
static_res = self.get_static_output()
dygraph_res = self.get_dygraph_output()
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册