diff --git a/.gitignore b/.gitignore index 85f6b5657bf5c462542c06f631f9717ba019b73d..1fabfa03fe068f5bde8bdc1616a5262f9cc84dfa 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,4 @@ paddle/fluid/pybind/tmp_eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/op_function_impl.h +paddle/fluid/pybind/*final_state_op_function_impl.h diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 5fe51425dc44e761bce9a11433b5842ec5a25aa6..a5e3183774f990b37ffe216920fea1e325e4a482 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -305,6 +305,7 @@ class WhileOp : public framework::OperatorBase { cond_data = GetCondData( scope.FindVar(Input(kCondition))->Get()); } + scope.DeleteScope(¤t_scope); } } @@ -367,6 +368,7 @@ class WhileGradOp : public framework::OperatorBase { auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto *parent_block = block->ParentBlock(); auto &skip_vars = Attr>(kSkipEagerDeletionVars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); @@ -428,6 +430,35 @@ class WhileGradOp : public framework::OperatorBase { continue; } + if (cur_scope_iter == step_scopes->rbegin()) { + auto &og_outside = *scope.FindVar(outside_og_name); + if (og_outside.IsType() && + !og_outside.GetMutable()->IsInitialized()) { + auto *var_desc = parent_block->FindVarRecursive(outside_og_name); + PADDLE_ENFORCE_NOT_NULL(var_desc, + platform::errors::PreconditionNotMet( + "Var `%s` is not found in parent " + "block, can't fill constant.", + outside_og_name)); + auto shape = var_desc->GetShape(); + VLOG(8) << "Found uninitialized tensor " << outside_og_name + << " in step 0, fill it with 0.0f. dims=" + << phi::make_ddim(shape); + framework::AttributeMap attrs; + attrs["dtype"] = var_desc->GetDataType(); + attrs["shape"] = phi::vectorize(phi::make_ddim(shape)); + attrs["value"] = 0.0f; + + auto var_name = outside_og_name; + auto zero_op = + framework::OpRegistry::CreateOp("fill_constant", + framework::VariableNameMap{}, + {{"Out", {var_name}}}, + attrs); + zero_op->Run(scope, dev_place); + } + } + auto &og_outside = *scope.FindVar(outside_og_name); auto &og_inside = *cur_scope.Var(inside_og_name); if (og_outside.IsType()) { @@ -534,9 +565,10 @@ class WhileGradOp : public framework::OperatorBase { // continue; // } - auto var_iter = std::find(outside_og_names.begin(), - outside_og_names.end(), - pg_ig_names[param_id]); + auto is_var_input_and_output = + std::find(outside_og_names.begin(), + outside_og_names.end(), + pg_ig_names[param_id]) != outside_og_names.end(); // zero gradient variable in step 0 if (cur_scope_iter == step_scopes->rbegin()) { @@ -555,8 +587,7 @@ class WhileGradOp : public framework::OperatorBase { inside_grad_name, framework::ToTypeName(var->Type()))); - if ((var_iter == outside_og_names.end()) && - var->IsType()) { + if (!is_var_input_and_output && var->IsType()) { auto &inside_tensor = var->Get(); framework::AttributeMap attrs; attrs["dtype"] = @@ -575,10 +606,7 @@ class WhileGradOp : public framework::OperatorBase { inside_tensor.lod()); } } - auto var_outside = scope.FindVar(pg_ig_names[param_id]); - if ((var_iter == outside_og_names.end()) || - ((var_iter != outside_og_names.end()) && - var_outside->IsType())) { + if (!is_var_input_and_output) { auto new_inside_name = cur_scope.Rename(inside_grad_name); auto sum_op = framework::OpRegistry::CreateOp( "sum", @@ -587,6 +615,8 @@ class WhileGradOp : public framework::OperatorBase { framework::AttributeMap{{"use_mkldnn", {false}}}); sum_op->Run(cur_scope, dev_place); cur_scope.Rename(new_inside_name, inside_grad_name); + } else { + ShareVariable(cur_scope, scope, pg_ig_names[param_id]); } } dev_ctx.Wait(); @@ -595,6 +625,29 @@ class WhileGradOp : public framework::OperatorBase { step_scopes->clear(); } + void ShareVariable(const framework::Scope &source, + const framework::Scope &dest, + std::string name) const { + auto from_var = source.FindVar(name); + auto to_var = dest.FindVar(name); + if (from_var->IsType()) { + if (from_var->Get().IsInitialized()) { + to_var->GetMutable()->ShareDataWith( + from_var->Get()); + } + } else if (from_var->IsType()) { + auto from_arr = from_var->GetMutable(); + auto to_arr = to_var->GetMutable(); + to_arr->clear(); + to_arr->resize(from_arr->size()); + for (size_t i = 0; i < to_arr->size(); ++i) { + if (from_arr->at(i).IsInitialized()) { + to_arr->at(i).ShareDataWith(from_arr->at(i)); + } + } + } + } + private: mutable std::shared_ptr executor_{nullptr}; mutable std::unique_ptr ctx_{nullptr}; @@ -646,6 +699,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker { block_ins.insert(o); } std::unordered_set output_grads; + for (const auto *op : grad_block->AllOps()) { for (auto &input_name : op->InputArgumentNames()) { // If the input of Op has been recorded or is generated by the forward @@ -658,7 +712,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker { parent_block->FindVarRecursive(input_name) != nullptr)) { continue; } - output_grads.insert(input_name); } for (auto &output_name : op->OutputArgumentNames()) { diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 8f8b4bfa7311538f6209bece93d143dfb50b12d6..df975d06a45d4d568f136dcdaf777887066db8de 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -2220,6 +2220,10 @@ def _find_op_path_( op.desc.output_arg_names(), output_names ): relevant_op_flags[i] = True + if core.has_non_empty_grad_op_maker(op.type): + for name in op.desc.input_arg_names(): + if name not in no_grad_set: + output_names.add(name) op_path = [ block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] diff --git a/python/paddle/fluid/tests/unittests/rnn/test_rnn_api.py b/python/paddle/fluid/tests/unittests/rnn/test_rnn_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f138ab3216eb49affe6331f24005e19c9d7c6ef3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/rnn/test_rnn_api.py @@ -0,0 +1,362 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle + +paddle.set_default_dtype("float64") + +import unittest + +import numpy as np + +from paddle import fluid +from paddle.fluid import framework + +bidirectional_list = ["bidirectional", "bidirect"] + + +class TestSimpleRNN(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super().__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction in bidirectional_list else 1 + self.place = place + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + self.seq_len = 12 + self.seed = 1234 + + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + + place = paddle.set_device(self.place) + paddle.disable_static(self.place) + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + cell_dy = paddle.nn.SimpleRNNCell(self.input_size, self.hidden_size) + self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major) + + paddle.enable_static() + + with paddle.fluid.unique_name.guard(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ): + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + + self.exe = fluid.Executor( + fluid.CPUPlace() + if self.place == "cpu" + else fluid.CUDAPlace(0) + ) + + rnn_in_data = paddle.static.data( + "x", + [None, self.batch_size, self.hidden_size], + dtype="float64", + ) + pre_h_data = paddle.static.data( + "pre_h", + [self.batch_size, self.hidden_size], + dtype="float64", + ) + seq_len_data = paddle.static.data( + "seq_len", [self.batch_size], dtype="int64" + ) + cell_st = paddle.nn.SimpleRNNCell( + self.input_size, self.hidden_size + ) + self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major) + st_out, st_last_h = self.rnn_st( + rnn_in_data, pre_h_data, sequence_length=seq_len_data + ) + + self.fetch_list = [st_out, st_last_h] + + self.exe.run(framework.default_startup_program()) + + self.main_program = framework.default_main_program() + + paddle.disable_static(self.place) + + def test_base(self, test_seq_len=False): + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(4, 16) + + paddle.disable_static(self.place) + if test_seq_len: + seq_len = np.array([9, 10, 8, 12], "int64") + else: + seq_len = np.array([12, 12, 12, 12], "int64") + + y1, h1 = self.rnn_net( + paddle.to_tensor(x), + paddle.to_tensor(prev_h), + sequence_length=paddle.to_tensor(seq_len), + ) + + paddle.enable_static() + out = self.exe.run( + self.main_program, + feed={"x": x, "pre_h": prev_h, "seq_len": seq_len}, + fetch_list=[self.fetch_list], + ) + + y2, h2 = out + + np.testing.assert_allclose(y1.numpy(), y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1.numpy(), h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_base() + self.test_base(True) + + +class TestGRU(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super().__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction in bidirectional_list else 1 + self.place = place + self.batch_size = 4 + self.input_size = 16 + self.hidden_size = 16 + self.seq_len = 12 + self.seed = 1234 + + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + + place = paddle.set_device(self.place) + paddle.disable_static(self.place) + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + cell_dy = paddle.nn.GRUCell(self.input_size, self.hidden_size) + self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major) + + paddle.enable_static() + + with paddle.fluid.unique_name.guard(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ): + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + + self.exe = fluid.Executor( + fluid.CPUPlace() + if self.place == "cpu" + else fluid.CUDAPlace(0) + ) + + rnn_in_data = paddle.static.data( + "x", + [None, self.batch_size, self.hidden_size], + dtype="float64", + ) + pre_h_data = paddle.static.data( + "pre_h", + [self.batch_size, self.hidden_size], + dtype="float64", + ) + seq_len_data = paddle.static.data( + "seq_len", [self.batch_size], dtype="int64" + ) + cell_st = paddle.nn.GRUCell(self.input_size, self.hidden_size) + self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major) + st_out, st_last_h = self.rnn_st( + rnn_in_data, pre_h_data, sequence_length=seq_len_data + ) + + self.fetch_list = [st_out, st_last_h] + + self.exe.run(framework.default_startup_program()) + + self.main_program = framework.default_main_program() + + paddle.disable_static(self.place) + + def test_base(self, test_seq_len=False): + x = np.random.randn(12, 4, 16) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(4, 16) + + paddle.disable_static(self.place) + if test_seq_len: + seq_len = np.array([9, 10, 8, 12], "int64") + else: + seq_len = np.array([12, 12, 12, 12], "int64") + + y1, h1 = self.rnn_net( + paddle.to_tensor(x), + paddle.to_tensor(prev_h), + sequence_length=paddle.to_tensor(seq_len), + ) + + paddle.enable_static() + out = self.exe.run( + self.main_program, + feed={"x": x, "pre_h": prev_h, "seq_len": seq_len}, + fetch_list=[self.fetch_list], + ) + + y2, h2 = out + + np.testing.assert_allclose(y1.numpy(), y2, atol=1e-8, rtol=1e-5) + np.testing.assert_allclose(h1.numpy(), h2, atol=1e-8, rtol=1e-5) + + def runTest(self): + self.test_base() + self.test_base(True) + + +class TestGRUBackward(unittest.TestCase): + def __init__(self, time_major=True, direction="forward", place="cpu"): + super().__init__("runTest") + self.time_major = time_major + self.direction = direction + self.num_directions = 2 if direction in bidirectional_list else 1 + self.place = place + self.batch_size = 4 + self.input_size = 4 + self.hidden_size = 4 + self.seq_len = 12 + self.seed = 1234 + + def setUp(self): + # Since `set_device` is global, set `set_device` in `setUp` rather than + # `__init__` to avoid using an error device set by another test case. + + place = paddle.set_device(self.place) + paddle.disable_static(self.place) + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + cell_dy = paddle.nn.SimpleRNNCell(self.input_size, self.hidden_size) + self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major) + + paddle.enable_static() + + with paddle.fluid.unique_name.guard(): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard( + main_program=main_program, startup_program=startup_program + ): + paddle.seed(self.seed) + paddle.framework.random._manual_program_seed(self.seed) + + self.exe = paddle.fluid.Executor( + fluid.CPUPlace() + if self.place == "cpu" + else fluid.CUDAPlace(0) + ) + + rnn_in_data = paddle.static.data( + "x", + [None, self.batch_size, self.hidden_size], + dtype="float64", + ) + pre_h_data = paddle.static.data( + "pre_h", + [self.batch_size, self.hidden_size], + dtype="float64", + ) + seq_len_data = paddle.static.data( + "seq_len", [self.batch_size], dtype="int64" + ) + + pre_h_data.stop_gradient = False + rnn_in_data.stop_gradient = False + + cell_st = paddle.nn.SimpleRNNCell( + self.input_size, self.hidden_size + ) + self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major) + + st_out, st_last_h = self.rnn_st( + rnn_in_data, pre_h_data, sequence_length=seq_len_data + ) + loss = paddle.sum(st_out) + sgd = paddle.optimizer.SGD(0.0) + sgd.minimize(loss) + self.fetch_list = [st_out, st_last_h, "pre_h@GRAD", "x@GRAD"] + + self.exe.run(framework.default_startup_program()) + + self.main_program = framework.default_main_program() + + paddle.disable_static(self.place) + + def test_base(self, test_seq_len=False): + x = np.random.randn(12, 4, self.hidden_size) + if not self.time_major: + x = np.transpose(x, [1, 0, 2]) + prev_h = np.random.randn(4, self.hidden_size) + + paddle.disable_static(self.place) + if test_seq_len: + seq_len = np.array([9, 10, 8, 12], "int64") + else: + seq_len = np.array([12, 12, 12, 12], "int64") + + x_in = paddle.to_tensor(x) + h_in = paddle.to_tensor(prev_h) + x_in.stop_gradient = False + h_in.stop_gradient = False + y1, h1 = self.rnn_net( + x_in, + h_in, + sequence_length=paddle.to_tensor(seq_len), + ) + loss = y1.sum() + loss.backward() + + h1_grad = h_in.gradient() + + paddle.enable_static() + out = self.exe.run( + self.main_program, + feed={"x": x, "pre_h": prev_h, "seq_len": seq_len}, + fetch_list=[self.fetch_list], + ) + + y2, h2, g1, g2 = out + + np.testing.assert_allclose(h1_grad, g1, atol=1e-8, rtol=1e-5) + + def runTest(self): + + self.test_base(True) + self.test_base() + self.test_base() + self.test_base(True) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index 59dfa2413d5c393fd7569f5c704b4d31daadc8c8..7fbce5abc0b13f6ac1635a514ad81df14dc9cc78 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -248,8 +248,8 @@ def _rnn_static_graph( if not time_major: inputs = map_structure(_transpose_batch_time, inputs) + max_seq_len = paddle.shape(flatten(inputs)[0])[0] if sequence_length: - max_seq_len = paddle.shape(flatten(inputs)[0])[0] mask = sequence_lod.sequence_mask( sequence_length, maxlen=max_seq_len, @@ -260,30 +260,77 @@ def _rnn_static_graph( inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs) mask = paddle.reverse(mask, axis=[0]) if sequence_length else None - # StaticRNN - rnn = control_flow.StaticRNN() - with rnn.step(): - inputs = map_structure(rnn.step_input, inputs) - states = map_structure(rnn.memory, initial_states) - copy_states = map_structure(lambda x: x, states) - outputs, new_states = cell(inputs, copy_states, **kwargs) - utils.assert_same_structure(states, new_states) + with paddle.fluid.framework.device_guard("cpu"): + start_i = paddle.zeros([1], dtype="int64") + end = max_seq_len + + end = paddle.cast(end, "int64") + cond = start_i < end + while_op = control_flow.While(cond) + + out_array = paddle.tensor.create_array(dtype=flatten(inputs)[0].dtype) + + init_array = map_structure( + lambda x: paddle.tensor.create_array(dtype=x.dtype), initial_states + ) + + map_structure( + lambda x, y: paddle.tensor.array_write(x, start_i, y), + initial_states, + init_array, + ) + + with while_op.block(): + + step_in = inputs[start_i] + # step_in = paddle.fluid.layers.Print( step_in, message="step in") + pre_state = map_structure( + lambda x: paddle.tensor.array_read(x, start_i), init_array + ) + # pre_state = paddle.fluid.layers.Print( pre_state, message="pre") + outputs, new_states = cell(step_in, pre_state, **kwargs) + assert isinstance(outputs, paddle.fluid.framework.Variable) + utils.assert_same_structure(new_states, pre_state) if sequence_length: - step_mask = rnn.step_input(mask) + step_mask = paddle.unsqueeze(mask[start_i], 1) + # paddle.fluid.layers.Print( step_mask, message="mask") + # new_states = map_structure( + # partial(_maybe_copy, step_mask=step_mask), + # pre_state, new_states + # ) new_states = map_structure( - partial(_maybe_copy, step_mask=step_mask), states, new_states + lambda x, y: (x * step_mask + y * (1.0 - step_mask)), + new_states, + pre_state, ) - map_structure(rnn.update_memory, states, new_states) - flat_outputs = flatten(outputs) - map_structure(rnn.step_output, outputs) - map_structure(rnn.step_output, new_states) + paddle.tensor.array_write(outputs, start_i, out_array) + + with paddle.fluid.framework.device_guard("cpu"): + + start_i = paddle.tensor.increment(x=start_i, value=1) + map_structure( + lambda x, y: paddle.tensor.array_write(x, start_i, y), + new_states, + init_array, + ) + + with paddle.fluid.framework.device_guard("cpu"): + new_cond = paddle.tensor.less_than(start_i, end) + paddle.fluid.layers.assign(new_cond, cond) + + out, _ = paddle.fluid.layers.tensor_array_to_tensor( + out_array, axis=0, use_stack=True + ) - rnn_out = rnn() - final_outputs = rnn_out[: len(flat_outputs)] - final_outputs = utils.pack_sequence_as(outputs, final_outputs) - final_states = map_structure(lambda x: x[-1], rnn_out[len(flat_outputs) :]) - final_states = utils.pack_sequence_as(new_states, final_states) + all_state = map_structure( + lambda x: paddle.fluid.layers.tensor_array_to_tensor( + x, axis=0, use_stack=True + )[0], + init_array, + ) + final_outputs = out + final_states = map_structure(lambda x: x[-1], all_state) if is_reverse: final_outputs = map_structure(