未验证 提交 08e80d17 编写于 作者: L liym27 提交者: GitHub

Support list in control flow for dygraph_to_static (#22902)

* support list in control flow if. test=develop

* support list in for/while and supplement tests. test=develop
上级 137d6563
...@@ -30,6 +30,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func ...@@ -30,6 +30,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
...@@ -67,6 +68,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -67,6 +68,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
basic_api_trans.ast_visit() basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
ListTransformer(node_wrapper).transform()
# Transform all if/else statement of Dygraph into Static Graph. # Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform() IfElseTransformer(node_wrapper).transform()
......
...@@ -424,7 +424,9 @@ def create_cond_node(return_name_ids, pred, true_func, false_func): ...@@ -424,7 +424,9 @@ def create_cond_node(return_name_ids, pred, true_func, false_func):
func=cond_api, func=cond_api,
args=[pred, true_func_lambda, false_func_lambda], args=[pred, true_func_lambda, false_func_lambda],
keywords=[]) keywords=[])
targets = [generate_name_node(return_name_ids, ctx=gast.Store())] if return_name_ids:
assign_node = gast.Assign(targets=targets, value=cond_layer) targets = [generate_name_node(return_name_ids, ctx=gast.Store())]
assign_node = gast.Assign(targets=targets, value=cond_layer)
return assign_node return assign_node
else:
return gast.Expr(value=cond_layer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import astor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
class ListTransformer(gast.NodeTransformer):
"""
This class transforms python list used in control flow into Static Graph Ast
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
self.name_of_list_set = set()
self.list_name_to_updated = dict()
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
var_env = self.static_analysis_visitor.get_var_env()
var_env.cur_scope = var_env.cur_scope.sub_scopes[0]
self.scope_var_type_dict = var_env.get_scope_var_type()
def transform(self):
self.visit(self.root)
self.replace_list_with_tensor_array(self.root)
def visit_Assign(self, node):
self._update_list_name_to_updated(node)
return node
def visit_If(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_While(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_For(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def replace_list_with_tensor_array(self, node):
for child_node in gast.walk(node):
if isinstance(child_node, gast.Assign):
if self._need_to_create_tensor_array(child_node):
child_node.value = self._create_tensor_array()
def _transform_list_append_in_control_flow(self, node):
for child_node in gast.walk(node):
if self._need_to_array_write_node(child_node):
child_node.value = \
self._to_array_write_node(child_node.value)
def _need_to_array_write_node(self, node):
if isinstance(node, gast.Expr):
if isinstance(node.value, gast.Call):
if self._is_list_append_tensor(node.value):
return True
return False
def _is_list_append_tensor(self, node):
"""
a.append(b): a is list, b is Tensor
self.x.append(b): self.x is list, b is Tensor
"""
assert isinstance(node, gast.Call)
# 1. The func is `append`.
if not isinstance(node.func, gast.Attribute):
return False
if node.func.attr != 'append':
return False
# 2. It's a `python list` to call append().
value_name = astor.to_source(gast.gast_to_ast(node.func.value)).strip()
if value_name not in self.list_name_to_updated:
return False
# 3. The arg of append() is one `Tensor`
# Only one argument is supported in Python list.append()
if len(node.args) != 1:
return False
arg = node.args[0]
if isinstance(arg, gast.Name):
# TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
# Need a better way to confirm whether `arg.id` is a Tensor.
try:
var_type_set = self.scope_var_type_dict[arg.id]
except KeyError:
return False
if NodeVarType.NUMPY_NDARRAY in var_type_set:
return False
if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
return False
# else:
# Todo: Consider that `arg` may be a gast.Call about Paddle Api.
# eg: list_a.append(fluid.layers.reshape(x))
# return True
self.list_name_to_updated[value_name.strip()] = True
return True
def _need_to_create_tensor_array(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
try:
target_id = target_node.id
except AttributeError:
return False
if self.list_name_to_updated.get(target_id):
return True
return False
def _create_tensor_array(self):
# Although `dtype='float32'`, other types such as `int32` can also be supported
func_code = "fluid.layers.create_array(dtype='float32')"
func_node = gast.parse(func_code).body[0].value
return func_node
def _to_array_write_node(self, node):
assert isinstance(node, gast.Call)
array = astor.to_source(gast.gast_to_ast(node.func.value))
x = astor.to_source(gast.gast_to_ast(node.args[0]))
i = "fluid.layers.array_length({})".format(array)
func_code = "fluid.layers.array_write(x={}, i={}, array={})".format(
x, i, array)
return gast.parse(func_code).body[0].value
def _update_list_name_to_updated(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
# TODO: Consider node has more than one target. eg: x, y = a, []
try:
target_id = target_node.id
except AttributeError:
return False
value_node = node.value
if isinstance(value_node, gast.List):
self.list_name_to_updated[target_id] = False
return True
elif target_id in self.name_of_list_set:
del self.list_name_to_updated[target_id]
return False
...@@ -77,6 +77,29 @@ def is_numpy_api(node): ...@@ -77,6 +77,29 @@ def is_numpy_api(node):
return False return False
def is_control_flow_to_transform(node, var_name_to_type):
"""
Determines whether the node is a Paddle control flow statement which needs to
transform into a static graph control flow statement.
"""
assert isinstance(node, gast.AST), \
"The type of input node must be gast.AST, but received %s." % type(node)
if isinstance(node, gast.If):
# TODO: make a better condition
return True
if isinstance(node, gast.For):
# TODO: make a better condition
return True
if isinstance(node, gast.While):
# TODO: make a better condition
return True
return False
def _delete_keywords_from(node): def _delete_keywords_from(node):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
func_src = astor.to_source(gast.gast_to_ast(node.func)) func_src = astor.to_source(gast.gast_to_ast(node.func))
...@@ -255,7 +278,8 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): ...@@ -255,7 +278,8 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
""" """
nodes = copy.copy(nodes) nodes = copy.copy(nodes)
# add return statement # add return statement
nodes.append(gast.Return(value=generate_name_node(return_name_ids))) if return_name_ids:
nodes.append(gast.Return(value=generate_name_node(return_name_ids)))
func_def_node = gast.FunctionDef( func_def_node = gast.FunctionDef(
name=name, name=name,
args=input_args, args=input_args,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
SEED = 2020
np.random.seed(SEED)
def test_list_without_control_flow(x):
# Python list will not be transformed.
x = fluid.dygraph.to_variable(x)
a = []
a.append(x)
return a
def test_list_in_if(x):
x = fluid.dygraph.to_variable(x)
a = []
if x.numpy()[0] > 0:
a.append(x)
else:
a.append(
fluid.layers.fill_constant(
shape=[1, 2], value=9, dtype="int64"))
return a
def test_list_in_for_loop(x, iter_num):
# Note: for_loop can't be transformed before PR22867 merged.
x = fluid.dygraph.to_variable(x)
a = []
for i in range(iter_num):
a.append(x)
return a
def test_list_in_while_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32")
a = []
i = 0
# Note: `i < iter_num` can't be supported in dygraph mode now,
# but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892.
# If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`.
while i < iter_num.numpy()[0]:
a.append(x)
i += 1
return a
class TestListWithoutControlFlow(unittest.TestCase):
def setUp(self):
self.input = np.random.random((3)).astype('int32')
self.dygraph_func = test_list_without_control_flow
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def run_dygraph_mode(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (list, tuple)):
res = res[0]
return res.numpy()
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_list = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=tensor_list[0])
return static_res[0]
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestListInIf(TestListWithoutControlFlow):
def setUp(self):
self.input = np.random.random((3)).astype('int32')
self.dygraph_func = test_list_in_if
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_array = dygraph_to_static_graph(self.dygraph_func)(
self.input)
static_out = fluid.layers.array_read(
tensor_array,
i=fluid.layers.fill_constant(
shape=[1], value=0, dtype='int64'))
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_out)
return numpy_res[0]
class TestListInWhileLoop(unittest.TestCase):
def setUp(self):
self.iter_num = 3
self.input = np.random.random((3)).astype('int32')
self.dygraph_func = test_list_in_while_loop
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def run_dygraph_mode(self):
with fluid.dygraph.guard():
var_res = self.dygraph_func(self.input, self.iter_num)
numpy_res = [ele.numpy() for ele in var_res]
return numpy_res
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_array = dygraph_to_static_graph(self.dygraph_func)(
self.input, self.iter_num)
static_outs = []
for i in range(self.iter_num):
static_outs.append(
fluid.layers.array_read(
tensor_array,
i=fluid.layers.fill_constant(
shape=[1], value=i, dtype='int64')))
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_outs)
return numpy_res
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.array_equal(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestListInForLoop(unittest.TestCase):
def setUp(self):
self.iter_num = 3
self.input = np.random.random((3)).astype('int32')
self.dygraph_func = test_list_in_for_loop
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def run_dygraph_mode(self):
with fluid.dygraph.guard():
var_res = self.dygraph_func(self.input, self.iter_num)
numpy_res = [ele.numpy() for ele in var_res]
return numpy_res
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_array = dygraph_to_static_graph(self.dygraph_func)(
self.input, self.iter_num)
static_outs = []
for i in range(self.iter_num):
static_outs.append(
fluid.layers.array_read(
tensor_array,
i=fluid.layers.fill_constant(
shape=[1], value=i, dtype='int64')))
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_outs)
return numpy_res
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.array_equal(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册