From 3a59ede964922aa009e4bbbf7ebee36798e4a72c Mon Sep 17 00:00:00 2001 From: Chenxiao Niu Date: Fri, 1 Jul 2022 10:49:50 +0800 Subject: [PATCH] [MLU] add rnn backward kernel. (#43969) --- paddle/fluid/operators/mlu/mlu_baseop.cc | 82 ++++ paddle/fluid/operators/mlu/mlu_baseop.h | 24 ++ paddle/fluid/operators/rnn_op_mlu.cc | 397 +++++++++++++++++- .../tests/unittests/mlu/test_rnn_op_mlu.py | 55 +-- 4 files changed, 522 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 972bdefdf02..5531250f363 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -4616,6 +4616,88 @@ MLURNNDesc::~MLURNNDesc() { reservespace_size)); } +/* static */ void MLUCnnl::RNNBackward(const ExecutionContext& ctx, + const cnnlRNNDescriptor_t rnn_desc, + cnnlWgradMode_t add_grad, + const int dev_seq_lengths[], + const void* weight_param_ptr, + void* dweight_param_ptr, + size_t weightspace_size, + const cnnlSeqDataDescriptor_t x_desc, + const void* x, + void* dx, + const cnnlSeqDataDescriptor_t y_desc, + const void* y, + const void* dy, + const cnnlTensorDescriptor_t hx_desc, + const void* hx, + const void* dhy, + void* dhx, + const cnnlTensorDescriptor_t cx_desc, + const void* cx, + const void* dcy, + void* dcx, + void* reservespace_ptr, + size_t reservespace_size) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_NOT_NULL( + rnn_desc, + paddle::platform::errors::Fatal( + "MLU RNNForward failed. rnn_desc initializing failed.")); + PADDLE_ENFORCE_NOT_NULL( + x_desc, + paddle::platform::errors::Fatal( + "MLU RNNForward failed. x_desc initializing failed.")); + auto& dev_ctx = GetDevCtxFromCTX(ctx); + size_t workspace_size; + Tensor workspace; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNTempSizes( + handle, rnn_desc, x_desc, &workspace_size, &reservespace_size)); + workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlRNNBackwardData(handle, + rnn_desc, + dev_seq_lengths, + y_desc, + y, + dy, + x_desc, + dx, + hx_desc, + hx, + dhy, + dhx, + cx_desc, + cx, + dcy, + dcx, + weight_param_ptr, + weightspace_size, + workspace_ptr, + workspace_size, + reservespace_ptr, + reservespace_size)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlRNNBackwardWeights(handle, + rnn_desc, + add_grad, + dev_seq_lengths, + x_desc, + x, + hx_desc, + hx, + y_desc, + y, + dweight_param_ptr, + weightspace_size, + workspace_ptr, + workspace_size, + reservespace_ptr, + reservespace_size)); +} + /* static */ void MLUCnnl::Mask(const ExecutionContext& ctx, cnnlMaskedOp_t masked_mode, const cnnlTensorDescriptor_t input_desc, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 85f4439c3b9..07c5031ee2e 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1924,6 +1924,30 @@ class MLUCnnl { void* cy, void* reservespace_ptr); + static void RNNBackward(const ExecutionContext& ctx, + const cnnlRNNDescriptor_t rnn_desc, + cnnlWgradMode_t add_grad, + const int dev_seq_lengths[], + const void* weight_param_ptr, + void* dweight_param_ptr, + size_t weightspace_size, + const cnnlSeqDataDescriptor_t x_desc, + const void* x, + void* dx, + const cnnlSeqDataDescriptor_t y_desc, + const void* y, + const void* dy, + const cnnlTensorDescriptor_t hx_desc, + const void* hx, + const void* dhy, + void* dhx, + const cnnlTensorDescriptor_t cx_desc, + const void* cx, + const void* dcy, + void* dcx, + void* reservespace_ptr, + size_t reservespace_size); + static void Mask(const ExecutionContext& ctx, cnnlMaskedOp_t masked_mode, const cnnlTensorDescriptor_t input_desc, diff --git a/paddle/fluid/operators/rnn_op_mlu.cc b/paddle/fluid/operators/rnn_op_mlu.cc index 653c50c83b8..fe567333b6d 100644 --- a/paddle/fluid/operators/rnn_op_mlu.cc +++ b/paddle/fluid/operators/rnn_op_mlu.cc @@ -28,7 +28,7 @@ void reset_parameter_vector( const std::vector& raw_params_vec, const int& num_layers, const bool& is_bidirec, - std::vector>>* params_vec) { + std::vector>>* params_vec) { // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers @@ -47,7 +47,8 @@ void reset_parameter_vector( } using remove_cv_t = typename std::remove_cv::type; params_vec->at(i)[j] = std::make_pair( - raw_params_vec[tensor_idx]->template data(), + const_cast( + raw_params_vec[tensor_idx]->template data()), raw_params_vec[tensor_idx]->numel() * sizeof(T)); } } @@ -66,7 +67,6 @@ class RNNMLUKernel : public framework::OpKernel { // Output auto state = ctx.MultiOutput("State"); auto* output = ctx.Output("Out"); - // auto* dropout_mask = ctx.Output("DropoutState"); auto* reserve_data = ctx.Output("Reserve"); // Attributes const int& num_layers = ctx.Attr("num_layers"); @@ -79,14 +79,6 @@ class RNNMLUKernel : public framework::OpKernel { sequence_length = ctx.Input("SequenceLength"); } - // if (dropout_mask->IsInitialized()) { - // if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); - // } - // dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); - // auto& dev_ctx = ctx.template device_context(); - // phi::funcs::SetConstant ones; - // ones(dev_ctx, dropout_mask, static_cast(1)); - auto init_h = pre_state[0]; // -> hx auto init_c = pre_state[1]; // -> cx auto last_h = state[0]; @@ -143,7 +135,7 @@ class RNNMLUKernel : public framework::OpKernel { init_c->dims()[0])); // weightlist - std::vector>> parameter_lists; + std::vector>> parameter_lists; parameter_lists.resize(num_layers); reset_parameter_vector( weight_list, num_layers, is_bidirec, ¶meter_lists); @@ -363,9 +355,390 @@ class RNNMLUKernel : public framework::OpKernel { } }; +template +class RNNMLUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + auto stream = ctx.template device_context().stream(); + // get the tensor pointer for the input + auto* input = ctx.Input("Input"); + auto pre_state = ctx.MultiInput("PreState"); + auto weight_list = ctx.MultiInput("WeightList"); + auto* output = ctx.Input("Out"); + auto* reserve_data = ctx.Input("Reserve"); + const int& num_layers = ctx.Attr("num_layers"); + const bool& is_bidirec = ctx.Attr("is_bidirec"); + const int& hidden_size = ctx.Attr("hidden_size"); + const std::string& mode = ctx.Attr("mode"); + + bool has_seq_length = ctx.HasInput("SequenceLength"); + const Tensor* sequence_length = nullptr; + if (has_seq_length) { + sequence_length = ctx.Input("SequenceLength"); + } + + PADDLE_ENFORCE_EQ( + mode, + "LSTM", + platform::errors::InvalidArgument( + "XPU only support LSTM mode now, current mode is %s", mode)); + + auto init_h = pre_state[0]; // -> hx + auto init_c = pre_state[1]; // -> cx + + auto output_grad = ctx.Input(framework::GradVarName("Out")); + auto state_grad = ctx.MultiInput(framework::GradVarName("State")); + auto last_h_grad = state_grad[0]; // -> dhy + auto last_c_grad = state_grad[1]; // -> dcy + + // get the tensor pointer for the output + auto* input_grad = ctx.Output(framework::GradVarName("Input")); + auto weight_grad_list = ctx.MultiOutput( + framework::GradVarName("WeightList")); + auto pre_state_grad = + ctx.MultiOutput(framework::GradVarName("PreState")); + Tensor* init_h_grad = nullptr; + Tensor* init_c_grad = nullptr; + if (pre_state_grad.size() > 0) { // has gradient + init_h_grad = pre_state_grad[0]; // -> dhx + init_c_grad = pre_state_grad[1]; // -> dcx + } + + // check shape + const int in_out_dim_num = input->dims().size(); + const int& seq_len = input->dims()[0]; + const int& batch_size = input->dims()[1]; + const int& input_dim = input->dims()[2]; + const int& direction_num = is_bidirec ? 2 : 1; + int in_dim_arr[in_out_dim_num] = {seq_len, batch_size, input_dim}; + int out_dim_arr[in_out_dim_num] = { + seq_len, batch_size, direction_num * hidden_size}; + int proj_size = hidden_size; + PADDLE_ENFORCE_EQ( + num_layers, + 1, + platform::errors::InvalidArgument( + "MLU only support 1 num_layers, current num_layers is %s", + num_layers)); + PADDLE_ENFORCE_EQ( + init_h->dims()[0], + num_layers * direction_num, + platform::errors::InvalidArgument("The num_layers of in RNN layer must" + " be the same as first dim of init" + "hidden, but received num_layers:%d," + " dim:%d", + num_layers, + init_h->dims()[0])); + PADDLE_ENFORCE_EQ( + init_c->dims()[0], + num_layers * direction_num, + platform::errors::InvalidArgument( + "The num_layers of in RNN layer must" + " be the same as first dim of cell state hidden, but received" + " num_layers:%d, dim:%d", + num_layers, + init_c->dims()[0])); + + std::vector>> parameter_lists; + parameter_lists.resize(num_layers); + reset_parameter_vector( + weight_list, num_layers, is_bidirec, ¶meter_lists); + + for (unsigned int i = 0; i < weight_grad_list.size(); ++i) { + weight_grad_list[i]->mutable_data(ctx.GetPlace()); + } + std::vector>> parameter_lists_grad; + parameter_lists_grad.resize(num_layers); + reset_parameter_vector( + weight_grad_list, num_layers, is_bidirec, ¶meter_lists_grad); + + // allocate the memory and initization the input_grad + input_grad->mutable_data(input->dims(), ctx.GetPlace()); + FillMLUTensorWithHostValue(ctx, static_cast(0.0), input_grad); + + Tensor a, b; + Tensor* dynamic_grad_pre_h = &a; + Tensor* dynamic_grad_pre_c = &b; + if (init_h_grad) { + init_h_grad->mutable_data(last_h_grad->dims(), ctx.GetPlace()); + FillMLUTensorWithHostValue(ctx, static_cast(0.0), init_h_grad); + } else { + dynamic_grad_pre_h->Resize(last_h_grad->dims()); + dynamic_grad_pre_h->mutable_data(ctx.GetPlace()); + FillMLUTensorWithHostValue(ctx, static_cast(0.0), dynamic_grad_pre_h); + init_h_grad = dynamic_grad_pre_h; + } + if (init_c_grad) { + init_c_grad->mutable_data(last_c_grad->dims(), ctx.GetPlace()); + } else { + dynamic_grad_pre_c->Resize(last_h_grad->dims()); + dynamic_grad_pre_c->mutable_data(ctx.GetPlace()); + init_c_grad = dynamic_grad_pre_c; + } + + std::vector seq_len_vec(batch_size, seq_len); + if (has_seq_length) { + seq_len_vec = operators::GetDataFromTensor(sequence_length); + } + cnnlDirectionMode_t direction = + is_bidirec ? CNNL_RNN_BIDIRECTIONAL : CNNL_RNN_UNIDIRECTIONAL; + + MLUSeqDataDesc input_seq_data_desc(CNNL_SEQDATA_TNC, + ToCnnlDataType(input->dtype()), + in_out_dim_num, + in_dim_arr, + static_cast(seq_len_vec.size()), + seq_len_vec.data(), + nullptr); + MLUSeqDataDesc out_seq_data_desc(CNNL_SEQDATA_TNC, + ToCnnlDataType(input->dtype()), + in_out_dim_num, + out_dim_arr, + static_cast(seq_len_vec.size()), + seq_len_vec.data(), + nullptr); + MLUCnnlTensorDesc hx_desc(*init_h); + MLUCnnlTensorDesc cx_desc(*init_c); + MLURNNDesc rnn_desc(CNNL_LSTM, + CNNL_RNN_DOUBLE_BIAS, + direction, + CNNL_RNN_LINEAR_INPUT, + ToCnnlDataType(input->dtype()), + ToCnnlDataType(input->dtype()), + input_dim, + hidden_size, + /*projection*/ proj_size, + num_layers, + nullptr, + CNNL_RNN_PADDED_IO_DISABLED); + rnn_desc.SetRNNMaskMode(CNNL_LSTM_MASK_ENABLED); + + // copy weight + size_t weightspace_size; + framework::Tensor weightspace, dweightspace; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNWeightSpaceSize( + GetHandleFromCTX(ctx), rnn_desc.get(), &weightspace_size)); + + weightspace = ctx.AllocateTmpTensor( + {static_cast(weightspace_size)}, dev_ctx); + dweightspace = ctx.AllocateTmpTensor( + {static_cast(weightspace_size)}, dev_ctx); + void* weightspace_ptr = weightspace.mutable_data(ctx.GetPlace()); + auto w_x = parameter_lists[0][0]; + auto w_h = parameter_lists[0][1]; + auto b_x = parameter_lists[0][2]; + auto b_h = parameter_lists[0][3]; + auto actual_total_w_size = + w_x.second + w_h.second + b_x.second + b_h.second; + + void* w_x_ptr = weightspace_ptr; + void* w_h_ptr = static_cast(weightspace_ptr) + w_x.second; + void* b_x_ptr = + static_cast(weightspace_ptr) + w_x.second + w_h.second; + void* b_h_ptr = static_cast(weightspace_ptr) + w_x.second + + w_h.second + b_x.second; + + memory::Copy(weightspace.place(), + w_x_ptr, + weightspace.place(), + w_x.first, + w_x.second, + stream); + memory::Copy(weightspace.place(), + w_h_ptr, + weightspace.place(), + w_h.first, + w_h.second, + stream); + memory::Copy(weightspace.place(), + b_x_ptr, + weightspace.place(), + b_x.first, + b_x.second, + stream); + memory::Copy(weightspace.place(), + b_h_ptr, + weightspace.place(), + b_h.first, + b_h.second, + stream); + + if (is_bidirec) { + auto bw_x = parameter_lists[0][4]; + auto bw_h = parameter_lists[0][5]; + auto bb_x = parameter_lists[0][6]; + auto bb_h = parameter_lists[0][7]; + void* bw_x_ptr = + static_cast(weightspace_ptr) + actual_total_w_size; + void* bw_h_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second; + void* bb_x_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second + bw_h.second; + void* bb_h_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second + bw_h.second + + bb_x.second; + actual_total_w_size += + bw_x.second + bw_h.second + bb_x.second + bb_h.second; + + memory::Copy(weightspace.place(), + bw_x_ptr, + weightspace.place(), + bw_x.first, + bw_x.second, + stream); + memory::Copy(weightspace.place(), + bw_h_ptr, + weightspace.place(), + bw_h.first, + bw_h.second, + stream); + memory::Copy(weightspace.place(), + bb_x_ptr, + weightspace.place(), + bb_x.first, + bb_x.second, + stream); + memory::Copy(weightspace.place(), + bb_h_ptr, + weightspace.place(), + bb_h.first, + bb_h.second, + stream); + } + dev_ctx.Wait(); + + PADDLE_ENFORCE_EQ(weightspace_size, + actual_total_w_size, + platform::errors::InvalidArgument( + "The weightsize doesn't match" + " weightspace_size:%d, actual_total_w_size:%d", + weightspace_size, + actual_total_w_size)); + + MLUCnnl::RNNBackward(ctx, + rnn_desc.get(), + CNNL_WGRAD_MODE_SET, + seq_len_vec.data(), + GetBasePtr(&weightspace), + GetBasePtr(&dweightspace), + weightspace.numel() * sizeof(T), + input_seq_data_desc.get(), + GetBasePtr(input), + GetBasePtr(input_grad), + out_seq_data_desc.get(), + GetBasePtr(output), + GetBasePtr(output_grad), + hx_desc.get(), + GetBasePtr(init_h), + GetBasePtr(last_h_grad), + GetBasePtr(init_h_grad), + cx_desc.get(), + GetBasePtr(init_c), + GetBasePtr(last_c_grad), + GetBasePtr(init_c_grad), + const_cast(GetBasePtr(reserve_data)), + reserve_data->numel() * sizeof(T)); + + void* dweightspace_ptr = dweightspace.mutable_data(ctx.GetPlace()); + auto dw_x = parameter_lists_grad[0][0]; + auto dw_h = parameter_lists_grad[0][1]; + auto db_x = parameter_lists_grad[0][2]; + auto db_h = parameter_lists_grad[0][3]; + auto dactual_total_w_size = + dw_x.second + dw_h.second + db_x.second + db_h.second; + + void* dw_x_ptr = dweightspace_ptr; + void* dw_h_ptr = static_cast(dweightspace_ptr) + dw_x.second; + void* db_x_ptr = + static_cast(dweightspace_ptr) + dw_x.second + dw_h.second; + void* db_h_ptr = static_cast(dweightspace_ptr) + dw_x.second + + dw_h.second + db_x.second; + + memory::Copy(weightspace.place(), + dw_x.first, + weightspace.place(), + dw_x_ptr, + dw_x.second, + stream); + memory::Copy(weightspace.place(), + dw_h.first, + weightspace.place(), + dw_h_ptr, + dw_h.second, + stream); + memory::Copy(weightspace.place(), + db_x.first, + weightspace.place(), + db_x_ptr, + db_x.second, + stream); + memory::Copy(weightspace.place(), + db_h.first, + weightspace.place(), + db_h_ptr, + db_h.second, + stream); + + if (is_bidirec) { + auto dbw_x = parameter_lists_grad[0][4]; + auto dbw_h = parameter_lists_grad[0][5]; + auto dbb_x = parameter_lists_grad[0][6]; + auto dbb_h = parameter_lists_grad[0][7]; + void* dbw_x_ptr = + static_cast(dweightspace_ptr) + dactual_total_w_size; + void* dbw_h_ptr = static_cast(dweightspace_ptr) + + dactual_total_w_size + dbw_x.second; + void* dbb_x_ptr = static_cast(dweightspace_ptr) + + dactual_total_w_size + dbw_x.second + dbw_h.second; + void* dbb_h_ptr = static_cast(dweightspace_ptr) + + dactual_total_w_size + dbw_x.second + dbw_h.second + + dbb_x.second; + dactual_total_w_size += + dbw_x.second + dbw_h.second + dbb_x.second + dbb_h.second; + + memory::Copy(weightspace.place(), + dbw_x.first, + weightspace.place(), + dbw_x_ptr, + dbw_x.second, + stream); + memory::Copy(weightspace.place(), + dbw_h.first, + weightspace.place(), + dbw_h_ptr, + dbw_h.second, + stream); + memory::Copy(weightspace.place(), + dbb_x.first, + weightspace.place(), + dbb_x_ptr, + dbb_x.second, + stream); + memory::Copy(weightspace.place(), + dbb_h.first, + weightspace.place(), + dbb_h_ptr, + dbb_h.second, + stream); + } + dev_ctx.Wait(); + + PADDLE_ENFORCE_EQ(weightspace_size, + dactual_total_w_size, + platform::errors::InvalidArgument( + "The weightsize doesn't match" + " weightspace_size:%d, dactual_total_w_size:%d", + weightspace_size, + dactual_total_w_size)); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_MLU_KERNEL( rnn, ops::RNNMLUKernel); +REGISTER_OP_MLU_KERNEL( + rnn_grad, ops::RNNMLUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py index f1aabbd3b60..917597daf3a 100644 --- a/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py +++ b/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py @@ -135,43 +135,50 @@ class TestRNNOp(OpTest): def test_output(self): self.check_output_with_place( - self.place, no_check_set=['Reserve', 'DropoutState', 'State']) + self.place, + atol=1e-4, + no_check_set=['Reserve', 'DropoutState', 'State']) def set_attrs(self): pass - # def test_grad(self): - # if not self.is_test: - # var_name_list = self.get_weight_names() - # grad_check_list = ['Input', 'init_h', 'init_c'] - # grad_check_list.extend(var_name_list) - # self.check_grad_with_place(self.place, set(grad_check_list), - # ['Out', 'last_hidden', 'last_cell']) + def test_grad(self): + if not self.is_test and self.sequence_length is None: + # if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input', 'init_h', 'init_c'] + grad_check_list.extend(var_name_list) + self.check_grad_with_place(self.place, set(grad_check_list), + ['Out', 'last_hidden', 'last_cell']) -# class TestRNNOp1(TestRNNOp): +class TestRNNOp1(TestRNNOp): -# def set_attrs(self): -# self.sequence_length = None + def set_attrs(self): + self.sequence_length = None -# class TestRNNOp2(TestRNNOp): -# def set_attrs(self): -# self.sequence_length = None -# self.is_bidirec = True +class TestRNNOp2(TestRNNOp): -# class TestRNNOp3(TestRNNOp): + def set_attrs(self): + self.sequence_length = None + self.is_bidirec = True -# def set_attrs(self): -# self.is_test = True -# self.sequence_length = None -# class TestRNNOp4(TestRNNOp): +class TestRNNOp3(TestRNNOp): + + def set_attrs(self): + self.is_test = True + self.sequence_length = None + + +class TestRNNOp4(TestRNNOp): + + def set_attrs(self): + self.is_test = True + self.sequence_length = None + self.is_bidirec = True -# def set_attrs(self): -# self.is_test = True -# self.sequence_length = None -# self.is_bidirec = True #TODO(chenxiao): cnnl doesn't support num_layers > 1 case # class TestRNNOp5(TestRNNOp): -- GitLab