未验证 提交 5a202af9 编写于 作者: L liym27 提交者: GitHub

Support slice write in dygraph_to_static. test=develop (#23055)

上级 52575304
...@@ -17,12 +17,12 @@ from __future__ import print_function ...@@ -17,12 +17,12 @@ from __future__ import print_function
import gast import gast
import astor import astor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform, ast_to_source_code
class ListTransformer(gast.NodeTransformer): class ListTransformer(gast.NodeTransformer):
""" """
This class transforms python list used in control flow into Static Graph Ast This class transforms python list used in control flow into Static Graph Ast.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -31,8 +31,8 @@ class ListTransformer(gast.NodeTransformer): ...@@ -31,8 +31,8 @@ class ListTransformer(gast.NodeTransformer):
), "Input non-AstNodeWrapper node for the initialization of ListTransformer." ), "Input non-AstNodeWrapper node for the initialization of ListTransformer."
self.wrapper_root = wrapper_root self.wrapper_root = wrapper_root
self.root = wrapper_root.node self.root = wrapper_root.node
self.name_of_list_set = set()
self.list_name_to_updated = dict() self.list_name_to_updated = dict()
self.list_nodes = set()
self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
...@@ -46,7 +46,11 @@ class ListTransformer(gast.NodeTransformer): ...@@ -46,7 +46,11 @@ class ListTransformer(gast.NodeTransformer):
self.replace_list_with_tensor_array(self.root) self.replace_list_with_tensor_array(self.root)
def visit_Assign(self, node): def visit_Assign(self, node):
self._update_list_name_to_updated(node) if self._update_list_name_to_updated(node):
return node
if self._need_to_array_write_node(node):
return self._transform_slice_to_tensor_write(node)
return node return node
def visit_If(self, node): def visit_If(self, node):
...@@ -85,8 +89,33 @@ class ListTransformer(gast.NodeTransformer): ...@@ -85,8 +89,33 @@ class ListTransformer(gast.NodeTransformer):
if self._is_list_append_tensor(node.value): if self._is_list_append_tensor(node.value):
return True return True
if isinstance(node, gast.Assign):
target_node = node.targets[0]
if isinstance(target_node, gast.Subscript):
list_name = ast_to_source_code(target_node.value).strip()
if list_name in self.list_name_to_updated:
if self.list_name_to_updated[list_name] == True:
return True
return False return False
def _transform_slice_to_tensor_write(self, node):
assert isinstance(node, gast.Assign)
target_node = node.targets[0]
target_name = target_node.value.id
slice_node = target_node.slice
if isinstance(slice_node, gast.Slice):
pass
elif isinstance(slice_node, gast.Index):
value_code = ast_to_source_code(node.value)
i = "fluid.layers.cast(" \
"x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name)
assign_node = gast.parse(assign_code).body[0]
return assign_node
def _is_list_append_tensor(self, node): def _is_list_append_tensor(self, node):
""" """
a.append(b): a is list, b is Tensor a.append(b): a is list, b is Tensor
...@@ -135,7 +164,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -135,7 +164,7 @@ class ListTransformer(gast.NodeTransformer):
target_id = target_node.id target_id = target_node.id
except AttributeError: except AttributeError:
return False return False
if self.list_name_to_updated.get(target_id): if self.list_name_to_updated.get(target_id) and node in self.list_nodes:
return True return True
return False return False
...@@ -165,7 +194,8 @@ class ListTransformer(gast.NodeTransformer): ...@@ -165,7 +194,8 @@ class ListTransformer(gast.NodeTransformer):
value_node = node.value value_node = node.value
if isinstance(value_node, gast.List): if isinstance(value_node, gast.List):
self.list_name_to_updated[target_id] = False self.list_name_to_updated[target_id] = False
self.list_nodes.add(node)
return True return True
elif target_id in self.name_of_list_set: elif target_id in self.list_name_to_updated:
del self.list_name_to_updated[target_id] del self.list_name_to_updated[target_id]
return False return False
...@@ -44,7 +44,6 @@ def test_list_in_if(x): ...@@ -44,7 +44,6 @@ def test_list_in_if(x):
def test_list_in_for_loop(x, iter_num): def test_list_in_for_loop(x, iter_num):
# Note: for_loop can't be transformed before PR22867 merged.
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
a = [] a = []
for i in range(iter_num): for i in range(iter_num):
...@@ -53,7 +52,6 @@ def test_list_in_for_loop(x, iter_num): ...@@ -53,7 +52,6 @@ def test_list_in_for_loop(x, iter_num):
def test_list_in_for_loop_with_concat(x, iter_num): def test_list_in_for_loop_with_concat(x, iter_num):
# Note: for_loop can't be transformed before PR22867 merged.
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
a = [] a = []
for i in range(iter_num): for i in range(iter_num):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
SEED = 2020
np.random.seed(SEED)
def test_slice_without_control_flow(x):
# Python slice will not be transformed.
x = fluid.dygraph.to_variable(x)
a = [x]
a[0] = fluid.layers.fill_constant(shape=[2], value=2, dtype="float32")
return a
def test_slice_in_if(x):
x = fluid.dygraph.to_variable(x)
a = []
if x.numpy()[0] > 0:
a.append(x)
else:
a.append(
fluid.layers.fill_constant(
shape=[1, 2], value=9, dtype="int64"))
if x.numpy()[0] > 0:
a[0] = x
return a
def test_slice_in_while_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32")
a = []
i = 0
# Note: `i < iter_num` can't be supported in dygraph mode now,
# but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892.
# If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`.
while i < iter_num.numpy()[0]:
a.append(x)
i += 1
i = 0
while i < iter_num.numpy()[0]:
a[i] = fluid.layers.fill_constant(shape=[2], value=2, dtype="float32")
i += 1
return a
def test_slice_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
a = []
for i in range(iter_num):
a.append(x)
for i in range(iter_num):
a[i] = x
return a
class TestSliceWithoutControlFlow(unittest.TestCase):
def setUp(self):
self.input = np.random.random((3)).astype('int32')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
def init_dygraph_func(self):
self.dygraph_func = test_slice_without_control_flow
def run_dygraph_mode(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, (list, tuple)):
res = res[0]
return res.numpy()
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_list = dygraph_to_static_graph(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=tensor_list[0])
return static_res[0]
def test_transformed_static_result(self):
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
self.assertTrue(
np.allclose(dygraph_res, static_res),
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
static_res))
class TestSliceInIf(TestSliceWithoutControlFlow):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_if
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_array = dygraph_to_static_graph(self.dygraph_func)(
self.input)
static_out = fluid.layers.array_read(
tensor_array,
i=fluid.layers.fill_constant(
shape=[1], value=0, dtype='int64'))
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_out)
return numpy_res[0]
class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
def setUp(self):
self.iter_num = 3
self.input = np.random.random((3)).astype('int32')
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.init_dygraph_func()
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_while_loop
def run_dygraph_mode(self):
with fluid.dygraph.guard():
var_res = self.dygraph_func(self.input, self.iter_num)
numpy_res = [ele.numpy() for ele in var_res]
return numpy_res
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
tensor_array = dygraph_to_static_graph(self.dygraph_func)(
self.input, self.iter_num)
static_outs = []
for i in range(self.iter_num):
static_outs.append(
fluid.layers.array_read(
tensor_array,
i=fluid.layers.fill_constant(
shape=[1], value=i, dtype='int64')))
exe = fluid.Executor(self.place)
numpy_res = exe.run(main_program, fetch_list=static_outs)
return numpy_res
class TestSliceInForLoop(TestSliceInWhileLoop):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_for_loop
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册