未验证 提交 4ea95b6f 编写于 作者: L liym27 提交者: GitHub

Support Tensor.shape in dygraph_to_static (#22830)

* support basic tensor.shape. 

* Support tensor.shape with dependencies. 
上级 1644926a
......@@ -14,19 +14,21 @@
from __future__ import print_function
import copy
import inspect
import textwrap
import astor
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
import textwrap
import inspect
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from .static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from .utils import *
__all__ = ['DygraphToStaticAst', 'convert_to_static']
......@@ -121,8 +123,10 @@ class DygraphToStaticAst(gast.NodeTransformer):
def get_static_ast(self, root):
# save root for some analysis may need global AST
self.root = root
self.static_analysis_root = StaticAnalysisVisitor(
root).get_node_wrapper_root()
self.static_analysis_visitor = StaticAnalysisVisitor(root)
self.static_analysis_root = self.static_analysis_visitor.get_node_wrapper_root(
)
self.decorate_func_name = None
self.arg_name_to_idx = {}
self.transfer_from_node_type(self.static_analysis_root)
......@@ -133,7 +137,8 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans = BasicApiTransformer(node_wrapper,
self.static_analysis_visitor)
basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
......@@ -178,14 +183,31 @@ class BasicApiTransformer(gast.NodeTransformer):
Class to transform basic API from dygraph to static graph.
"""
def __init__(self, wrapper_root):
def __init__(self, wrapper_root, static_analysis_visitor):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.class_node_dict = {}
# Used for transformation of data feed
self.feed_name_to_arg_id = {}
self.name_to_tensor_shape = {}
# Used for transformation of Tensor.shape
self.static_analysis_visitor = static_analysis_visitor
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
self.scope_var_type_dict = {}
self._run_static_visitor()
def _run_static_visitor(self):
var_env = copy.deepcopy(self.static_analysis_visitor.get_var_env())
# TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def ast_visit(self):
self.visit(self.root)
......@@ -204,11 +226,12 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._update_class_node_dict(node):
return None
value_node = node.value
for child_node in gast.walk(value_node):
if self._update_name_to_tensor_shape(node):
return node
for child_node in gast.walk(node.value):
if isinstance(child_node, gast.Call):
self._visit_Call(child_node)
return node
def visit_Expr(self, node):
......@@ -219,19 +242,41 @@ class BasicApiTransformer(gast.NodeTransformer):
return
else:
self._visit_Call(child_node)
return node
def visit_Attribute(self, node):
if self._used_by_paddle_api(node):
if self.is_tensor_shape(node):
return create_api_shape_node(node)
return node
def visit_Name(self, node):
if node.id in self.name_to_tensor_shape:
if self._used_by_paddle_api(node):
tensor_shape_node = self.name_to_tensor_shape[node.id]
if isinstance(tensor_shape_node, gast.Attribute):
return create_api_shape_node(tensor_shape_node)
elif isinstance(tensor_shape_node, gast.Subscript):
result_node = copy.deepcopy(tensor_shape_node)
result_node.value = create_api_shape_node(
tensor_shape_node.value)
return result_node
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
self._update_feed_dict(node)
node = to_assign_node(node)
return node
if is_paddle_api(node):
# Visit gast.Attribute and gast.Name to replace tensor.shape if necessary
self.generic_visit(node)
func_name = astor.to_source(gast.gast_to_ast(node.func))
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
......@@ -239,6 +284,53 @@ class BasicApiTransformer(gast.NodeTransformer):
else:
return node
def is_tensor_shape(self, node):
"""
Return True if node is like `x.shape` and x is Tensor, return False otherwise.
"""
assert isinstance(node, gast.Attribute)
if node.attr != 'shape':
return False
try:
value_id = node.value.id
except AttributeError:
return False
if value_id in self.name_to_tensor_shape:
return True
# TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function
# Need a better way to confirm whether `value_id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[value_id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
return True
def _used_by_paddle_api(self, node):
assert isinstance(node, (gast.Attribute, gast.Name))
wrapper_node = self.node_to_wrapper_map.get(node)
if not wrapper_node:
# Transformed node is not in node_to_wrapper_map
return False
while wrapper_node.parent:
parent_node = wrapper_node.parent.node
if isinstance(parent_node, gast.Call):
if is_paddle_api(parent_node):
return True
else:
return False
wrapper_node = wrapper_node.parent
return False
def _is_dygraph_forward(self, func_id):
return func_id in self.class_node_dict
......@@ -280,6 +372,32 @@ class BasicApiTransformer(gast.NodeTransformer):
def get_feed_name_to_arg_id(self):
return self.feed_name_to_arg_id
def _update_name_to_tensor_shape(self, node):
assert isinstance(node, gast.Assign)
# TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1]
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.Name):
if value_node.id in self.name_to_tensor_shape:
self.name_to_tensor_shape[
target_id] = self.name_to_tensor_shape[value_node.id]
return True
if isinstance(value_node, gast.Attribute):
if self.is_tensor_shape(value_node): # eg: x.shape
self.name_to_tensor_shape[target_id] = value_node
return True
if isinstance(value_node, gast.Subscript):
if isinstance(value_node.value, gast.Attribute):
if self.is_tensor_shape(value_node.value): # eg: x.shape[0]
self.name_to_tensor_shape[target_id] = value_node
return True
return False
def convert_to_static(dyfunc):
"""
......
......@@ -360,7 +360,9 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
# TODO(Aurelius84): more elegant way to transform ast into callable object
import_str = "import paddle\n" \
"import paddle.fluid as fluid\n" \
"import paddle.fluid.layers as layers\n"
"import paddle.fluid.layers as layers\n" \
"import numpy as np\n" \
"import numpy\n"
with f:
module_name = os.path.basename(f.name[:-3])
f.write(import_str)
......
......@@ -181,3 +181,12 @@ def update_args_of_func(node, dygraph_node, method_name):
node.args = []
node.keywords = added_keywords + node.keywords
def create_api_shape_node(tensor_shape_node):
assert isinstance(tensor_shape_node, gast.Attribute)
api_shape_node = gast.Call(
func=gast.parse('fluid.layers.shape').body[0].value,
args=[tensor_shape_node.value],
keywords=[])
return api_shape_node
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
def dyfunc_tensor_shape_1(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, shape=x.shape)
return res
def dyfunc_tensor_shape_2(x):
x = fluid.dygraph.to_variable(x)
shape = x.shape
shape2 = shape
res = fluid.layers.reshape(x, shape2)
return res
def dyfunc_tensor_shape_3(x):
# Don't transform y.shape because y is numpy.ndarray
x = fluid.dygraph.to_variable(x)
y = numpy.ones(5)
res = fluid.layers.reshape(x, shape=y.shape)
return res
def dyfunc_tensor_shape_4(x):
x = fluid.dygraph.to_variable(x)
res = fluid.layers.reshape(x, shape=(-1, x.shape[0], len(x.shape)))
return res
def dyfunc_tensor_shape_5(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))`
x = fluid.dygraph.to_variable(x)
s = x.shape[0]
res = fluid.layers.reshape(x, shape=(-1, s))
return res
test_funcs = [
dyfunc_tensor_shape_1, dyfunc_tensor_shape_2, dyfunc_tensor_shape_3,
dyfunc_tensor_shape_4, dyfunc_tensor_shape_5
]
class TestTensorShape(unittest.TestCase):
def setUp(self):
self.input = numpy.ones(5).astype("int32")
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def get_dygraph_output(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input).numpy()
return res
def get_static_output(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
def test_transformed_static_result(self):
for func in test_funcs:
self.dygraph_func = func
static_res = self.get_static_output()
dygraph_res = self.get_dygraph_output()
self.assertTrue(
numpy.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册