未验证 提交 b1adde3d 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

use bilstm_train for rnn forward, * test=kunlun (#41671)

上级 f4cc5def
...@@ -51,41 +51,6 @@ void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec, ...@@ -51,41 +51,6 @@ void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
} }
} }
template <typename DeviceContext, typename T>
void RunLSTMLayer(const framework::ExecutionContext& ctx, int seq_len,
int batch_size, int xdim, int hidden_size, const T* x, T* y,
const T* init_h, const T* init_c, T* last_h, T* last_c,
int state_offset, const std::vector<int>& seq_len_tensor,
const std::vector<const T*>& param_list, T* i_f_g_o, T* c,
bool is_bidirect, int layer_idx, int offset) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto w_x = param_list[0 + offset * 4];
auto w_h = param_list[1 + offset * 4];
auto b_x = param_list[2 + offset * 4];
auto b_h = param_list[3 + offset * 4];
auto h_0 = init_h + layer_idx * state_offset;
auto c_0 = init_c + layer_idx * state_offset;
auto last_h_ptr = last_h + layer_idx * state_offset;
auto last_c_ptr = last_c + layer_idx * state_offset;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::lstm_train<T, T, int16_t>(
dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0,
(const T*)w_x, (const T*)w_h, (const T*)b_x, (const T*)b_h,
reinterpret_cast<T*>(y), reinterpret_cast<T*>(last_h_ptr),
reinterpret_cast<T*>(last_c_ptr), batch_size, xdim, hidden_size, seq_len,
seq_len_tensor, is_reverse, nullptr, nullptr, nullptr, nullptr,
reinterpret_cast<T*>(i_f_g_o), reinterpret_cast<T*>(c),
xpu::Activation_t::TANH, xpu::Activation_t::SIGMOID);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train");
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class RnnXPUKernel : public framework::OpKernel<T> { class RnnXPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -184,9 +149,9 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -184,9 +149,9 @@ class RnnXPUKernel : public framework::OpKernel<T> {
auto y = output->data<T>(); auto y = output->data<T>();
auto last_h_ptr = last_h->data<T>(); auto last_h_ptr = last_h->data<T>();
auto last_c_ptr = last_c->data<T>(); auto last_c_ptr = last_c->data<T>();
auto i_f_g_o = reserve_data->data<T>(); auto i_f_g_o_ptr = reserve_data->data<T>();
auto c = auto c_ptr =
i_f_g_o + i_f_g_o_ptr +
num_layers * direction_num * seq_len * batch_size * hidden_size * 4; num_layers * direction_num * seq_len * batch_size * hidden_size * 4;
std::vector<int> seq_len_tensor(batch_size, seq_len); std::vector<int> seq_len_tensor(batch_size, seq_len);
...@@ -197,11 +162,12 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -197,11 +162,12 @@ class RnnXPUKernel : public framework::OpKernel<T> {
int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2];
for (int i = 0; i < num_layers; i++) { for (int i = 0; i < num_layers; i++) {
auto i_f_g_o = i_f_g_o_ptr +
i * direction_num * seq_len * batch_size * hidden_size * 4;
auto c = c_ptr + i * direction_num * seq_len * batch_size * hidden_size;
const T* cur_input_ptr = nullptr; const T* cur_input_ptr = nullptr;
int cur_xdim = -1; int cur_xdim = -1;
i_f_g_o += i * direction_num * seq_len * batch_size * hidden_size * 4;
c += i * direction_num * seq_len * batch_size * hidden_size;
if (i == 0) { if (i == 0) {
cur_input_ptr = x; cur_input_ptr = x;
cur_xdim = input_dim; cur_xdim = input_dim;
...@@ -222,41 +188,44 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -222,41 +188,44 @@ class RnnXPUKernel : public framework::OpKernel<T> {
cur_output_ptr = internal_output_1_ptr; cur_output_ptr = internal_output_1_ptr;
} }
auto h_0 = init_h_ptr + direction_num * i * state_offset;
auto c_0 = init_c_ptr + direction_num * i * state_offset;
auto last_h = last_h_ptr + direction_num * i * state_offset;
auto last_c = last_c_ptr + direction_num * i * state_offset;
auto w_x = parameter_lists[i][0];
auto w_h = parameter_lists[i][1];
auto b_x = parameter_lists[i][2];
auto b_h = parameter_lists[i][3];
if (is_bidirec) { if (is_bidirec) {
std::vector<Tensor> output_vec(2); auto bw_x = parameter_lists[i][4];
std::vector<T*> output_ptr_vec(2); auto bw_h = parameter_lists[i][5];
for (int k = 0; k < 2; ++k) { auto bb_x = parameter_lists[i][6];
output_vec[k].Resize({seq_len, batch_size, output->dims()[2] / 2}); auto bb_h = parameter_lists[i][7];
output_ptr_vec[k] = output_vec[k].mutable_data<T>(ctx.GetPlace());
} int r = xpu::bilstm_train<T, T, int16_t>(
RunLSTMLayer<DeviceContext, T>( dev_ctx.x_context(), (const T*)cur_input_ptr, (const T*)h_0,
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)b_x,
output_ptr_vec[0], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr, (const T*)b_h, (const T*)bw_x, (const T*)bw_h, (const T*)bb_x,
state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c, (const T*)bb_h, reinterpret_cast<T*>(cur_output_ptr),
is_bidirec, i, 0); reinterpret_cast<T*>(last_h), reinterpret_cast<T*>(last_c),
batch_size, cur_xdim, hidden_size, seq_len, seq_len_tensor, nullptr,
T* bw_i_f_g_o = i_f_g_o + seq_len * batch_size * hidden_size * 4; nullptr, nullptr, nullptr, nullptr, nullptr,
T* bw_c = c + seq_len * batch_size * hidden_size; reinterpret_cast<T*>(i_f_g_o), reinterpret_cast<T*>(c));
RunLSTMLayer<DeviceContext, T>(
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_train");
output_ptr_vec[1], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr,
state_offset, seq_len_tensor, parameter_lists[i], bw_i_f_g_o, bw_c,
is_bidirec, i, 1);
// concat
int r = xpu::concat<T>(
dev_ctx.x_context(), {output_ptr_vec[0], output_ptr_vec[1]},
cur_output_ptr, {{seq_len, batch_size, hidden_size},
{seq_len, batch_size, hidden_size}},
2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat");
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else { } else {
RunLSTMLayer<DeviceContext, T>( int r = xpu::lstm_train<T, T, int16_t>(
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr, dev_ctx.x_context(), (const T*)cur_input_ptr, (const T*)h_0,
cur_output_ptr, init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr, (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)b_x,
state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c, (const T*)b_h, reinterpret_cast<T*>(cur_output_ptr),
is_bidirec, i, 0); reinterpret_cast<T*>(last_h), reinterpret_cast<T*>(last_c),
batch_size, cur_xdim, hidden_size, seq_len, seq_len_tensor, nullptr,
nullptr, nullptr, nullptr, reinterpret_cast<T*>(i_f_g_o),
reinterpret_cast<T*>(c), xpu::Activation_t::TANH,
xpu::Activation_t::SIGMOID);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train");
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册