diff --git a/paddle/fluid/operators/rnn_op_xpu.cc b/paddle/fluid/operators/rnn_op_xpu.cc deleted file mode 100644 index ee81c0c148a1afed12cc4623eb059265f58db51a..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/rnn_op_xpu.cc +++ /dev/null @@ -1,571 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using DDim = framework::DDim; -using TensorList = std::vector; -template -void reset_parameter_vector(const std::vector& raw_params_vec, - const int& num_layers, - const bool& is_bidirec, - std::vector>* params_vec) { - // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers - // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to - // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers - const int& direction_num = is_bidirec ? 2 : 1; - const int& layer_weight_size = 4 * direction_num; - const int& all_weight_size = num_layers * layer_weight_size; - const int& bias_start_idx = all_weight_size / 2; - for (int i = 0; i < num_layers; i++) { - params_vec->at(i).resize(layer_weight_size); - for (int j = 0; j < layer_weight_size; j++) { - int k = j % 4; - const int& section = j / 4; - int tensor_idx = i * 2 * direction_num + section * 2 + k % 2; - if (k >= 2) { - tensor_idx += bias_start_idx; - } - using remove_cv_t = typename std::remove_cv::type; - params_vec->at(i)[j] = - raw_params_vec[tensor_idx]->template data(); - } - } -} - -template -class RnnXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // Input - auto* input = ctx.Input("Input"); - auto pre_state = ctx.MultiInput("PreState"); - auto weight_list = ctx.MultiInput("WeightList"); - bool has_seq_length = ctx.HasInput("SequenceLength"); - // Output - auto state = ctx.MultiOutput("State"); - auto* output = ctx.Output("Out"); - auto* dropout_mask = ctx.Output("DropoutState"); - auto* reserve_data = ctx.Output("Reserve"); - // Attributes - const int& num_layers = ctx.Attr("num_layers"); - const bool& is_bidirec = ctx.Attr("is_bidirec"); - const int& hidden_size = ctx.Attr("hidden_size"); - const std::string& mode = ctx.Attr("mode"); - - const Tensor* sequence_length = nullptr; - if (has_seq_length) { - sequence_length = ctx.Input("SequenceLength"); - } - - if (dropout_mask->IsInitialized()) { - if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); - } - dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant ones; - ones(dev_ctx, dropout_mask, static_cast(1)); - - PADDLE_ENFORCE_EQ( - mode, - "LSTM", - platform::errors::InvalidArgument( - "XPU only support LSTM mode now, current mode is %s", mode)); - - auto init_h = pre_state[0]; - auto init_c = pre_state[1]; - auto last_h = state[0]; - auto last_c = state[1]; - - // check shape - const int& seq_len = input->dims()[0]; // time_step - const int& batch_size = input->dims()[1]; - const int& input_dim = input->dims()[2]; - const int& direction_num = is_bidirec ? 2 : 1; - - PADDLE_ENFORCE_EQ( - init_h->dims()[0], - num_layers * direction_num, - platform::errors::InvalidArgument("The num_layers of in RNN layer must" - " be the same as first dim of init " - "hidden, but received num_layers:%d," - " dim:%d", - num_layers, - init_h->dims()[0])); - - PADDLE_ENFORCE_EQ( - init_c->dims()[0], - num_layers * direction_num, - platform::errors::InvalidArgument( - "The num_layers of in RNN layer must" - " be the same as first dim of cell state hidden, but received" - " num_layers:%d, dim:%d", - num_layers, - init_c->dims()[0])); - // weightlist - std::vector> parameter_lists; - parameter_lists.resize(num_layers); - reset_parameter_vector( - weight_list, num_layers, is_bidirec, ¶meter_lists); - - // init the output and allocate the memory - output->mutable_data(ctx.GetPlace()); - last_h->mutable_data(ctx.GetPlace()); - last_c->mutable_data(ctx.GetPlace()); - int gate_num = 4; - int hidden_data_idx = (num_layers - 1); - hidden_data_idx += (gate_num + 1) * num_layers; - const int& block_size = direction_num * seq_len * batch_size * hidden_size; - reserve_data->Resize({hidden_data_idx, block_size}); - - reserve_data->mutable_data(ctx.GetPlace()); - // get ptr from tensor - auto x = input->data(); - auto init_h_ptr = init_h->data(); - auto init_c_ptr = init_c->data(); - auto y = output->data(); - auto last_h_ptr = last_h->data(); - auto last_c_ptr = last_c->data(); - auto i_f_g_o_ptr = reserve_data->data(); - auto c_ptr = - i_f_g_o_ptr + num_layers * block_size * 4; // 4 for i_f_g_o offset - auto hidden_data_ptr = - c_ptr + num_layers * block_size * 1; // 1 for c offset - - std::vector seq_len_tensor(batch_size, seq_len); - if (has_seq_length) { - seq_len_tensor = operators::GetDataFromTensor(sequence_length); - } - - int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; - - const T* cur_input_ptr = nullptr; - int cur_xdim = -1; - T* cur_output_ptr = y; - for (int i = 0; i < num_layers; i++) { - auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; - auto c = c_ptr + i * block_size; - - cur_output_ptr = y; - if (i < num_layers - 1 && num_layers > 1) { - cur_output_ptr = hidden_data_ptr + i * block_size; - } - - if (i == 0) { - cur_input_ptr = x; - cur_xdim = input_dim; - } else { - cur_input_ptr = hidden_data_ptr + (i - 1) * block_size; - cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; - } - - auto h_0 = init_h_ptr + direction_num * i * state_offset; - auto c_0 = init_c_ptr + direction_num * i * state_offset; - auto last_h = last_h_ptr + direction_num * i * state_offset; - auto last_c = last_c_ptr + direction_num * i * state_offset; - - auto w_x = parameter_lists[i][0]; - auto w_h = parameter_lists[i][1]; - auto b_x = parameter_lists[i][2]; - auto b_h = parameter_lists[i][3]; - if (is_bidirec) { - auto bw_x = parameter_lists[i][4]; - auto bw_h = parameter_lists[i][5]; - auto bb_x = parameter_lists[i][6]; - auto bb_h = parameter_lists[i][7]; - - int r = xpu::bilstm_train( - dev_ctx.x_context(), - (const T*)cur_input_ptr, - (const T*)h_0, - (const T*)c_0, - (const T*)w_x, - (const T*)w_h, - (const T*)b_x, - (const T*)b_h, - (const T*)bw_x, - (const T*)bw_h, - (const T*)bb_x, - (const T*)bb_h, - reinterpret_cast(cur_output_ptr), - reinterpret_cast(last_h), - reinterpret_cast(last_c), - batch_size, - cur_xdim, - hidden_size, - seq_len, - seq_len_tensor, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - reinterpret_cast(i_f_g_o), - reinterpret_cast(c)); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_train"); - } else { - int r = - xpu::lstm_train(dev_ctx.x_context(), - (const T*)cur_input_ptr, - (const T*)h_0, - (const T*)c_0, - (const T*)w_x, - (const T*)w_h, - (const T*)b_x, - (const T*)b_h, - reinterpret_cast(cur_output_ptr), - reinterpret_cast(last_h), - reinterpret_cast(last_c), - batch_size, - cur_xdim, - hidden_size, - seq_len, - seq_len_tensor, - nullptr, - nullptr, - nullptr, - nullptr, - reinterpret_cast(i_f_g_o), - reinterpret_cast(c), - xpu::Activation_t::TANH, - xpu::Activation_t::SIGMOID); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train"); - } - } - } -}; - -template -class RnnXPUGradKernel : public framework::OpKernel { - using XPUTyp = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // get the tensor pointer for the input - auto* input = ctx.Input("Input"); - auto pre_state = ctx.MultiInput("PreState"); - auto weight_list = ctx.MultiInput("WeightList"); - auto* output = ctx.Input("Out"); - auto* reserve_data = ctx.Input("Reserve"); - const int& num_layers = ctx.Attr("num_layers"); - const bool& is_bidirec = ctx.Attr("is_bidirec"); - const float& dropout_prob = ctx.Attr("dropout_prob"); - const int& hidden_size = ctx.Attr("hidden_size"); - const std::string& mode = ctx.Attr("mode"); - - bool has_seq_length = ctx.HasInput("SequenceLength"); - const Tensor* sequence_length = nullptr; - if (has_seq_length) { - sequence_length = ctx.Input("SequenceLength"); - } - - PADDLE_ENFORCE_EQ( - mode, - "LSTM", - platform::errors::InvalidArgument( - "XPU only support LSTM mode now, current mode is %s", mode)); - - auto init_h = pre_state[0]; - auto init_c = pre_state[1]; - - auto output_grad = ctx.Input(framework::GradVarName("Out")); - auto state_grad = ctx.MultiInput(framework::GradVarName("State")); - auto last_h_grad = state_grad[0]; - auto last_c_grad = state_grad[1]; - - // get the tensor pointer for the output - auto* input_grad = ctx.Output(framework::GradVarName("Input")); - auto weight_grad_list = ctx.MultiOutput( - framework::GradVarName("WeightList")); - auto pre_state_grad = - ctx.MultiOutput(framework::GradVarName("PreState")); - Tensor* init_h_grad = nullptr; - Tensor* init_c_grad = nullptr; - if (pre_state_grad.size() > 0) { // has gradient - init_h_grad = pre_state_grad[0]; - init_c_grad = pre_state_grad[1]; - } - - // check shape - const int& seq_len = input->dims()[0]; - const int& batch_size = input->dims()[1]; - const int& input_dim = input->dims()[2]; - const int& direction_num = is_bidirec ? 2 : 1; - PADDLE_ENFORCE_EQ( - init_h->dims()[0], - num_layers * direction_num, - platform::errors::InvalidArgument("The num_layers of in RNN layer must" - " be the same as first dim of init " - "hidden, but received num_layers:%d," - " dim:%d", - num_layers, - init_h->dims()[0])); - - PADDLE_ENFORCE_EQ( - init_c->dims()[0], - num_layers * direction_num, - platform::errors::InvalidArgument( - "The num_layers of in RNN layer must" - " be the same as first dim of cell state hidden, but received" - " num_layers:%d, dim:%d", - num_layers, - init_c->dims()[0])); - - std::vector> parameter_lists; - parameter_lists.resize(num_layers); - reset_parameter_vector( - weight_list, num_layers, is_bidirec, ¶meter_lists); - - for (unsigned int i = 0; i < weight_grad_list.size(); ++i) { - weight_grad_list[i]->mutable_data(ctx.GetPlace()); - } - std::vector> parameter_lists_grad; - parameter_lists_grad.resize(num_layers); - reset_parameter_vector( - weight_grad_list, num_layers, is_bidirec, ¶meter_lists_grad); - - // allocate the memory and initization the input_grad - input_grad->mutable_data(input->dims(), ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant zero; - zero(dev_ctx, input_grad, static_cast(0.0)); - - Tensor a, b; - Tensor* dynamic_grad_pre_h = &a; - Tensor* dynamic_grad_pre_c = &b; - if (init_h_grad) { - init_h_grad->mutable_data(last_h_grad->dims(), ctx.GetPlace()); - zero(dev_ctx, init_h_grad, static_cast(0.0)); - } else { - dynamic_grad_pre_h->Resize(last_h_grad->dims()); - dynamic_grad_pre_h->mutable_data(ctx.GetPlace()); - zero(dev_ctx, dynamic_grad_pre_h, static_cast(0.0)); - init_h_grad = dynamic_grad_pre_h; - } - if (init_c_grad) { - init_c_grad->mutable_data(last_c_grad->dims(), ctx.GetPlace()); - } else { - dynamic_grad_pre_c->Resize(last_h_grad->dims()); - dynamic_grad_pre_c->mutable_data(ctx.GetPlace()); - init_c_grad = dynamic_grad_pre_c; - } - - Tensor temp_input_grad_1, temp_input_grad_2; - T* input_grad_1_ptr = nullptr; - T* input_grad_2_ptr = nullptr; - if (num_layers >= 2) { - temp_input_grad_1.Resize(output_grad->dims()); - input_grad_1_ptr = temp_input_grad_1.mutable_data(ctx.GetPlace()); - } - if (num_layers >= 3) { - temp_input_grad_2.Resize(output_grad->dims()); - input_grad_2_ptr = temp_input_grad_2.mutable_data(ctx.GetPlace()); - } - - // get ptr from tensor - auto x = input->data(); - auto init_h_ptr = init_h->data(); - auto init_c_ptr = init_c->data(); - auto y = output->data(); - auto y_grad = output_grad->data(); - auto last_h_grad_ptr = last_h_grad->data(); - auto last_c_grad_ptr = last_c_grad->data(); - auto x_grad = input_grad->data(); - auto init_h_grad_ptr = init_h_grad->data(); - auto init_c_grad_ptr = init_c_grad->data(); - const int& block_size = direction_num * seq_len * batch_size * hidden_size; - auto i_f_g_o_ptr = reserve_data->data(); - auto c_ptr = i_f_g_o_ptr + num_layers * block_size * 4; - auto hidden_data_ptr = c_ptr + num_layers * block_size * 1; - int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; - - std::vector seq_len_tensor(batch_size, seq_len); - if (has_seq_length) { - seq_len_tensor = operators::GetDataFromTensor(sequence_length); - } - - for (int i = num_layers - 1; i >= 0; --i) { - // the layer input output had saved, just use the data - auto w_x = parameter_lists[i][0]; - auto w_h = parameter_lists[i][1]; - auto bw_x = parameter_lists[i][4]; - auto bw_h = parameter_lists[i][5]; - - auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; - auto c = c_ptr + i * block_size; - - Tensor layer_input_t; - auto layer_input = x; - if (i > 0) { - layer_input_t.Resize(output->dims()); - layer_input = layer_input_t.mutable_data(ctx.GetPlace()); - float scale = static_cast(1.0f - dropout_prob); - auto hidden_data = hidden_data_ptr + (i - 1) * block_size; - int r = xpu::scale(dev_ctx.x_context(), - reinterpret_cast(hidden_data), - const_cast(layer_input), - output->numel(), - false, - scale, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - } else { - layer_input = x; - } - - auto layer_output = y; - if (i == num_layers - 1) { - layer_output = y; - } else { - layer_output = hidden_data_ptr + i * block_size; - } - - const T* cur_input_ptr = nullptr; - if (i == num_layers - 1) { - cur_input_ptr = y_grad; - } else if (i % 2 != 0) { - cur_input_ptr = input_grad_2_ptr; - } else { - cur_input_ptr = input_grad_1_ptr; - } - - T* cur_output_ptr = nullptr; - int cur_xdim = -1; - if (i == 0) { - cur_output_ptr = x_grad; - cur_xdim = input_dim; - } else if (i % 2 != 0) { - cur_output_ptr = input_grad_1_ptr; - cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; - } else { - cur_output_ptr = input_grad_2_ptr; - cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; - } - - auto w_x_grad = parameter_lists_grad[i][0]; - auto w_h_grad = parameter_lists_grad[i][1]; - auto b_x_grad = parameter_lists_grad[i][2]; - auto b_h_grad = parameter_lists_grad[i][3]; - - auto h_0 = init_h_ptr + direction_num * i * state_offset; - auto c_0 = init_c_ptr + direction_num * i * state_offset; - - auto h_0_grad = init_h_grad_ptr + direction_num * i * state_offset; - auto c_0_grad = init_c_grad_ptr + direction_num * i * state_offset; - auto h_t_grad = last_h_grad_ptr + direction_num * i * state_offset; - auto c_t_grad = last_c_grad_ptr + direction_num * i * state_offset; - - if (is_bidirec) { - auto bw_x_grad = parameter_lists_grad[i][4]; - auto bw_h_grad = parameter_lists_grad[i][5]; - auto bb_x_grad = parameter_lists_grad[i][6]; - auto bb_h_grad = parameter_lists_grad[i][7]; - - int r = xpu::bilstm_grad( - dev_ctx.x_context(), - (const T*)layer_input, - (const T*)h_0, - (const T*)c_0, - (const T*)w_x, - (const T*)w_h, - (const T*)bw_x, - (const T*)bw_h, - (const T*)layer_output, - (const T*)cur_input_ptr, - (const T*)h_t_grad, - (const T*)c_t_grad, - reinterpret_cast(cur_output_ptr), - reinterpret_cast(h_0_grad), - reinterpret_cast(c_0_grad), - w_x_grad, - w_h_grad, - b_x_grad, - b_h_grad, - bw_x_grad, - bw_h_grad, - bb_x_grad, - bb_h_grad, - batch_size, - cur_xdim, - hidden_size, - seq_len, - seq_len_tensor, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - i_f_g_o, - c); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_grad"); - } else { - int r = - xpu::lstm_grad(dev_ctx.x_context(), - (const T*)layer_input, - (const T*)h_0, - (const T*)c_0, - (const T*)w_x, - (const T*)w_h, - (const T*)layer_output, - (const T*)cur_input_ptr, - (const T*)h_t_grad, - (const T*)c_t_grad, - reinterpret_cast(cur_output_ptr), - reinterpret_cast(h_0_grad), - reinterpret_cast(c_0_grad), - w_x_grad, - w_h_grad, - b_x_grad, - b_h_grad, - batch_size, - cur_xdim, - hidden_size, - seq_len, - seq_len_tensor, - nullptr, - nullptr, - nullptr, - nullptr, - i_f_g_o, - c); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_grad"); - } - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - rnn, ops::RnnXPUKernel); -REGISTER_OP_XPU_KERNEL( - rnn_grad, ops::RnnXPUGradKernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/platform/device/npu/CMakeLists.txt b/paddle/fluid/platform/device/npu/CMakeLists.txt index 417b0f9ab6e1a2848fa1e0745639d34bc2c1f072..4f3d842cb0cdc36df5fa9ff9e7fa470dbd037e69 100644 --- a/paddle/fluid/platform/device/npu/CMakeLists.txt +++ b/paddle/fluid/platform/device/npu/CMakeLists.txt @@ -31,3 +31,9 @@ if(WITH_ASCEND_CL) SRCS npu_op_runner.cc DEPS operator npu_info) endif() + +# every source file that includes "dnnl.h" must depends on mkldnn +# or, the first one should depends on mkldnn +if(WITH_MKLDNN) + add_dependencies(npu_collective_helper mkldnn) +endif() diff --git a/paddle/phi/kernels/funcs/math_function.cc b/paddle/phi/kernels/funcs/math_function.cc index 84936d1e20c0ea695e42ac804ed69d0d45d190a8..19bbec124f2ca0319603c0eb5720f4bff1a610ba 100644 --- a/paddle/phi/kernels/funcs/math_function.cc +++ b/paddle/phi/kernels/funcs/math_function.cc @@ -67,6 +67,19 @@ template struct SetConstant>; template struct SetConstant>; + +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; + #endif #define DEFINE_CPU_TRANS(RANK) \ diff --git a/paddle/phi/kernels/xpu/rnn_grad_kernel.cc b/paddle/phi/kernels/xpu/rnn_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc68fa6f15d2e5be1964b080ad6a5571e8decb6e --- /dev/null +++ b/paddle/phi/kernels/xpu/rnn_grad_kernel.cc @@ -0,0 +1,326 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rnn_grad_kernel.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/xpu/rnn_util.h" + +namespace phi { + +template +void RnnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& pre_state, + const std::vector& weight_list, + const paddle::optional& sequence_length, + const DenseTensor& out, + const DenseTensor& dropout_state, + const DenseTensor& reserve, + const DenseTensor& out_grad, + const std::vector& state_grad, + float dropout_prob, + bool is_bidirec, + int input_size, + int hidden_size, + int num_layers, + const std::string& mode, + int seed, + bool is_test, + DenseTensor* x_grad, + std::vector pre_state_grad, + std::vector weight_grad_list) { + using XPUTyp = typename XPUTypeTrait::Type; + + PADDLE_ENFORCE_EQ( + mode, + "LSTM", + errors::InvalidArgument( + "XPU only support LSTM mode now, current mode is %s", mode)); + + auto init_h = pre_state[0]; + auto init_c = pre_state[1]; + + auto last_h_grad = state_grad[0]; + auto last_c_grad = state_grad[1]; + + // get the tensor pointer for the output + DenseTensor* init_h_grad = nullptr; + DenseTensor* init_c_grad = nullptr; + if (pre_state_grad.size() > 0) { // has gradient + init_h_grad = pre_state_grad[0]; + init_c_grad = pre_state_grad[1]; + } + + // check shape + const int& seq_len = x.dims()[0]; + const int& batch_size = x.dims()[1]; + const int& input_dim = x.dims()[2]; + const int& direction_num = is_bidirec ? 2 : 1; + PADDLE_ENFORCE_EQ( + init_h->dims()[0], + num_layers * direction_num, + errors::InvalidArgument("The num_layers of in RNN layer must" + " be the same as first dim of init " + "hidden, but received num_layers:%d," + " dim:%d", + num_layers, + init_h->dims()[0])); + + PADDLE_ENFORCE_EQ( + init_c->dims()[0], + num_layers * direction_num, + errors::InvalidArgument( + "The num_layers of in RNN layer must" + " be the same as first dim of cell state hidden, but received" + " num_layers:%d, dim:%d", + num_layers, + init_c->dims()[0])); + + std::vector> parameter_lists; + parameter_lists.resize(num_layers); + reset_parameter_vector(weight_list, num_layers, is_bidirec, ¶meter_lists); + + for (unsigned int i = 0; i < weight_grad_list.size(); ++i) { + dev_ctx.template Alloc(weight_grad_list[i]); + } + std::vector> parameter_lists_grad; + parameter_lists_grad.resize(num_layers); + reset_parameter_vector( + weight_grad_list, num_layers, is_bidirec, ¶meter_lists_grad); + + // allocate the memory and initization the x_grad + x_grad->Resize(x.dims()); + dev_ctx.template Alloc(x_grad); + + phi::funcs::SetConstant zero; + zero(dev_ctx, x_grad, static_cast(0.0)); + + DenseTensor a, b; + DenseTensor* dynamic_grad_pre_h = &a; + DenseTensor* dynamic_grad_pre_c = &b; + if (init_h_grad) { + init_h_grad->Resize(last_h_grad->dims()); + dev_ctx.template Alloc(init_h_grad); + + zero(dev_ctx, init_h_grad, static_cast(0.0)); + } else { + dynamic_grad_pre_h->Resize(last_h_grad->dims()); + dev_ctx.template Alloc(dynamic_grad_pre_h); + + zero(dev_ctx, dynamic_grad_pre_h, static_cast(0.0)); + init_h_grad = dynamic_grad_pre_h; + } + if (init_c_grad) { + init_c_grad->Resize(last_c_grad->dims()); + dev_ctx.template Alloc(init_c_grad); + } else { + dynamic_grad_pre_c->Resize(last_h_grad->dims()); + dev_ctx.template Alloc(dynamic_grad_pre_c); + init_c_grad = dynamic_grad_pre_c; + } + + DenseTensor temp_input_grad_1, temp_input_grad_2; + T* input_grad_1_ptr = nullptr; + T* input_grad_2_ptr = nullptr; + if (num_layers >= 2) { + temp_input_grad_1.Resize(x_grad->dims()); + input_grad_1_ptr = dev_ctx.template Alloc(&temp_input_grad_1); + } + if (num_layers >= 3) { + temp_input_grad_2.Resize(x_grad->dims()); + input_grad_2_ptr = dev_ctx.template Alloc(&temp_input_grad_2); + } + + // get ptr from tensor + auto x_data = x.data(); + auto init_h_ptr = init_h->data(); + auto init_c_ptr = init_c->data(); + auto y = out.data(); + auto y_grad = out_grad.data(); + auto last_h_grad_ptr = last_h_grad->data(); + auto last_c_grad_ptr = last_c_grad->data(); + auto x_grad_data = x_grad->data(); + auto init_h_grad_ptr = init_h_grad->data(); + auto init_c_grad_ptr = init_c_grad->data(); + const int& block_size = direction_num * seq_len * batch_size * hidden_size; + auto i_f_g_o_ptr = reserve.data(); + auto c_ptr = i_f_g_o_ptr + num_layers * block_size * 4; + auto hidden_data_ptr = c_ptr + num_layers * block_size * 1; + int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; + + bool has_seq_length = sequence_length.is_initialized(); + std::vector seq_len_tensor(batch_size, seq_len); + if (has_seq_length) { + seq_len_tensor = + paddle::operators::GetDataFromTensor(sequence_length.get_ptr()); + } + + for (int i = num_layers - 1; i >= 0; --i) { + // the layer input output had saved, just use the data + auto w_x = parameter_lists[i][0]; + auto w_h = parameter_lists[i][1]; + auto bw_x = parameter_lists[i][4]; + auto bw_h = parameter_lists[i][5]; + + auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; + auto c = c_ptr + i * block_size; + + DenseTensor layer_input_t; + auto layer_input = x_data; + if (i > 0) { + layer_input_t.Resize(out.dims()); + layer_input = dev_ctx.template Alloc(&layer_input_t); + float scale = static_cast(1.0f - dropout_prob); + auto hidden_data = hidden_data_ptr + (i - 1) * block_size; + int r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(hidden_data), + const_cast(layer_input), + out.numel(), + false, + scale, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } else { + layer_input = x_data; + } + + auto layer_output = y; + if (i == num_layers - 1) { + layer_output = y; + } else { + layer_output = hidden_data_ptr + i * block_size; + } + + const T* cur_input_ptr = nullptr; + if (i == num_layers - 1) { + cur_input_ptr = y_grad; + } else if (i % 2 != 0) { + cur_input_ptr = input_grad_2_ptr; + } else { + cur_input_ptr = input_grad_1_ptr; + } + + T* cur_output_ptr = nullptr; + int cur_xdim = -1; + if (i == 0) { + cur_output_ptr = x_grad_data; + cur_xdim = input_dim; + } else if (i % 2 != 0) { + cur_output_ptr = input_grad_1_ptr; + cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; + } else { + cur_output_ptr = input_grad_2_ptr; + cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; + } + + auto w_x_grad = parameter_lists_grad[i][0]; + auto w_h_grad = parameter_lists_grad[i][1]; + auto b_x_grad = parameter_lists_grad[i][2]; + auto b_h_grad = parameter_lists_grad[i][3]; + + auto h_0 = init_h_ptr + direction_num * i * state_offset; + auto c_0 = init_c_ptr + direction_num * i * state_offset; + + auto h_0_grad = init_h_grad_ptr + direction_num * i * state_offset; + auto c_0_grad = init_c_grad_ptr + direction_num * i * state_offset; + auto h_t_grad = last_h_grad_ptr + direction_num * i * state_offset; + auto c_t_grad = last_c_grad_ptr + direction_num * i * state_offset; + + if (is_bidirec) { + auto bw_x_grad = parameter_lists_grad[i][4]; + auto bw_h_grad = parameter_lists_grad[i][5]; + auto bb_x_grad = parameter_lists_grad[i][6]; + auto bb_h_grad = parameter_lists_grad[i][7]; + + int r = + xpu::bilstm_grad(dev_ctx.x_context(), + (const T*)layer_input, + (const T*)h_0, + (const T*)c_0, + (const T*)w_x, + (const T*)w_h, + (const T*)bw_x, + (const T*)bw_h, + (const T*)layer_output, + (const T*)cur_input_ptr, + (const T*)h_t_grad, + (const T*)c_t_grad, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(h_0_grad), + reinterpret_cast(c_0_grad), + w_x_grad, + w_h_grad, + b_x_grad, + b_h_grad, + bw_x_grad, + bw_h_grad, + bb_x_grad, + bb_h_grad, + batch_size, + cur_xdim, + hidden_size, + seq_len, + seq_len_tensor, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + i_f_g_o, + c); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_grad"); + } else { + int r = + xpu::lstm_grad(dev_ctx.x_context(), + (const T*)layer_input, + (const T*)h_0, + (const T*)c_0, + (const T*)w_x, + (const T*)w_h, + (const T*)layer_output, + (const T*)cur_input_ptr, + (const T*)h_t_grad, + (const T*)c_t_grad, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(h_0_grad), + reinterpret_cast(c_0_grad), + w_x_grad, + w_h_grad, + b_x_grad, + b_h_grad, + batch_size, + cur_xdim, + hidden_size, + seq_len, + seq_len_tensor, + nullptr, + nullptr, + nullptr, + nullptr, + i_f_g_o, + c); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_grad"); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rnn_grad, XPU, ALL_LAYOUT, phi::RnnGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/rnn_kernel.cc b/paddle/phi/kernels/xpu/rnn_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d9234b38f16f5ad550a39bfe02c3f7d865fcda5 --- /dev/null +++ b/paddle/phi/kernels/xpu/rnn_kernel.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rnn_kernel.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/xpu/rnn_util.h" + +namespace phi { + +template +void RnnKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& pre_state, + const std::vector& weight_list, + const paddle::optional& sequence_length, + float dropout_prob, + bool is_bidirec, + int input_size, + int hidden_size, + int num_layers, + const std::string& mode, + int seed, + bool is_test, + DenseTensor* out, + DenseTensor* dropout_state, + std::vector state, + DenseTensor* reserve) { + using XPUTyp = typename XPUTypeTrait::Type; + if (dropout_state->IsInitialized()) { + if (dropout_state->numel() != out->numel()) dropout_state->clear(); + } + + dropout_state->Resize(out->dims()); + dev_ctx.template Alloc(dropout_state); + + phi::funcs::SetConstant ones; + ones(dev_ctx, dropout_state, static_cast(1)); + + PADDLE_ENFORCE_EQ( + mode, + "LSTM", + errors::InvalidArgument( + "XPU only support LSTM mode now, current mode is %s", mode)); + + auto init_h = pre_state[0]; + auto init_c = pre_state[1]; + auto last_h = state[0]; + auto last_c = state[1]; + + // check shape + const int& seq_len = x.dims()[0]; // time_step + const int& batch_size = x.dims()[1]; + const int& input_dim = x.dims()[2]; + const int& direction_num = is_bidirec ? 2 : 1; + + PADDLE_ENFORCE_EQ( + init_h->dims()[0], + num_layers * direction_num, + errors::InvalidArgument("The num_layers of in RNN layer must" + " be the same as first dim of init " + "hidden, but received num_layers:%d," + " dim:%d", + num_layers, + init_h->dims()[0])); + + PADDLE_ENFORCE_EQ( + init_c->dims()[0], + num_layers * direction_num, + errors::InvalidArgument( + "The num_layers of in RNN layer must" + " be the same as first dim of cell state hidden, but received" + " num_layers:%d, dim:%d", + num_layers, + init_c->dims()[0])); + // weightlist + std::vector> parameter_lists; + parameter_lists.resize(num_layers); + reset_parameter_vector(weight_list, num_layers, is_bidirec, ¶meter_lists); + + // init the output and allocate the memory + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(last_h); + dev_ctx.template Alloc(last_c); + + int gate_num = 4; + int hidden_data_idx = (num_layers - 1); + hidden_data_idx += (gate_num + 1) * num_layers; + const int& block_size = direction_num * seq_len * batch_size * hidden_size; + reserve->Resize({hidden_data_idx, block_size}); + dev_ctx.template Alloc(reserve); + + // get ptr from tensor + auto x_data = x.data(); + auto init_h_ptr = init_h->data(); + auto init_c_ptr = init_c->data(); + auto y = out->data(); + auto last_h_ptr = last_h->data(); + auto last_c_ptr = last_c->data(); + auto i_f_g_o_ptr = reserve->data(); + auto c_ptr = + i_f_g_o_ptr + num_layers * block_size * 4; // 4 for i_f_g_o offset + auto hidden_data_ptr = c_ptr + num_layers * block_size * 1; // 1 for c offset + + std::vector seq_len_tensor(batch_size, seq_len); + + bool has_seq_length = sequence_length.is_initialized(); + + if (has_seq_length) { + seq_len_tensor = + paddle::operators::GetDataFromTensor(sequence_length.get_ptr()); + } + + int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; + + const T* cur_input_ptr = nullptr; + int cur_xdim = -1; + T* cur_output_ptr = y; + for (int i = 0; i < num_layers; i++) { + auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4; + auto c = c_ptr + i * block_size; + + cur_output_ptr = y; + if (i < num_layers - 1 && num_layers > 1) { + cur_output_ptr = hidden_data_ptr + i * block_size; + } + + if (i == 0) { + cur_input_ptr = x_data; + cur_xdim = input_dim; + } else { + cur_input_ptr = hidden_data_ptr + (i - 1) * block_size; + cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; + } + + auto h_0 = init_h_ptr + direction_num * i * state_offset; + auto c_0 = init_c_ptr + direction_num * i * state_offset; + auto last_h = last_h_ptr + direction_num * i * state_offset; + auto last_c = last_c_ptr + direction_num * i * state_offset; + + auto w_x = parameter_lists[i][0]; + auto w_h = parameter_lists[i][1]; + auto b_x = parameter_lists[i][2]; + auto b_h = parameter_lists[i][3]; + if (is_bidirec) { + auto bw_x = parameter_lists[i][4]; + auto bw_h = parameter_lists[i][5]; + auto bb_x = parameter_lists[i][6]; + auto bb_h = parameter_lists[i][7]; + + int r = + xpu::bilstm_train(dev_ctx.x_context(), + (const T*)cur_input_ptr, + (const T*)h_0, + (const T*)c_0, + (const T*)w_x, + (const T*)w_h, + (const T*)b_x, + (const T*)b_h, + (const T*)bw_x, + (const T*)bw_h, + (const T*)bb_x, + (const T*)bb_h, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(last_h), + reinterpret_cast(last_c), + batch_size, + cur_xdim, + hidden_size, + seq_len, + seq_len_tensor, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + reinterpret_cast(i_f_g_o), + reinterpret_cast(c)); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_train"); + } else { + int r = + xpu::lstm_train(dev_ctx.x_context(), + (const T*)cur_input_ptr, + (const T*)h_0, + (const T*)c_0, + (const T*)w_x, + (const T*)w_h, + (const T*)b_x, + (const T*)b_h, + reinterpret_cast(cur_output_ptr), + reinterpret_cast(last_h), + reinterpret_cast(last_c), + batch_size, + cur_xdim, + hidden_size, + seq_len, + seq_len_tensor, + nullptr, + nullptr, + nullptr, + nullptr, + reinterpret_cast(i_f_g_o), + reinterpret_cast(c), + xpu::Activation_t::TANH, + xpu::Activation_t::SIGMOID); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train"); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(rnn, XPU, ALL_LAYOUT, phi::RnnKernel, float) {} diff --git a/paddle/phi/kernels/xpu/rnn_util.h b/paddle/phi/kernels/xpu/rnn_util.h new file mode 100644 index 0000000000000000000000000000000000000000..c42cb1309d11336cd2cac6f70d713a980271779b --- /dev/null +++ b/paddle/phi/kernels/xpu/rnn_util.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace phi { + +template +void reset_parameter_vector(const std::vector& raw_params_vec, + const int& num_layers, + const bool& is_bidirec, + std::vector>* params_vec) { + // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers + // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to + // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers + const int& direction_num = is_bidirec ? 2 : 1; + const int& layer_weight_size = 4 * direction_num; + const int& all_weight_size = num_layers * layer_weight_size; + const int& bias_start_idx = all_weight_size / 2; + for (int i = 0; i < num_layers; i++) { + params_vec->at(i).resize(layer_weight_size); + for (int j = 0; j < layer_weight_size; j++) { + int k = j % 4; + const int& section = j / 4; + int tensor_idx = i * 2 * direction_num + section * 2 + k % 2; + if (k >= 2) { + tensor_idx += bias_start_idx; + } + using remove_cv_t = typename std::remove_cv::type; + params_vec->at(i)[j] = + raw_params_vec[tensor_idx]->template data(); + } + } +} + +} // namespace phi