From ac9a7eeea42470e940e64650275ccc8e2991f70f Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 6 May 2020 17:09:50 +0800 Subject: [PATCH] [Dy2Stat]Support list pop (#24250) * Replace dygraph_to_static_func with @declarative or program_translator.get_func in test_list.py * Add comments in ConditionalBlock. * Support list pop last item. * Support pop the i-th item. * Support an empty tensor array as Input in assign op and set the kernel type is float. --- paddle/fluid/framework/operator.cc | 2 +- paddle/fluid/framework/operator_test.cc | 18 +- paddle/fluid/operators/assign_op.cc | 11 + paddle/fluid/operators/slice_op.cc | 3 +- .../dygraph/dygraph_to_static/__init__.py | 4 + .../dygraph_to_static/list_transformer.py | 117 ++++++++- .../dygraph_to_static/program_translator.py | 8 +- python/paddle/fluid/layers/control_flow.py | 3 + .../unittests/dygraph_to_static/test_list.py | 232 +++++++++++------- 9 files changed, 295 insertions(+), 103 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a53e8ec09a..874d6a2bb0 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1326,7 +1326,7 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( PADDLE_ENFORCE_NE( data_type, dafault_data_type, "The Input Variable(%s) of %s Op used to determine kernel data type " - "is empty or not LoDTensor or SelectedRows.", + "is empty or not LoDTensor or SelectedRows or LoDTensorArray.", name, Type()); return data_type; } diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index 6fbaca7174..3c1fb819da 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -476,14 +476,14 @@ TEST(IndicateVarDataTypeTest, other) { paddle::framework::InitDevices(true); paddle::framework::proto::OpDesc op_desc; op_desc.set_type("indicate_other_data_type_test"); - BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs()); + BuildVar("Other", {"lod_rank_table_1"}, op_desc.add_inputs()); paddle::platform::CPUPlace cpu_place; paddle::framework::Scope scope; auto op = paddle::framework::OpRegistry::CreateOp(op_desc); - auto* var = scope.Var("lod_tensor_array_1"); - var->GetMutable(); + auto* var = scope.Var("lod_rank_table_1"); + var->GetMutable(); bool caught = false; try { @@ -491,11 +491,13 @@ TEST(IndicateVarDataTypeTest, other) { } catch (paddle::platform::EnforceNotMet& err) { caught = true; std::string ex_msg = err.what(); - EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of " - "indicate_other_data_type_test Op used to " - "determine kernel data type " - "is empty or not LoDTensor or SelectedRows") != - std::string::npos); + EXPECT_TRUE( + ex_msg.find( + "The Input Variable(Other) of " + "indicate_other_data_type_test Op used to " + "determine kernel data type " + "is empty or not LoDTensor or SelectedRows or LoDTensorArray") != + std::string::npos); } ASSERT_TRUE(caught); } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index cb47511796..f8c1216e97 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -58,6 +58,17 @@ class AssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + const framework::Variable *var = ctx.InputVar("X"); + if (var->IsType()) { + auto t_arr = var->Get(); + // NOTE(liym27): Support an empty tensor array as Input. + // And set the kernel type is float. + if (t_arr.size() == 0) { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } + } + return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 17ee3d7188..418159b341 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -47,7 +47,8 @@ class SliceOp : public framework::OperatorWithKernel { // the output shape is determined by SliceKernel:Compute in runtime. return; } else { - // NOTE: A better way is needed to get accurate dims of tensor array. + // NOTE(liym27): A better way is needed to get accurate dims of tensor + // array. // The resulted dim of GetInputDim("Input") is the dim of the // last item written into TensorArray "Input". Maybe it's a bug to fix. ctx->SetOutputDim("Out", ctx->GetInputDim("Input")); diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py index d2d03d65b1..c1a884f3ba 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/__init__.py @@ -32,6 +32,9 @@ from .program_translator import * from . import convert_call_func from .convert_call_func import * +from . import list_transformer +from .list_transformer import * + __all__ = [] __all__ += ast_transformer.__all__ __all__ += loop_transformer.__all__ @@ -39,3 +42,4 @@ __all__ += static_analysis.__all__ __all__ += variable_trans_func.__all__ __all__ += program_translator.__all__ __all__ += convert_call_func.__all__ +__all__ += list_transformer.__all__ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py index 4780325bae..4136838dec 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py @@ -14,10 +14,96 @@ from __future__ import print_function -import gast import astor +import gast + from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor -from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform, ast_to_source_code +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code, is_control_flow_to_transform +from paddle.fluid.framework import core, default_main_program, Variable +from paddle.fluid.layers import array_length, array_read, array_write, create_array +from paddle.fluid.layers import assign, cast, fill_constant, slice +from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment + +__all__ = ['convert_list_pop'] + + +def create_array_in_parent_blcok(null_array): + # TODO(liym27): Create a null tensor_array with the same name in parent block to avoid a bug in control flow, + # because in `null_array = create_array("float32")`, `null_array` is not a output of a real OP. + # See class ConditionalBlock for details. + prog = default_main_program() + parent_idx = prog.current_block().parent_idx + while parent_idx != -1: + parent_block = prog.block(parent_idx) + parent_block.create_var( + name=null_array.name, + type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, + dtype="float32") + parent_idx = parent_block.parent_idx + + +# TODO(liym27): A better way to slice tensor array. +# Maybe support start == end for slice op. +def slice_tensor_array(array, start, end): + end = cast(end, "int32") + + def true_fn(): + null_array = create_array("float32") + create_array_in_parent_blcok(null_array) + return null_array + + def false_fn(array, start, end): + new_array = slice(array, starts=[start], ends=[end], axes=[0]) + return new_array + + new_array = cond(start == end, true_fn, lambda: false_fn(array, start, end)) + return new_array + + +def tensor_array_pop(array, idx): + assert isinstance(idx, int) + + def cond(i, new_array): + return less_than(i, arr_len) + + def body(i, new_array): + item = array_read(array=array, i=i) + array_write(item, array_length(new_array), new_array) + i = increment(i) + return i, new_array + + arr_len = array_length(array) + if idx < 0: + idx = idx + arr_len + else: + idx = fill_constant(shape=[1], dtype="int64", value=idx) + + pop_item = array_read(array, idx) + + new_array = slice_tensor_array(array, 0, idx) + i = idx + 1 + _, new_array = while_loop(cond, body, [i, new_array]) + assign(input=new_array, output=array) + + return pop_item + + +def convert_list_pop(target, idx=None): + """ + Convert list pop. + """ + + if idx is None: + idx = -1 + + is_variable = isinstance(target, Variable) + if is_variable: + is_tensor_array = target.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + if is_variable and is_tensor_array: + result = tensor_array_pop(target, idx) + else: + result = target.pop(idx) + return result class ListTransformer(gast.NodeTransformer): @@ -45,12 +131,21 @@ class ListTransformer(gast.NodeTransformer): self.visit(self.root) self.replace_list_with_tensor_array(self.root) + def visit_Call(self, node): + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == "pop": + node = self._replace_list_pop(node) + return node + def visit_Assign(self, node): if self._update_list_name_to_updated(node): return node if self._need_to_array_write_node(node): return self._transform_slice_to_tensor_write(node) + + self.generic_visit(node) return node def visit_If(self, node): @@ -203,3 +298,21 @@ class ListTransformer(gast.NodeTransformer): self.list_name_to_updated[target_id] == False: del self.list_name_to_updated[target_id] return False + + def _replace_list_pop(self, node): + assert isinstance(node, gast.Call) + assert isinstance(node.func, gast.Attribute) + + target_node = node.func.value + target_str = ast_to_source_code(target_node).strip() + + if node.args: + idx_node = node.args[0] + idx_str = ast_to_source_code(idx_node).strip() + else: + idx_str = "None" + + new_call_str = "fluid.dygraph.dygraph_to_static.convert_list_pop({}, {})".format( + target_str, idx_str) + new_call_node = gast.parse(new_call_str).body[0].value + return new_call_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index bb8fbeb50a..4c365a9d83 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -344,10 +344,10 @@ class ProgramTranslator(object): prog_trans.enable(False) x = np.ones([1, 2]) - # The declarative is disabled so the func is run in dygraph + # The declarative is disabled so the func is run in dygraph with fluid.dygraph.guard(): print(func(x).numpy()) # [[2. 2.]] - + """ check_type(enable_declarative, "enable_declarative", bool, "ProgramTranslator.enable") @@ -361,7 +361,7 @@ class ProgramTranslator(object): Args: dygraph_func (callable): the dygraph function. - *args, **kwargs : the input argument of dygraph_func. + *args, **kwargs : the input argument of dygraph_func. Returns: VarBase or tuple of VarBase: the dygraph VarBase containing digital @@ -763,7 +763,7 @@ class ProgramTranslator(object): assert abs(index_of_loss) < len(outputs), \ "index_of_loss: {} shall not exceed the length of outputs: {}.".format( - index_of_loss, len(outputs)) + index_of_loss, len(outputs)) loss_var = outputs[index_of_loss] check_type(loss_var, "loss_var", framework.Variable, diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index dc31ec3b60..31430a1f32 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -2001,6 +2001,9 @@ class ConditionalBlock(object): intermediate = set() params = set() + # NOTE: Here assumes that all variables are input or output of Ops, + # but some variables are created without appendding a real op. + # For example, in `arr = create_array(dtype)`, `arr` is not a output of a op. for each_op in inside_block.ops: assert isinstance(each_op, Operator) for iname in each_op.input_names: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 3e65492524..a289df6d27 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -15,15 +15,20 @@ from __future__ import print_function import unittest +from functools import partial + import numpy as np import paddle.fluid as fluid -from paddle.fluid.dygraph.jit import dygraph_to_static_func +from paddle.fluid.dygraph.jit import declarative +from paddle.fluid.layers.utils import map_structure SEED = 2020 np.random.seed(SEED) -def test_list_without_control_flow(x): +# Situation 1: Test list append +@declarative +def test_list_append_without_control_flow(x): # Python list will not be transformed. x = fluid.dygraph.to_variable(x) a = [] @@ -33,7 +38,8 @@ def test_list_without_control_flow(x): return a -def test_list_in_if(x): +@declarative +def test_list_append_in_if(x): x = fluid.dygraph.to_variable(x) a = [] if x.numpy()[0] > 0: @@ -45,7 +51,8 @@ def test_list_in_if(x): return a -def test_list_in_for_loop(x, iter_num): +@declarative +def test_list_append_in_for_loop(x, iter_num): x = fluid.dygraph.to_variable(x) # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor iter_num = fluid.layers.fill_constant( @@ -57,7 +64,8 @@ def test_list_in_for_loop(x, iter_num): return a -def test_list_in_for_loop_with_concat(x, iter_num): +@declarative +def test_list_append_in_for_loop_with_concat(x, iter_num): x = fluid.dygraph.to_variable(x) a = [] # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor @@ -70,22 +78,21 @@ def test_list_in_for_loop_with_concat(x, iter_num): return a -def test_list_in_while_loop(x, iter_num): +@declarative +def test_list_append_in_while_loop(x, iter_num): x = fluid.dygraph.to_variable(x) iter_num = fluid.layers.fill_constant( shape=[1], value=iter_num, dtype="int32") a = [] i = 0 - # Note: `i < iter_num` can't be supported in dygraph mode now, - # but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892. - # If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`. - while i < iter_num.numpy()[0]: + while i < iter_num: a.append(x) i += 1 return a -def test_list_in_while_loop_with_stack(x, iter_num): +@declarative +def test_list_append_in_while_loop_with_stack(x, iter_num): x = fluid.dygraph.to_variable(x) iter_num = fluid.layers.fill_constant( shape=[1], value=iter_num, dtype="int32") @@ -98,121 +105,172 @@ def test_list_in_while_loop_with_stack(x, iter_num): return out +# Situation 2: Test list pop +@declarative +def test_list_pop_without_control_flow_1(x): + x = fluid.dygraph.to_variable(x) + a = [] + if 2 > 1: + a.append(x) + a.pop() + return a + + +@declarative +def test_list_pop_without_control_flow_2(x): + x = fluid.dygraph.to_variable(x) + a = [] + if 2 > 1: + a.append(x) + a.append(x + 1) + last_tiem = a.pop(1) + return last_tiem + + +@declarative +def test_list_pop_in_if(x): + x = fluid.dygraph.to_variable(x) + a = [] + if x.numpy()[0] > 0: + a.append(x) + a.append(fluid.layers.fill_constant(shape=[1], value=1, dtype="int64")) + else: + a.append(x + 1) + a.append(fluid.layers.fill_constant(shape=[2], value=2, dtype="int64")) + item1 = a.pop(1) + a.pop() + return a, item1 + + +@declarative +def test_list_pop_in_for_loop(x, iter_num): + x = fluid.dygraph.to_variable(x) + # Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32" + ) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved + + a = [] + for i in range(iter_num): + a.append(x + i) + + one = fluid.layers.ones(shape=[1], dtype="int32") + for i in range(one.numpy()[0]): + item = a.pop() + + return a, item + + +@declarative +def test_list_pop_in_while_loop(x, iter_num): + x = fluid.dygraph.to_variable(x) + iter_num = fluid.layers.fill_constant( + shape=[1], value=iter_num, dtype="int32") + a = [] + i = 0 + while i < iter_num: + a.append(x + i) + i += 1 + if i % 2 == 1: + a.pop() + return a + + class TestListWithoutControlFlow(unittest.TestCase): def setUp(self): - self.input = np.random.random((3)).astype('int32') self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() + + self.init_data() self.init_dygraph_func() + def init_data(self): + self.input = np.random.random((3)).astype('int32') + def init_dygraph_func(self): - self.dygraph_func = test_list_without_control_flow + self.all_dygraph_funcs = [ + test_list_append_without_control_flow, + test_list_pop_without_control_flow_1, + test_list_pop_without_control_flow_2, + ] + + def varbase_to_numpy(self, res): + if isinstance(res, (list, tuple)): + res = map_structure(lambda x: x.numpy(), res) + else: + res = [res.numpy()] + return res def run_dygraph_mode(self): with fluid.dygraph.guard(): res = self.dygraph_func(self.input) - if isinstance(res, (list, tuple)): - res = res[0] - return res.numpy() + return self.varbase_to_numpy(res) def run_static_mode(self): main_program = fluid.Program() with fluid.program_guard(main_program): - tensor_list = dygraph_to_static_func(self.dygraph_func)(self.input) - exe = fluid.Executor(self.place) - static_res = exe.run(main_program, fetch_list=tensor_list[0]) - - return static_res[0] + res = self.dygraph_func(self.input) + return self.varbase_to_numpy(res) def test_transformed_static_result(self): - static_res = self.run_static_mode() - dygraph_res = self.run_dygraph_mode() - self.assertTrue( - np.allclose(dygraph_res, static_res), - msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, - static_res)) + for dyfunc in self.all_dygraph_funcs: + self.dygraph_func = dyfunc + static_res_list = self.run_static_mode() + dygraph_res_list = self.run_dygraph_mode() + + self.assertEqual(len(static_res_list), len(dygraph_res_list)) + for stat_res, dy_res in zip(static_res_list, dygraph_res_list): + self.assertTrue( + np.allclose(stat_res, dy_res), + msg='dygraph_res is {}\nstatic_res is {}'.format(stat_res, + dy_res)) class TestListInIf(TestListWithoutControlFlow): def init_dygraph_func(self): - self.dygraph_func = test_list_in_if - - def run_static_mode(self): - main_program = fluid.Program() - with fluid.program_guard(main_program): - tensor_array = dygraph_to_static_func(self.dygraph_func)(self.input) - static_out = fluid.layers.array_read( - tensor_array, - i=fluid.layers.fill_constant( - shape=[1], value=0, dtype='int64')) - exe = fluid.Executor(self.place) - numpy_res = exe.run(main_program, fetch_list=static_out) - return numpy_res[0] + self.all_dygraph_funcs = [test_list_append_in_if, test_list_pop_in_if] class TestListInWhileLoop(TestListWithoutControlFlow): - def setUp(self): - self.iter_num = 3 + def init_data(self): self.input = np.random.random((3)).astype('int32') - self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( - ) else fluid.CPUPlace() - self.init_dygraph_func() + self.iter_num = 3 def init_dygraph_func(self): - self.dygraph_func = test_list_in_while_loop - - def run_dygraph_mode(self): - with fluid.dygraph.guard(): - var_res = self.dygraph_func(self.input, self.iter_num) - numpy_res = [ele.numpy() for ele in var_res] - return numpy_res - - def run_static_mode(self): - main_program = fluid.Program() - with fluid.program_guard(main_program): - tensor_array = dygraph_to_static_func(self.dygraph_func)( - self.input, self.iter_num) - static_outs = [] - for i in range(self.iter_num): - static_outs.append( - fluid.layers.array_read( - tensor_array, - i=fluid.layers.fill_constant( - shape=[1], value=i, dtype='int64'))) - - exe = fluid.Executor(self.place) - numpy_res = exe.run(main_program, fetch_list=static_outs) - return numpy_res + self.all_dygraph_funcs = [ + partial( + test_list_append_in_while_loop, iter_num=self.iter_num), + partial( + test_list_pop_in_while_loop, iter_num=self.iter_num), + ] class TestListInWhileLoopWithStack(TestListInWhileLoop): def init_dygraph_func(self): - self.dygraph_func = test_list_in_while_loop_with_stack - - def run_dygraph_mode(self): - with fluid.dygraph.guard(): - var_res = self.dygraph_func(self.input, self.iter_num) - numpy_res = var_res.numpy() - return numpy_res - - def run_static_mode(self): - main_program = fluid.Program() - with fluid.program_guard(main_program): - out_var = dygraph_to_static_func(self.dygraph_func)(self.input, - self.iter_num) - exe = fluid.Executor(self.place) - numpy_res = exe.run(main_program, fetch_list=out_var) - return numpy_res[0] + self.all_dygraph_funcs = [ + partial( + test_list_append_in_while_loop_with_stack, + iter_num=self.iter_num) + ] class TestListInForLoop(TestListInWhileLoop): def init_dygraph_func(self): - self.dygraph_func = test_list_in_for_loop + self.all_dygraph_funcs = [ + partial( + test_list_append_in_for_loop, iter_num=self.iter_num), + partial( + test_list_pop_in_for_loop, iter_num=self.iter_num), + ] class TestListInForLoopWithConcat(TestListInWhileLoopWithStack): def init_dygraph_func(self): - self.dygraph_func = test_list_in_for_loop_with_concat + self.all_dygraph_funcs = [ + partial( + test_list_append_in_for_loop_with_concat, + iter_num=self.iter_num) + ] if __name__ == '__main__': -- GitLab