提交 6ed20474 编写于 作者: T tensor-tang

refine attention lstm infershape

上级 508548f8
...@@ -26,86 +26,102 @@ namespace paddle { ...@@ -26,86 +26,102 @@ namespace paddle {
namespace operators { namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"),
PADDLE_ENFORCE(ctx->HasInput("WeightX"), "Input(X) of AttentionLSTM should not be null.");
"Input(WeightX) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("C0"),
PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Input(C0) of AttentionLSTM should not be null.");
"Input(WeightH) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Input(LSTMWeight) of AttentionLSTM should not be null.");
"Input(Bias) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"), PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Output(XX) of LSTM should not be null."); "Input(AttentionWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null."); "Output(Hidden) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null."); "Output(Cell) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(BatchedGate) of LSTM should not be null."); "Output(AttentionedX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(BatchedGate) of LSTM should not be null."); "Output(AttentionFCOut) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M,
D);
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
M, D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) " "The dimension of Input(H0) and Input(C0) "
"should be the same."); "should be the same.");
} }
// fc_out , shape (maxseqlen,1) auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
int max_seq_len = 0; PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
auto wx_dims = ctx->GetInputDim("WeightX"); PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
PADDLE_ENFORCE_EQ(wx_dims.size(), 2, "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
"The rank of Input(WeightX) should be 2."); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], "AttentionWeight shapes must be (%d + %d) * 1.", M, D);
"The first dimension of Input(WeightX) " if (ctx->HasInput("AttentionBias")) {
"should be %d.", auto atten_b_dims = ctx->GetInputDim("AttentionBias");
x_dims[1]); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
int frame_size = wx_dims[1] / 4; PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
auto wh_dims = ctx->GetInputDim("WeightH"); "AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2, PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"The rank of Input(WeightH) should be 2."); "AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, }
"The first dimension of Input(WeightH) "
"should be %d.", if (ctx->HasInput("AttentionScalar")) {
frame_size); auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(dims.size(), 2,
"The second dimension of Input(WeightH) " "Input(AttentionScalar)'s rank must be 2.");
"should be 4 * %d.", PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
frame_size); PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); if (ctx->HasInput("AttentionScalarBias")) {
PADDLE_ENFORCE_EQ(b_dims[0], 1, auto dims = ctx->GetInputDim("AttentionScalarBias");
"The first dimension of Input(Bias) should be 1."); PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"), "AttentionScalar should not be null when have AttentionScalarBias.");
"Do not support peephole yet."); PADDLE_ENFORCE_EQ(dims.size(), 2,
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, "Input(AttentionScalarBias)'s rank must be 2.");
"The second dimension of Input(Bias) should be " PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
"4 * %d if disable peepholes connection", PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
frame_size); }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
} }
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
...@@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() { ...@@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() {
AddOutput("Cell", AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."); "The shape is (T x D), and lod is the same with the `Input`.");
AddOutput( AddOutput("AttentionedX",
"AttentionedX", "(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
"(LodTensor) shape is (T x 1), the result after X * AttentionWeight," " where T is the total time steps in this mini-batch,"
" where T is the total time steps in this mini-batch," " D is the hidden size.")
" D is the hidden size.")
.AsIntermediate(); .AsIntermediate();
AddOutput("AttentionFCOut", AddOutput("AttentionFCOut",
"(Tensor) (max_seq_len, 1), compute at each step.") "(Tensor) (max_seq_len, 1), compute at each step.")
...@@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4 auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4 auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<LoDTensor>("AttentionedX"); // T x 1 auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1 auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
// some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
int max_seq_len = x_lod[0][1];
for (int i = 1; i < N; ++i) {
int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len;
}
PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
...@@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>(); T* lstm_x_data = lstm_x->mutable_data<T>();
T* lstm_out_data = lstm_out->mutable_data<T>(); T* lstm_out_data = lstm_out->mutable_data<T>();
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data, math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
...@@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
const T* prev_hidden_data = NULL; const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data; T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data; T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1]; int seq_len = x_lod[0][i + 1];
prev_cell_data = c0_data + i * D; prev_cell_data = c0_data + i * D;
prev_hidden_data = h0 ? h0_data + i * D : NULL; prev_hidden_data = h0 ? h0_data + i * D : NULL;
...@@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
/// compute attention vector /// compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt // prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot(); // T = cblas_dot();
T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M); T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// add cell bias and relu // add cell bias and relu
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
// fc2: scalar // fc2: scalar
if (atten_scalar_data) { if (atten_scalar_data) {
// x = a*x // x = a*x
blas.VSCAL(seq_len, atten_scalar_data, fc_out_data); blas.SCAL(seq_len, atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data, bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data); fc_out_data);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册