未验证 提交 8b07ce0e 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support multi layer and bidirection of lstm_grad, *test=kunlun (#41742)

* support multi layer and bidirection of lstm_grad, *test=kunlun

* support multi layer and bidirection of lstm_grad, *test=kunlun
上级 bda4965a
...@@ -125,23 +125,13 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -125,23 +125,13 @@ class RnnXPUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
last_h->mutable_data<T>(ctx.GetPlace()); last_h->mutable_data<T>(ctx.GetPlace());
last_c->mutable_data<T>(ctx.GetPlace()); last_c->mutable_data<T>(ctx.GetPlace());
int gate_num = 4;
int hidden_data_idx = (num_layers - 1);
hidden_data_idx += (gate_num + 1) * num_layers;
const int& block_size = direction_num * seq_len * batch_size * hidden_size;
reserve_data->Resize({hidden_data_idx, block_size});
reserve_data->Resize(
{num_layers * direction_num * seq_len * batch_size * hidden_size * 5});
reserve_data->mutable_data<T>(ctx.GetPlace()); reserve_data->mutable_data<T>(ctx.GetPlace());
Tensor internal_output_1_tensor, internal_output_2_tensor;
T* internal_output_1_ptr = nullptr;
T* internal_output_2_ptr = nullptr;
if (num_layers >= 2) {
internal_output_1_tensor.Resize(output->dims());
internal_output_1_ptr =
internal_output_1_tensor.mutable_data<T>(ctx.GetPlace());
}
if (num_layers >= 3) {
internal_output_2_tensor.Resize(output->dims());
internal_output_2_ptr =
internal_output_2_tensor.mutable_data<T>(ctx.GetPlace());
}
// get ptr from tensor // get ptr from tensor
auto x = input->data<T>(); auto x = input->data<T>();
auto init_h_ptr = init_h->data<T>(); auto init_h_ptr = init_h->data<T>();
...@@ -151,8 +141,9 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -151,8 +141,9 @@ class RnnXPUKernel : public framework::OpKernel<T> {
auto last_c_ptr = last_c->data<T>(); auto last_c_ptr = last_c->data<T>();
auto i_f_g_o_ptr = reserve_data->data<T>(); auto i_f_g_o_ptr = reserve_data->data<T>();
auto c_ptr = auto c_ptr =
i_f_g_o_ptr + i_f_g_o_ptr + num_layers * block_size * 4; // 4 for i_f_g_o offset
num_layers * direction_num * seq_len * batch_size * hidden_size * 4; auto hidden_data_ptr =
c_ptr + num_layers * block_size * 1; // 1 for c offset
std::vector<int> seq_len_tensor(batch_size, seq_len); std::vector<int> seq_len_tensor(batch_size, seq_len);
if (has_seq_length) { if (has_seq_length) {
...@@ -161,33 +152,26 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -161,33 +152,26 @@ class RnnXPUKernel : public framework::OpKernel<T> {
int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2]; int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2];
for (int i = 0; i < num_layers; i++) {
auto i_f_g_o = i_f_g_o_ptr +
i * direction_num * seq_len * batch_size * hidden_size * 4;
auto c = c_ptr + i * direction_num * seq_len * batch_size * hidden_size;
const T* cur_input_ptr = nullptr; const T* cur_input_ptr = nullptr;
int cur_xdim = -1; int cur_xdim = -1;
T* cur_output_ptr = y;
for (int i = 0; i < num_layers; i++) {
auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4;
auto c = c_ptr + i * block_size;
cur_output_ptr = y;
if (i < num_layers - 1 && num_layers > 1) {
cur_output_ptr = hidden_data_ptr + i * block_size;
}
if (i == 0) { if (i == 0) {
cur_input_ptr = x; cur_input_ptr = x;
cur_xdim = input_dim; cur_xdim = input_dim;
} else if (i % 2 != 0) {
cur_input_ptr = internal_output_1_ptr;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
} else { } else {
cur_input_ptr = internal_output_2_ptr; cur_input_ptr = hidden_data_ptr + (i - 1) * block_size;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size; cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
} }
T* cur_output_ptr = nullptr;
if (i == num_layers - 1) {
cur_output_ptr = y;
} else if (i % 2 != 0) {
cur_output_ptr = internal_output_2_ptr;
} else {
cur_output_ptr = internal_output_1_ptr;
}
auto h_0 = init_h_ptr + direction_num * i * state_offset; auto h_0 = init_h_ptr + direction_num * i * state_offset;
auto c_0 = init_c_ptr + direction_num * i * state_offset; auto c_0 = init_c_ptr + direction_num * i * state_offset;
auto last_h = last_h_ptr + direction_num * i * state_offset; auto last_h = last_h_ptr + direction_num * i * state_offset;
...@@ -233,6 +217,8 @@ class RnnXPUKernel : public framework::OpKernel<T> { ...@@ -233,6 +217,8 @@ class RnnXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class RnnXPUGradKernel : public framework::OpKernel<T> { class RnnXPUGradKernel : public framework::OpKernel<T> {
using XPUTyp = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// get the tensor pointer for the input // get the tensor pointer for the input
...@@ -243,6 +229,7 @@ class RnnXPUGradKernel : public framework::OpKernel<T> { ...@@ -243,6 +229,7 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
auto* reserve_data = ctx.Input<Tensor>("Reserve"); auto* reserve_data = ctx.Input<Tensor>("Reserve");
const int& num_layers = ctx.Attr<int>("num_layers"); const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec"); const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const int& hidden_size = ctx.Attr<int>("hidden_size"); const int& hidden_size = ctx.Attr<int>("hidden_size");
const std::string& mode = ctx.Attr<std::string>("mode"); const std::string& mode = ctx.Attr<std::string>("mode");
...@@ -257,16 +244,6 @@ class RnnXPUGradKernel : public framework::OpKernel<T> { ...@@ -257,16 +244,6 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"XPU only support LSTM mode now, current mode is %s", mode)); "XPU only support LSTM mode now, current mode is %s", mode));
PADDLE_ENFORCE_EQ(is_bidirec, false,
platform::errors::InvalidArgument(
"XPU only support unidirectional LSTM now"));
PADDLE_ENFORCE_EQ(
num_layers, 1,
platform::errors::InvalidArgument(
"XPU only support 1 layer LSTM now, current layer num is %s",
num_layers));
auto init_h = pre_state[0]; auto init_h = pre_state[0];
auto init_c = pre_state[1]; auto init_c = pre_state[1];
...@@ -289,11 +266,12 @@ class RnnXPUGradKernel : public framework::OpKernel<T> { ...@@ -289,11 +266,12 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
} }
// check shape // check shape
int seq_len = input->dims()[0]; const int& seq_len = input->dims()[0];
int batch_size = input->dims()[1]; const int& batch_size = input->dims()[1];
int input_dim = input->dims()[2]; const int& input_dim = input->dims()[2];
const int& direction_num = is_bidirec ? 2 : 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
init_h->dims()[0], num_layers, init_h->dims()[0], num_layers * direction_num,
platform::errors::InvalidArgument("The num_layers of in RNN layer must" platform::errors::InvalidArgument("The num_layers of in RNN layer must"
" be the same as first dim of init " " be the same as first dim of init "
"hidden, but received num_layers:%d," "hidden, but received num_layers:%d,"
...@@ -301,7 +279,7 @@ class RnnXPUGradKernel : public framework::OpKernel<T> { ...@@ -301,7 +279,7 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
num_layers, init_h->dims()[0])); num_layers, init_h->dims()[0]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
init_c->dims()[0], num_layers, init_c->dims()[0], num_layers * direction_num,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The num_layers of in RNN layer must" "The num_layers of in RNN layer must"
" be the same as first dim of cell state hidden, but received" " be the same as first dim of cell state hidden, but received"
...@@ -323,52 +301,165 @@ class RnnXPUGradKernel : public framework::OpKernel<T> { ...@@ -323,52 +301,165 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
// allocate the memory and initization the input_grad // allocate the memory and initization the input_grad
input_grad->mutable_data<T>(input->dims(), ctx.GetPlace()); input_grad->mutable_data<T>(input->dims(), ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<platform::XPUDeviceContext, T> zero;
zero(dev_ctx, input_grad, static_cast<T>(0.0));
Tensor a, b;
Tensor* dynamic_grad_pre_h = &a;
Tensor* dynamic_grad_pre_c = &b;
if (init_h_grad) { if (init_h_grad) {
init_h_grad->mutable_data<T>(init_h->dims(), ctx.GetPlace()); init_h_grad->mutable_data<T>(last_h_grad->dims(), ctx.GetPlace());
zero(dev_ctx, init_h_grad, static_cast<T>(0.0));
} else {
dynamic_grad_pre_h->Resize(last_h_grad->dims());
dynamic_grad_pre_h->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, dynamic_grad_pre_h, static_cast<T>(0.0));
init_h_grad = dynamic_grad_pre_h;
} }
if (init_c_grad) { if (init_c_grad) {
init_c_grad->mutable_data<T>(init_c->dims(), ctx.GetPlace()); init_c_grad->mutable_data<T>(last_c_grad->dims(), ctx.GetPlace());
} else {
dynamic_grad_pre_c->Resize(last_h_grad->dims());
dynamic_grad_pre_c->mutable_data<T>(ctx.GetPlace());
init_c_grad = dynamic_grad_pre_c;
}
Tensor temp_input_grad_1, temp_input_grad_2;
T* input_grad_1_ptr = nullptr;
T* input_grad_2_ptr = nullptr;
if (num_layers >= 2) {
temp_input_grad_1.Resize(output_grad->dims());
input_grad_1_ptr = temp_input_grad_1.mutable_data<T>(ctx.GetPlace());
}
if (num_layers >= 3) {
temp_input_grad_2.Resize(output_grad->dims());
input_grad_2_ptr = temp_input_grad_2.mutable_data<T>(ctx.GetPlace());
} }
// get ptr from tensor // get ptr from tensor
auto x = input->data<T>(); auto x = input->data<T>();
auto h_0 = init_h->data<T>(); auto init_h_ptr = init_h->data<T>();
auto c_0 = init_c->data<T>(); auto init_c_ptr = init_c->data<T>();
auto w_x = parameter_lists[0][0];
auto w_h = parameter_lists[0][1];
auto y = output->data<T>(); auto y = output->data<T>();
auto y_grad = output_grad->data<T>(); auto y_grad = output_grad->data<T>();
auto last_h_grad_ptr = last_h_grad->data<T>(); auto last_h_grad_ptr = last_h_grad->data<T>();
auto last_c_grad_ptr = last_c_grad->data<T>(); auto last_c_grad_ptr = last_c_grad->data<T>();
auto x_grad = input_grad->data<T>(); auto x_grad = input_grad->data<T>();
auto h_0_grad = init_h_grad ? init_h_grad->data<T>() : nullptr; auto init_h_grad_ptr = init_h_grad->data<T>();
auto c_0_grad = init_c_grad ? init_c_grad->data<T>() : nullptr; auto init_c_grad_ptr = init_c_grad->data<T>();
auto w_x_grad = parameter_lists_grad[0][0]; const int& block_size = direction_num * seq_len * batch_size * hidden_size;
auto w_h_grad = parameter_lists_grad[0][1]; auto i_f_g_o_ptr = reserve_data->data<T>();
auto b_x_grad = parameter_lists_grad[0][2]; auto c_ptr = i_f_g_o_ptr + num_layers * block_size * 4;
auto b_h_grad = parameter_lists_grad[0][3]; auto hidden_data_ptr = c_ptr + num_layers * block_size * 1;
auto i_f_g_o = reserve_data->data<T>(); int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2];
auto c = i_f_g_o + seq_len * batch_size * hidden_size * 4;
std::vector<int> seq_len_tensor(batch_size, seq_len); std::vector<int> seq_len_tensor(batch_size, seq_len);
if (has_seq_length) { if (has_seq_length) {
seq_len_tensor = operators::GetDataFromTensor(sequence_length); seq_len_tensor = operators::GetDataFromTensor(sequence_length);
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>(); for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data
auto w_x = parameter_lists[i][0];
auto w_h = parameter_lists[i][1];
auto bw_x = parameter_lists[i][4];
auto bw_h = parameter_lists[i][5];
auto i_f_g_o = i_f_g_o_ptr + i * block_size * 4;
auto c = c_ptr + i * block_size;
Tensor layer_input_t;
auto layer_input = x;
if (i > 0) {
layer_input_t.Resize(output->dims());
layer_input = layer_input_t.mutable_data<T>(ctx.GetPlace());
float scale = static_cast<float>(1.0f - dropout_prob);
auto hidden_data = hidden_data_ptr + (i - 1) * block_size;
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(hidden_data),
const_cast<XPUTyp*>(layer_input), output->numel(),
false, scale, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
} else {
layer_input = x;
}
auto layer_output = y;
if (i == num_layers - 1) {
layer_output = y;
} else {
layer_output = hidden_data_ptr + i * block_size;
}
const T* cur_input_ptr = nullptr;
if (i == num_layers - 1) {
cur_input_ptr = y_grad;
} else if (i % 2 != 0) {
cur_input_ptr = input_grad_2_ptr;
} else {
cur_input_ptr = input_grad_1_ptr;
}
T* cur_output_ptr = nullptr;
int cur_xdim = -1;
if (i == 0) {
cur_output_ptr = x_grad;
cur_xdim = input_dim;
} else if (i % 2 != 0) {
cur_output_ptr = input_grad_1_ptr;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
} else {
cur_output_ptr = input_grad_2_ptr;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
}
auto w_x_grad = parameter_lists_grad[i][0];
auto w_h_grad = parameter_lists_grad[i][1];
auto b_x_grad = parameter_lists_grad[i][2];
auto b_h_grad = parameter_lists_grad[i][3];
auto h_0 = init_h_ptr + direction_num * i * state_offset;
auto c_0 = init_c_ptr + direction_num * i * state_offset;
auto h_0_grad = init_h_grad_ptr + direction_num * i * state_offset;
auto c_0_grad = init_c_grad_ptr + direction_num * i * state_offset;
auto h_t_grad = last_h_grad_ptr + direction_num * i * state_offset;
auto c_t_grad = last_c_grad_ptr + direction_num * i * state_offset;
if (is_bidirec) {
auto bw_x_grad = parameter_lists_grad[i][4];
auto bw_h_grad = parameter_lists_grad[i][5];
auto bb_x_grad = parameter_lists_grad[i][6];
auto bb_h_grad = parameter_lists_grad[i][7];
int r = xpu::bilstm_grad<T, T, int16_t>(
dev_ctx.x_context(), (const T*)layer_input, (const T*)h_0,
(const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)bw_x,
(const T*)bw_h, (const T*)layer_output, (const T*)cur_input_ptr,
(const T*)h_t_grad, (const T*)c_t_grad,
reinterpret_cast<T*>(cur_output_ptr),
reinterpret_cast<T*>(h_0_grad), reinterpret_cast<T*>(c_0_grad),
w_x_grad, w_h_grad, b_x_grad, b_h_grad, bw_x_grad, bw_h_grad,
bb_x_grad, bb_h_grad, batch_size, cur_xdim, hidden_size, seq_len,
seq_len_tensor, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, i_f_g_o, c);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bilstm_grad");
} else {
int r = xpu::lstm_grad<T, T, int16_t>( int r = xpu::lstm_grad<T, T, int16_t>(
dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0, dev_ctx.x_context(), (const T*)layer_input, (const T*)h_0,
(const T*)w_x, (const T*)w_h, (const T*)y, (const T*)y_grad, (const T*)c_0, (const T*)w_x, (const T*)w_h, (const T*)layer_output,
(const T*)last_h_grad_ptr, (const T*)last_c_grad_ptr, (const T*)cur_input_ptr, (const T*)h_t_grad, (const T*)c_t_grad,
reinterpret_cast<T*>(x_grad), reinterpret_cast<T*>(h_0_grad), reinterpret_cast<T*>(cur_output_ptr),
reinterpret_cast<T*>(c_0_grad), w_x_grad, w_h_grad, b_x_grad, b_h_grad, reinterpret_cast<T*>(h_0_grad), reinterpret_cast<T*>(c_0_grad),
batch_size, input_dim, hidden_size, seq_len, seq_len_tensor, nullptr, w_x_grad, w_h_grad, b_x_grad, b_h_grad, batch_size, cur_xdim,
nullptr, nullptr, nullptr, i_f_g_o, c); hidden_size, seq_len, seq_len_tensor, nullptr, nullptr, nullptr,
PADDLE_ENFORCE_EQ( nullptr, i_f_g_o, c);
r, xpu::Error_t::SUCCESS,
platform::errors::External("RnnXPUGrad(lstm) return wrong " PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_grad");
"value[%d %s]", }
r, XPUAPIErrorMsg[r])); }
} }
}; };
......
...@@ -306,6 +306,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -306,6 +306,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad", {"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -122,7 +122,7 @@ class XPUTestRNNOp(XPUOpTestWrapper): ...@@ -122,7 +122,7 @@ class XPUTestRNNOp(XPUOpTestWrapper):
def set_xpu(self): def set_xpu(self):
self.__class__.use_xpu = True self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True self.__class__.no_need_check_grad = False
self.__class__.op_type = self.in_type self.__class__.op_type = self.in_type
def test_check_output(self): def test_check_output(self):
...@@ -130,6 +130,15 @@ class XPUTestRNNOp(XPUOpTestWrapper): ...@@ -130,6 +130,15 @@ class XPUTestRNNOp(XPUOpTestWrapper):
self.place, atol=0.01, self.place, atol=0.01,
no_check_set=['Reserve', 'DropoutState']) no_check_set=['Reserve', 'DropoutState'])
def test_grad(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad_with_place(self.place,
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
def init_size(self): def init_size(self):
self.seq_length = 12 self.seq_length = 12
self.batch_size = 5 self.batch_size = 5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册