提交 6ed20474 编写于 作者: T tensor-tang

refine attention lstm infershape

上级 508548f8
......@@ -26,86 +26,102 @@ namespace paddle {
namespace operators {
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightX"),
"Input(WeightX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Input(WeightH) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("XX"),
"Output(XX) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMWeight"),
"Input(LSTMWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LSTMBias"),
"Input(LSTMBias) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AttentionWeight"),
"Input(AttentionWeight) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
"Output(Hidden) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"),
"Output(BatchedGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchedGate) of LSTM should not be null.");
"Output(Cell) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionedX"),
"Output(AttentionedX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("AttentionFCOut"),
"Output(AttentionFCOut) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMX"),
"Output(LSTMX) of AttentionLSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("LSTMOUT"),
"Output(LSTMOUT) of AttentionLSTM should not be null.");
auto x_dims = ctx->GetInputDim("X");
const int M = x_dims[1];
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto w_dims = ctx->GetInputDim("LSTMWeight");
const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M,
D);
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
M, D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D);
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
// fc_out , shape (maxseqlen,1)
int max_seq_len = 0;
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2.");
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) "
"should be %d.",
x_dims[1]);
int frame_size = wx_dims[1] / 4;
auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2.");
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
"The second dimension of Input(WeightH) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"),
"Do not support peephole yet.");
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection",
frame_size);
framework::DDim out_dims({x_dims[0], frame_size});
auto atten_w_dims = ctx->GetInputDim("AttentionWeight");
PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2,
"Input(AttentionWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
PADDLE_ENFORCE_EQ(atten_w_dims[1], 1,
"AttentionWeight shapes must be (%d + %d) * 1.", M, D);
if (ctx->HasInput("AttentionBias")) {
auto atten_b_dims = ctx->GetInputDim("AttentionBias");
PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,
"Input(AttentionBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(atten_b_dims[0], 1,
"AttentionBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(atten_b_dims[1], 1,
"AttentionBias shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalar")) {
auto dims = ctx->GetInputDim("AttentionScalar");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalar)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1.");
}
if (ctx->HasInput("AttentionScalarBias")) {
auto dims = ctx->GetInputDim("AttentionScalarBias");
PADDLE_ENFORCE(
ctx->HasInput("AttentionScalar"),
"AttentionScalar should not be null when have AttentionScalarBias.");
PADDLE_ENFORCE_EQ(dims.size(), 2,
"Input(AttentionScalarBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1.");
PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1.");
}
framework::DDim out_dims({x_dims[0], D});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->SetOutputDim("AttentionedX", {x_dims[0], 1});
ctx->SetOutputDim("LSTMX", {1, M});
ctx->SetOutputDim("LSTMOUT", {1, 4 * D});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
......@@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() {
AddOutput("Cell",
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput(
"AttentionedX",
"(LodTensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size.")
AddOutput("AttentionedX",
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size.")
.AsIntermediate();
AddOutput("AttentionFCOut",
"(Tensor) (max_seq_len, 1), compute at each step.")
......@@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<LoDTensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>('AttentionFCOut'); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
// some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; // batch size
auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
int max_seq_len = x_lod[0][1];
for (int i = 1; i < N; ++i) {
int len = x_lod[0][i + 1] - x_lod[0][i];
max_seq_len = max_seq_len < len ? len : max_seq_len;
}
PADDLE_ENFORCE_EQ(x_lod.size(), 1, "Input(X)'s lod size must be 1.");
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>();
......@@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T* lstm_x_data = lstm_x->mutable_data<T>();
T* lstm_out_data = lstm_out->mutable_data<T>();
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto w_dims = w->dims(); // (D+M) x 4D
const int M = x_dims[1]; // x frame size
const int D = w_dims[1] / 4; // gate frame size
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = w_dims[1];
const int batch_size = x_lod[0].size() - 1; // assert lod.size() == 1
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, T, 1, M, x_data, atten_w_data,
......@@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < batch_size; ++i) {
for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1];
prev_cell_data = c0_data + i * D;
prev_hidden_data = h0 ? h0_data + i * D : NULL;
......@@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
/// compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
T prev_cell_bias = blas.VDOT(D, prev_cell_data, atten_w_data + M);
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// add cell bias and relu
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
// fc2: scalar
if (atten_scalar_data) {
// x = a*x
blas.VSCAL(seq_len, atten_scalar_data, fc_out_data);
blas.SCAL(seq_len, atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册