From 8b07ce0e9486fa46cf57b168f2a2bf910fb5b91c Mon Sep 17 00:00:00 2001 From: helen88 Date: Thu, 14 Apr 2022 15:06:39 +0800 Subject: [PATCH] support multi layer and bidirection of lstm_grad, *test=kunlun (#41742) * support multi layer and bidirection of lstm_grad, *test=kunlun * support multi layer and bidirection of lstm_grad, *test=kunlun --- paddle/fluid/operators/rnn_op_xpu.cc | 247 ++++++++++++------ .../fluid/platform/device/xpu/xpu2_op_list.h | 1 + .../tests/unittests/xpu/test_rnn_op_xpu.py | 11 +- 3 files changed, 180 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/operators/rnn_op_xpu.cc b/paddle/fluid/operators/rnn_op_xpu.cc index a18d0ebfca9..220d91bf4fa 100644 --- a/paddle/fluid/operators/rnn_op_xpu.cc +++ b/paddle/fluid/operators/rnn_op_xpu.cc @@ -125,23 +125,13 @@ class RnnXPUKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); last_h->mutable_data(ctx.GetPlace()); last_c->mutable_data(ctx.GetPlace()); + int gate_num = 4; + int hidden_data_idx = (num_layers - 1); + hidden_data_idx += (gate_num + 1) * num_layers; + const int& block_size = direction_num * seq_len * batch_size * hidden_size; + reserve_data->Resize({hidden_data_idx, block_size}); - reserve_data->Resize( - {num_layers * direction_num * seq_len * batch_size * hidden_size * 5}); reserve_data->mutable_data(ctx.GetPlace()); - Tensor internal_output_1_tensor, internal_output_2_tensor; - T* internal_output_1_ptr = nullptr; - T* internal_output_2_ptr = nullptr; - if (num_layers >= 2) { - internal_output_1_tensor.Resize(output->dims()); - internal_output_1_ptr = - internal_output_1_tensor.mutable_data(ctx.GetPlace()); - } - if (num_layers >= 3) { - internal_output_2_tensor.Resize(output->dims()); - internal_output_2_ptr = - internal_output_2_tensor.mutable_data(ctx.GetPlace()); - } // get ptr from tensor auto x = input->data(); auto init_h_ptr = init_h->data(); @@ -151,8 +141,9 @@ class RnnXPUKernel : public framework::OpKernel { auto last_c_ptr = last_c->data(); auto i_f_g_o_ptr = reserve_data->data(); auto c_ptr = - i_f_g_o_ptr + - num_layers * direction_num * seq_len * batch_size * hidden_size * 4; + i_f_g_o_ptr + num_layers * block_size * 4; // 4 for i_f_g_o offset + auto hidden_data_ptr = + c_ptr + num_layers * block_size * 1; // 1 for c offset std::vector seq_len_tensor(batch_size, seq_len); if (has_seq_length) { @@ -161,33 +152,26 @@ class RnnXPUKernel : public framework::OpKernel { int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; + const T* cur_input_ptr = nullptr; + int cur_xdim = -1; + T* cur_output_ptr = y; for (int i = 0; i < num_layers; i++) { - auto i_f_g_o = i_f_g_o_ptr + - i * direction_num * seq_len * batch_size * hidden_size * 4; - auto c = c_ptr + i * direction_num * seq_len * batch_size * hidden_size; + auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; + auto c = c_ptr + i * block_size; + + cur_output_ptr = y; + if (i < num_layers - 1 && num_layers > 1) { + cur_output_ptr = hidden_data_ptr + i * block_size; + } - const T* cur_input_ptr = nullptr; - int cur_xdim = -1; if (i == 0) { cur_input_ptr = x; cur_xdim = input_dim; - } else if (i % 2 != 0) { - cur_input_ptr = internal_output_1_ptr; - cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; } else { - cur_input_ptr = internal_output_2_ptr; + cur_input_ptr = hidden_data_ptr + (i - 1) * block_size; cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; } - T* cur_output_ptr = nullptr; - if (i == num_layers - 1) { - cur_output_ptr = y; - } else if (i % 2 != 0) { - cur_output_ptr = internal_output_2_ptr; - } else { - cur_output_ptr = internal_output_1_ptr; - } - auto h_0 = init_h_ptr + direction_num * i * state_offset; auto c_0 = init_c_ptr + direction_num * i * state_offset; auto last_h = last_h_ptr + direction_num * i * state_offset; @@ -233,6 +217,8 @@ class RnnXPUKernel : public framework::OpKernel { template class RnnXPUGradKernel : public framework::OpKernel { + using XPUTyp = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { // get the tensor pointer for the input @@ -243,6 +229,7 @@ class RnnXPUGradKernel : public framework::OpKernel { auto* reserve_data = ctx.Input("Reserve"); const int& num_layers = ctx.Attr("num_layers"); const bool& is_bidirec = ctx.Attr("is_bidirec"); + const float& dropout_prob = ctx.Attr("dropout_prob"); const int& hidden_size = ctx.Attr("hidden_size"); const std::string& mode = ctx.Attr("mode"); @@ -257,16 +244,6 @@ class RnnXPUGradKernel : public framework::OpKernel { platform::errors::InvalidArgument( "XPU only support LSTM mode now, current mode is %s", mode)); - PADDLE_ENFORCE_EQ(is_bidirec, false, - platform::errors::InvalidArgument( - "XPU only support unidirectional LSTM now")); - - PADDLE_ENFORCE_EQ( - num_layers, 1, - platform::errors::InvalidArgument( - "XPU only support 1 layer LSTM now, current layer num is %s", - num_layers)); - auto init_h = pre_state[0]; auto init_c = pre_state[1]; @@ -289,11 +266,12 @@ class RnnXPUGradKernel : public framework::OpKernel { } // check shape - int seq_len = input->dims()[0]; - int batch_size = input->dims()[1]; - int input_dim = input->dims()[2]; + const int& seq_len = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& input_dim = input->dims()[2]; + const int& direction_num = is_bidirec ? 2 : 1; PADDLE_ENFORCE_EQ( - init_h->dims()[0], num_layers, + init_h->dims()[0], num_layers * direction_num, platform::errors::InvalidArgument("The num_layers of in RNN layer must" " be the same as first dim of init " "hidden, but received num_layers:%d," @@ -301,7 +279,7 @@ class RnnXPUGradKernel : public framework::OpKernel { num_layers, init_h->dims()[0])); PADDLE_ENFORCE_EQ( - init_c->dims()[0], num_layers, + init_c->dims()[0], num_layers * direction_num, platform::errors::InvalidArgument( "The num_layers of in RNN layer must" " be the same as first dim of cell state hidden, but received" @@ -323,52 +301,165 @@ class RnnXPUGradKernel : public framework::OpKernel { // allocate the memory and initization the input_grad input_grad->mutable_data(input->dims(), ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + phi::funcs::SetConstant zero; + zero(dev_ctx, input_grad, static_cast(0.0)); + + Tensor a, b; + Tensor* dynamic_grad_pre_h = &a; + Tensor* dynamic_grad_pre_c = &b; if (init_h_grad) { - init_h_grad->mutable_data(init_h->dims(), ctx.GetPlace()); + init_h_grad->mutable_data(last_h_grad->dims(), ctx.GetPlace()); + zero(dev_ctx, init_h_grad, static_cast(0.0)); + } else { + dynamic_grad_pre_h->Resize(last_h_grad->dims()); + dynamic_grad_pre_h->mutable_data(ctx.GetPlace()); + zero(dev_ctx, dynamic_grad_pre_h, static_cast(0.0)); + init_h_grad = dynamic_grad_pre_h; } if (init_c_grad) { - init_c_grad->mutable_data(init_c->dims(), ctx.GetPlace()); + init_c_grad->mutable_data(last_c_grad->dims(), ctx.GetPlace()); + } else { + dynamic_grad_pre_c->Resize(last_h_grad->dims()); + dynamic_grad_pre_c->mutable_data(ctx.GetPlace()); + init_c_grad = dynamic_grad_pre_c; + } + + Tensor temp_input_grad_1, temp_input_grad_2; + T* input_grad_1_ptr = nullptr; + T* input_grad_2_ptr = nullptr; + if (num_layers >= 2) { + temp_input_grad_1.Resize(output_grad->dims()); + input_grad_1_ptr = temp_input_grad_1.mutable_data(ctx.GetPlace()); + } + if (num_layers >= 3) { + temp_input_grad_2.Resize(output_grad->dims()); + input_grad_2_ptr = temp_input_grad_2.mutable_data(ctx.GetPlace()); } // get ptr from tensor auto x = input->data(); - auto h_0 = init_h->data(); - auto c_0 = init_c->data(); - auto w_x = parameter_lists[0][0]; - auto w_h = parameter_lists[0][1]; + auto init_h_ptr = init_h->data(); + auto init_c_ptr = init_c->data(); auto y = output->data(); auto y_grad = output_grad->data(); auto last_h_grad_ptr = last_h_grad->data(); auto last_c_grad_ptr = last_c_grad->data(); auto x_grad = input_grad->data(); - auto h_0_grad = init_h_grad ? init_h_grad->data() : nullptr; - auto c_0_grad = init_c_grad ? init_c_grad->data() : nullptr; - auto w_x_grad = parameter_lists_grad[0][0]; - auto w_h_grad = parameter_lists_grad[0][1]; - auto b_x_grad = parameter_lists_grad[0][2]; - auto b_h_grad = parameter_lists_grad[0][3]; - auto i_f_g_o = reserve_data->data(); - auto c = i_f_g_o + seq_len * batch_size * hidden_size * 4; + auto init_h_grad_ptr = init_h_grad->data(); + auto init_c_grad_ptr = init_c_grad->data(); + const int& block_size = direction_num * seq_len * batch_size * hidden_size; + auto i_f_g_o_ptr = reserve_data->data(); + auto c_ptr = i_f_g_o_ptr + num_layers * block_size * 4; + auto hidden_data_ptr = c_ptr + num_layers * block_size * 1; + int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; std::vector seq_len_tensor(batch_size, seq_len); if (has_seq_length) { seq_len_tensor = operators::GetDataFromTensor(sequence_length); } - auto& dev_ctx = ctx.template device_context(); - int r = xpu::lstm_grad( - dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0, - (const T*)w_x, (const T*)w_h, (const T*)y, (const T*)y_grad, - (const T*)last_h_grad_ptr, (const T*)last_c_grad_ptr, - reinterpret_cast(x_grad), reinterpret_cast(h_0_grad), - reinterpret_cast(c_0_grad), w_x_grad, w_h_grad, b_x_grad, b_h_grad, - batch_size, input_dim, hidden_size, seq_len, seq_len_tensor, nullptr, - nullptr, nullptr, nullptr, i_f_g_o, c); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External("RnnXPUGrad(lstm) return wrong " - "value[%d %s]", - r, XPUAPIErrorMsg[r])); + for (int i = num_layers - 1; i >= 0; --i) { + // the layer input output had saved, just use the data + auto w_x = parameter_lists[i][0]; + auto w_h = parameter_lists[i][1]; + auto bw_x = parameter_lists[i][4]; + auto bw_h = parameter_lists[i][5]; + + auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; + auto c = c_ptr + i * block_size; + + Tensor layer_input_t; + auto layer_input = x; + if (i > 0) { + layer_input_t.Resize(output->dims()); + layer_input = layer_input_t.mutable_data(ctx.GetPlace()); + float scale = static_cast(1.0f - dropout_prob); + auto hidden_data = hidden_data_ptr + (i - 1) * block_size; + int r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(hidden_data), + const_cast(layer_input), output->numel(), + false, scale, 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } else { + layer_input = x; + } + + auto layer_output = y; + if (i == num_layers - 1) { + layer_output = y; + } else { + layer_output = hidden_data_ptr + i * block_size; + } + + const T* cur_input_ptr = nullptr; + if (i == num_layers - 1) { + cur_input_ptr = y_grad; + } else if (i % 2 != 0) { + cur_input_ptr = input_grad_2_ptr; + } else { + cur_input_ptr = input_grad_1_ptr; + } + + T* cur_output_ptr = nullptr; + int cur_xdim = -1; + if (i == 0) { + cur_output_ptr = x_grad; + cur_xdim = input_dim; + } else if (i % 2 != 0) { + cur_output_ptr = input_grad_1_ptr; + cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; + } else { + cur_output_ptr = input_grad_2_ptr; + cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; + } + + auto w_x_grad = parameter_lists_grad[i][0]; + auto w_h_grad = parameter_lists_grad[i][1]; + auto b_x_grad = parameter_lists_grad[i][2]; + auto b_h_grad = parameter_lists_grad[i][3]; + + auto h_0 = init_h_ptr + direction_num * i * state_offset; + auto c_0 = init_c_ptr + direction_num * i * state_offset; + + auto h_0_grad = init_h_grad_ptr + direction_num * i * state_offset; + auto c_0_grad = init_c_grad_ptr + direction_num * i * state_offset; + auto h_t_grad = last_h_grad_ptr + direction_num * i * state_offset; + auto c_t_grad = last_c_grad_ptr + direction_num * i * state_offset; + + if (is_bidirec) { + auto bw_x_grad = parameter_lists_grad[i][4]; + auto bw_h_grad = parameter_lists_grad[i][5]; + auto bb_x_grad = parameter_lists_grad[i][6]; + auto bb_h_grad = parameter_lists_grad[i][7]; + + int r = xpu::bilstm_grad( + dev_ctx.x_context(), (const T*)layer_input, (const T*)h_0, + (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)bw_x, + (const T*)bw_h, (const T*)layer_output, (const T*)cur_input_ptr, + (const T*)h_t_grad, (const T*)c_t_grad, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(h_0_grad), reinterpret_cast(c_0_grad), + w_x_grad, w_h_grad, b_x_grad, b_h_grad, bw_x_grad, bw_h_grad, + bb_x_grad, bb_h_grad, batch_size, cur_xdim, hidden_size, seq_len, + seq_len_tensor, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, i_f_g_o, c); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_grad"); + } else { + int r = xpu::lstm_grad( + dev_ctx.x_context(), (const T*)layer_input, (const T*)h_0, + (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)layer_output, + (const T*)cur_input_ptr, (const T*)h_t_grad, (const T*)c_t_grad, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(h_0_grad), reinterpret_cast(c_0_grad), + w_x_grad, w_h_grad, b_x_grad, b_h_grad, batch_size, cur_xdim, + hidden_size, seq_len, seq_len_tensor, nullptr, nullptr, nullptr, + nullptr, i_f_g_o, c); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_grad"); + } + } } }; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 9915b4d8d34..750a389940c 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -306,6 +306,7 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, diff --git a/python/paddle/fluid/tests/unittests/xpu/test_rnn_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_rnn_op_xpu.py index 20a3fc69fe8..84edbab1eac 100755 --- a/python/paddle/fluid/tests/unittests/xpu/test_rnn_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_rnn_op_xpu.py @@ -122,7 +122,7 @@ class XPUTestRNNOp(XPUOpTestWrapper): def set_xpu(self): self.__class__.use_xpu = True - self.__class__.no_need_check_grad = True + self.__class__.no_need_check_grad = False self.__class__.op_type = self.in_type def test_check_output(self): @@ -130,6 +130,15 @@ class XPUTestRNNOp(XPUOpTestWrapper): self.place, atol=0.01, no_check_set=['Reserve', 'DropoutState']) + def test_grad(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input', 'init_h', 'init_c'] + grad_check_list.extend(var_name_list) + self.check_grad_with_place(self.place, + set(grad_check_list), + ['Out', 'last_hidden', 'last_cell']) + def init_size(self): self.seq_length = 12 self.batch_size = 5 -- GitLab