From 6ed20474d47a2577159a3799549c457e9f38f420 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Aug 2018 10:17:47 +0800 Subject: [PATCH] refine attention lstm infershape --- paddle/fluid/operators/attention_lstm_op.cc | 198 +++++++++++--------- 1 file changed, 111 insertions(+), 87 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 178a1c19a9e..636deb04a13 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -26,86 +26,102 @@ namespace paddle { namespace operators { void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Input(WeightX) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Input(WeightH) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasOutput("XX"), - "Output(XX) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(C0) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"), + "Input(LSTMWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LSTMBias"), + "Input(LSTMBias) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"), + "Input(AttentionWeight) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Output(Hidden) of LSTM should not be null."); + "Output(Hidden) of AttentionLSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Output(Cell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), - "Output(BatchedGate) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), - "Output(BatchedGate) of LSTM should not be null."); + "Output(Cell) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"), + "Output(AttentionedX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"), + "Output(AttentionFCOut) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMX"), + "Output(LSTMX) of AttentionLSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"), + "Output(LSTMOUT) of AttentionLSTM should not be null."); auto x_dims = ctx->GetInputDim("X"); + const int M = x_dims[1]; PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + auto w_dims = ctx->GetInputDim("LSTMWeight"); + const int D = w_dims[1] / 4; + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + + auto b_dims = ctx->GetInputDim("LSTMBias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M, + D); + PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).", + M, D); + + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); if (ctx->HasInput("H0")) { - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(Cell) and Input(Hidden) of LSTM should not " - "be null at the same time."); auto h_dims = ctx->GetInputDim("H0"); - auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE(h_dims == c_dims, "The dimension of Input(H0) and Input(C0) " "should be the same."); } - // fc_out , shape (maxseqlen,1) - int max_seq_len = 0; - - auto wx_dims = ctx->GetInputDim("WeightX"); - PADDLE_ENFORCE_EQ(wx_dims.size(), 2, - "The rank of Input(WeightX) should be 2."); - PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], - "The first dimension of Input(WeightX) " - "should be %d.", - x_dims[1]); - - int frame_size = wx_dims[1] / 4; - auto wh_dims = ctx->GetInputDim("WeightH"); - PADDLE_ENFORCE_EQ(wh_dims.size(), 2, - "The rank of Input(WeightH) should be 2."); - PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, - "The first dimension of Input(WeightH) " - "should be %d.", - frame_size); - PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, - "The second dimension of Input(WeightH) " - "should be 4 * %d.", - frame_size); - - auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, - "The first dimension of Input(Bias) should be 1."); - - PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), - "Do not support peephole yet."); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, - "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes connection", - frame_size); - - framework::DDim out_dims({x_dims[0], frame_size}); + auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); + PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, + "Input(AttentionWeight)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { + auto atten_b_dims = ctx->GetInputDim("AttentionBias"); + PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, + "Input(AttentionBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalar")) { + auto dims = ctx->GetInputDim("AttentionScalar"); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalar)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } + + if (ctx->HasInput("AttentionScalarBias")) { + auto dims = ctx->GetInputDim("AttentionScalarBias"); + PADDLE_ENFORCE( + ctx->HasInput("AttentionScalar"), + "AttentionScalar should not be null when have AttentionScalarBias."); + PADDLE_ENFORCE_EQ(dims.size(), 2, + "Input(AttentionScalarBias)'s rank must be 2."); + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + } + + framework::DDim out_dims({x_dims[0], D}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); - ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->SetOutputDim("AttentionedX", {x_dims[0], 1}); + ctx->SetOutputDim("LSTMX", {1, M}); + ctx->SetOutputDim("LSTMOUT", {1, 4 * D}); + // AttentionFCOut should be reshape as (maxseqlen,1) in runtime ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); - ctx->ShareLoD("X", "XX"); } framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( @@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() { AddOutput("Cell", "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput( - "AttentionedX", - "(LodTensor) shape is (T x 1), the result after X * AttentionWeight," - " where T is the total time steps in this mini-batch," - " D is the hidden size.") + AddOutput("AttentionedX", + "(Tensor) shape is (T x 1), the result after X * AttentionWeight," + " where T is the total time steps in this mini-batch," + " D is the hidden size.") .AsIntermediate(); AddOutput("AttentionFCOut", "(Tensor) (max_seq_len, 1), compute at each step.") @@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 - auto* hidden_out = ctx.Output("Hidden"); // TxD - auto* cell_out = ctx.Output("Cell"); // TxD - auto* atted_x = ctx.Output("AttentionedX"); // T x 1 - auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 - auto* lstm_x = ctx.Output("LSTMX"); // 1 x M - auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + auto* hidden_out = ctx.Output("Hidden"); // TxD + auto* cell_out = ctx.Output("Cell"); // TxD + auto* atted_x = ctx.Output("AttentionedX"); // T x 1 + auto* fc_out = ctx.Output('AttentionFCOut'); // max_seq_len x 1 + auto* lstm_x = ctx.Output("LSTMX"); // 1 x M + auto* lstm_out = ctx.Output("LSTMOUT"); // 1 x 4D + + // some shape should be reshape here since infershape can not get lod info + auto x_lod = x->lod(); + const int N = x_lod[0].size() - 1; // batch size + auto x_dims = x->dims(); // T x M + auto w_dims = w->dims(); // (D+M) x 4D + const int M = x_dims[1]; // x frame size + const int D = w_dims[1] / 4; // gate frame size + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = w_dims[1]; + int max_seq_len = x_lod[0][1]; + for (int i = 1; i < N; ++i) { + int len = x_lod[0][i + 1] - x_lod[0][i]; + max_seq_len = max_seq_len < len ? len : max_seq_len; + } + PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1."); + PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); + fc_out->Resize({max_seq_len, 1}); const T* x_data = x->data(); const T* h0_data = h0->data(); @@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel { T* lstm_x_data = lstm_x->mutable_data(); T* lstm_out_data = lstm_out->mutable_data(); - auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto w_dims = w->dims(); // (D+M) x 4D - const int M = x_dims[1]; // x frame size - const int D = w_dims[1] / 4; // gate frame size - const int D2 = D * 2; - const int D3 = D * 3; - const int D4 = w_dims[1]; - const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1 - // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1 auto blas = math::GetBlas(ctx); math::FCCompute(blas, T, 1, M, x_data, atten_w_data, @@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel { const T* prev_hidden_data = NULL; T* cur_cell_out_data = cell_out_data; T* cur_hidden_out_data = hidden_out_data; - for (int i = 0; i < batch_size; ++i) { + for (int i = 0; i < N; ++i) { int seq_len = x_lod[0][i + 1]; prev_cell_data = c0_data + i * D; prev_hidden_data = h0 ? h0_data + i * D : NULL; @@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel { /// compute attention vector // prev_cell(1xD) * fc(D) rest part of atten_wgt // T = cblas_dot(); - T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M); + T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); // add cell bias and relu bias_relu(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); // fc2: scalar if (atten_scalar_data) { // x = a*x - blas.VSCAL(seq_len, atten_scalar_data, fc_out_data); + blas.SCAL(seq_len, atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } -- GitLab