From b1adde3dc1d74208a8bd2484adb700dc18ec1b8c Mon Sep 17 00:00:00 2001 From: helen88 Date: Wed, 13 Apr 2022 11:23:45 +0800 Subject: [PATCH] use bilstm_train for rnn forward, * test=kunlun (#41671) --- paddle/fluid/operators/rnn_op_xpu.cc | 117 ++++++++++----------------- 1 file changed, 43 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/operators/rnn_op_xpu.cc b/paddle/fluid/operators/rnn_op_xpu.cc index c75c24ab0ab..a18d0ebfca9 100644 --- a/paddle/fluid/operators/rnn_op_xpu.cc +++ b/paddle/fluid/operators/rnn_op_xpu.cc @@ -51,41 +51,6 @@ void reset_parameter_vector(const std::vector& raw_params_vec, } } -template -void RunLSTMLayer(const framework::ExecutionContext& ctx, int seq_len, - int batch_size, int xdim, int hidden_size, const T* x, T* y, - const T* init_h, const T* init_c, T* last_h, T* last_c, - int state_offset, const std::vector& seq_len_tensor, - const std::vector& param_list, T* i_f_g_o, T* c, - bool is_bidirect, int layer_idx, int offset) { - bool is_reverse = false; - if (is_bidirect) { - layer_idx = 2 * layer_idx + offset; - if (offset > 0) { - is_reverse = true; - } - } - auto w_x = param_list[0 + offset * 4]; - auto w_h = param_list[1 + offset * 4]; - auto b_x = param_list[2 + offset * 4]; - auto b_h = param_list[3 + offset * 4]; - - auto h_0 = init_h + layer_idx * state_offset; - auto c_0 = init_c + layer_idx * state_offset; - auto last_h_ptr = last_h + layer_idx * state_offset; - auto last_c_ptr = last_c + layer_idx * state_offset; - auto& dev_ctx = ctx.template device_context(); - int r = xpu::lstm_train( - dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0, - (const T*)w_x, (const T*)w_h, (const T*)b_x, (const T*)b_h, - reinterpret_cast(y), reinterpret_cast(last_h_ptr), - reinterpret_cast(last_c_ptr), batch_size, xdim, hidden_size, seq_len, - seq_len_tensor, is_reverse, nullptr, nullptr, nullptr, nullptr, - reinterpret_cast(i_f_g_o), reinterpret_cast(c), - xpu::Activation_t::TANH, xpu::Activation_t::SIGMOID); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train"); -} - template class RnnXPUKernel : public framework::OpKernel { public: @@ -184,9 +149,9 @@ class RnnXPUKernel : public framework::OpKernel { auto y = output->data(); auto last_h_ptr = last_h->data(); auto last_c_ptr = last_c->data(); - auto i_f_g_o = reserve_data->data(); - auto c = - i_f_g_o + + auto i_f_g_o_ptr = reserve_data->data(); + auto c_ptr = + i_f_g_o_ptr + num_layers * direction_num * seq_len * batch_size * hidden_size * 4; std::vector seq_len_tensor(batch_size, seq_len); @@ -197,11 +162,12 @@ class RnnXPUKernel : public framework::OpKernel { int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; for (int i = 0; i < num_layers; i++) { + auto i_f_g_o = i_f_g_o_ptr + + i * direction_num * seq_len * batch_size * hidden_size * 4; + auto c = c_ptr + i * direction_num * seq_len * batch_size * hidden_size; + const T* cur_input_ptr = nullptr; int cur_xdim = -1; - i_f_g_o += i * direction_num * seq_len * batch_size * hidden_size * 4; - c += i * direction_num * seq_len * batch_size * hidden_size; - if (i == 0) { cur_input_ptr = x; cur_xdim = input_dim; @@ -222,41 +188,44 @@ class RnnXPUKernel : public framework::OpKernel { cur_output_ptr = internal_output_1_ptr; } + auto h_0 = init_h_ptr + direction_num * i * state_offset; + auto c_0 = init_c_ptr + direction_num * i * state_offset; + auto last_h = last_h_ptr + direction_num * i * state_offset; + auto last_c = last_c_ptr + direction_num * i * state_offset; + + auto w_x = parameter_lists[i][0]; + auto w_h = parameter_lists[i][1]; + auto b_x = parameter_lists[i][2]; + auto b_h = parameter_lists[i][3]; if (is_bidirec) { - std::vector output_vec(2); - std::vector output_ptr_vec(2); - for (int k = 0; k < 2; ++k) { - output_vec[k].Resize({seq_len, batch_size, output->dims()[2] / 2}); - output_ptr_vec[k] = output_vec[k].mutable_data(ctx.GetPlace()); - } - RunLSTMLayer( - ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, - output_ptr_vec[0], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr, - state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c, - is_bidirec, i, 0); - - T* bw_i_f_g_o = i_f_g_o + seq_len * batch_size * hidden_size * 4; - T* bw_c = c + seq_len * batch_size * hidden_size; - RunLSTMLayer( - ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, - output_ptr_vec[1], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr, - state_offset, seq_len_tensor, parameter_lists[i], bw_i_f_g_o, bw_c, - is_bidirec, i, 1); - - // concat - int r = xpu::concat( - dev_ctx.x_context(), {output_ptr_vec[0], output_ptr_vec[1]}, - cur_output_ptr, {{seq_len, batch_size, hidden_size}, - {seq_len, batch_size, hidden_size}}, - 2); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat"); - xpu_wait(dev_ctx.x_context()->xpu_stream); + auto bw_x = parameter_lists[i][4]; + auto bw_h = parameter_lists[i][5]; + auto bb_x = parameter_lists[i][6]; + auto bb_h = parameter_lists[i][7]; + + int r = xpu::bilstm_train( + dev_ctx.x_context(), (const T*)cur_input_ptr, (const T*)h_0, + (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)b_x, + (const T*)b_h, (const T*)bw_x, (const T*)bw_h, (const T*)bb_x, + (const T*)bb_h, reinterpret_cast(cur_output_ptr), + reinterpret_cast(last_h), reinterpret_cast(last_c), + batch_size, cur_xdim, hidden_size, seq_len, seq_len_tensor, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, + reinterpret_cast(i_f_g_o), reinterpret_cast(c)); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_train"); } else { - RunLSTMLayer( - ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, - cur_output_ptr, init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr, - state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c, - is_bidirec, i, 0); + int r = xpu::lstm_train( + dev_ctx.x_context(), (const T*)cur_input_ptr, (const T*)h_0, + (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)b_x, + (const T*)b_h, reinterpret_cast(cur_output_ptr), + reinterpret_cast(last_h), reinterpret_cast(last_c), + batch_size, cur_xdim, hidden_size, seq_len, seq_len_tensor, nullptr, + nullptr, nullptr, nullptr, reinterpret_cast(i_f_g_o), + reinterpret_cast(c), xpu::Activation_t::TANH, + xpu::Activation_t::SIGMOID); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train"); } } } -- GitLab