diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 14985a3f74aa234d22a774dee3c9b46c75d24d8f..5d57703c0b9dace80b6613624e40b2fc5b7b2c1d 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto b_dims = ctx->GetInputDim("LSTMBias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M, - D); - PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).", - M, D); + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); @@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() { "(Tensor) the weights of attention fc. Always relu the fc result." "The shape is ((M+D) x 1), where M is the dim size of x, D is the " "gate size of LSTM."); - AddInput("AttentionBias, optional", - "(Tensor) the bias of attention fc." + AddInput("AttentionBias", + "(Tensor, optional) the bias of attention fc." "The shape is (1 x 1)") .AsDispensable(); AddInput("AttentionScalar", @@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel { auto* atten_w = ctx.Input("AttentionWeight"); // (M+D) x 1 auto* atten_b = ctx.Input("AttentionBias"); // 1x1 auto* atten_scalar = ctx.Input("AttentionScalar"); // 1x1 - auto* atten_scalar_bias = ctx.Input("AttentionScalar"); // 1x1 + auto* atten_scalar_bias = ctx.Input("AttentionScalarBias"); // 1x1 auto* lstm_w = ctx.Input("LSTMWeight"); // (D+M) x D*4 auto* lstm_b = ctx.Input("LSTMBias"); // 1 x D*4 @@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // } const T* x_data = x->data(); - const T* h0_data = h0->data(); + const T* h0_data = h0 ? h0->data() : NULL; const T* c0_data = c0->data(); const T* lstm_w_data = lstm_w->data(); const T* lstm_b_data = lstm_b->data(); @@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel { math::FCCompute(blas, total_T, 1, M, x_data, atten_w_data, atted_x_data, atten_b_data); + const T* cur_atten_x_data = atted_x_data; const T* cur_x_data = x_data; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; T* cur_cell_out_data = cell_out_data; T* cur_hidden_out_data = hidden_out_data; for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1]; + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; prev_cell_data = c0_data + i * D; - prev_hidden_data = h0 ? h0_data + i * D : NULL; - + prev_hidden_data = h0_data ? h0_data + i * D : NULL; for (int step = 0; step < seq_len; ++step) { - /// compute attention vector - // prev_cell(1xD) * fc(D) rest part of atten_wgt - // T = cblas_dot(); + /// 1. compute attention vector + // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); - // add cell bias and relu - bias_relu(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); - // fc2: scalar + // 1b. add cell bias and relu + bias_relu(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data); + // 1c. fc scalar if (atten_scalar_data) { - // x = a*x blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); bias_relu(seq_len, fc_out_data, atten_scalar_bias_data, fc_out_data); } + // 1d. softmax vec_softmax(blas, seq_len, fc_out_data, fc_out_data); // mul x(seq_len*M) and sum pool math::FCCompute(blas, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data); - /// compute LSTM step + /// 2. compute LSTM step // lstm weight : concat[forget , input , output , tilde] // shape : (D + M) x (4 * D) // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D @@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel { cur_hidden_out_data = cur_hidden_out_data + D; } cur_x_data = cur_x_data + seq_len * M; + cur_atten_x_data = cur_atten_x_data + seq_len; } } }; diff --git a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py index dea6ec7668934eb0b689095f3719540cbb204560..cb02c7e5868774395e8129c4c4ea6793c93113fa 100644 --- a/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_attention_lstm_op.py @@ -40,19 +40,20 @@ def attention_lstm( D = b.shape[1] / 4 assert T == x.shape[0] assert len(fcws) == len(fcbs) - hidden = [] cell = [] start_offset = 0 for bid in range(N): seq_len = lod[0][bid] - xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M) + xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len, + M) prev_cell = np.copy(c0[bid]).reshape([1, D]) prev_hidden = np.copy(h0[bid]).reshape([1, D]) for step in range(seq_len): expanded_cell = np.repeat(prev_cell, seq_len, axis=0) tmp = np.concatenate((xi, expanded_cell), axis=1) + assert tmp.shape[0] == seq_len assert tmp.shape[1] == M + D for fcid in range(len(fcbs)): tmp = fc(tmp, fcws[fcid], fcbs[fcid]) @@ -62,7 +63,7 @@ def attention_lstm( lstmx = xi * tmp # seq * M lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) lstmin = np.concatenate((prev_hidden, lstmx), axis=1) - lstmout = np.dot(lstmin, w).reshape([1, 4 * D]) + lstmout = fc(lstmin, w, b).reshape([1, 4 * D]) g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) g_f = act_gate(g_f).reshape([1, D]) @@ -88,7 +89,7 @@ def attention_lstm( class TestAttentionLSTMOp(OpTest): def set_conf(self): - self.lod = [[3]] + pass def setUp(self): self.op_type = 'attention_lstm'