提交 dd938d0b 编写于 作者: T tensor-tang

fix bugs and pass op test

上级 522b3e41
...@@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M, PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
D); PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
M, D);
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
...@@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() { ...@@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() {
"(Tensor) the weights of attention fc. Always relu the fc result." "(Tensor) the weights of attention fc. Always relu the fc result."
"The shape is ((M+D) x 1), where M is the dim size of x, D is the " "The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"gate size of LSTM."); "gate size of LSTM.");
AddInput("AttentionBias, optional", AddInput("AttentionBias",
"(Tensor) the bias of attention fc." "(Tensor, optional) the bias of attention fc."
"The shape is (1 x 1)") "The shape is (1 x 1)")
.AsDispensable(); .AsDispensable();
AddInput("AttentionScalar", AddInput("AttentionScalar",
...@@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1 auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1 auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1 auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalar"); // 1x1 auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4 auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4 auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
...@@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// } // }
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>(); const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0->data<T>(); const T* c0_data = c0->data<T>();
const T* lstm_w_data = lstm_w->data<T>(); const T* lstm_w_data = lstm_w->data<T>();
const T* lstm_b_data = lstm_b->data<T>(); const T* lstm_b_data = lstm_b->data<T>();
...@@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data, math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data); atted_x_data, atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data; const T* cur_x_data = x_data;
const T* prev_cell_data = NULL; const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL; const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data; T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data; T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1]; int seq_len = x_lod[0][i + 1] - x_lod[0][i];
prev_cell_data = c0_data + i * D; prev_cell_data = c0_data + i * D;
prev_hidden_data = h0 ? h0_data + i * D : NULL; prev_hidden_data = h0_data ? h0_data + i * D : NULL;
for (int step = 0; step < seq_len; ++step) { for (int step = 0; step < seq_len; ++step) {
/// compute attention vector /// 1. compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt // 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M); T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// add cell bias and relu // 1b. add cell bias and relu
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data); bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
// fc2: scalar // 1c. fc scalar
if (atten_scalar_data) { if (atten_scalar_data) {
// x = a*x
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data); blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data, bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data); fc_out_data);
} }
// 1d. softmax
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data); vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool // mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data, math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data); cur_x_data, lstm_x_data);
/// compute LSTM step /// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde] // lstm weight : concat[forget , input , output , tilde]
// shape : (D + M) x (4 * D) // shape : (D + M) x (4 * D)
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D // fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
...@@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
cur_hidden_out_data = cur_hidden_out_data + D; cur_hidden_out_data = cur_hidden_out_data + D;
} }
cur_x_data = cur_x_data + seq_len * M; cur_x_data = cur_x_data + seq_len * M;
cur_atten_x_data = cur_atten_x_data + seq_len;
} }
} }
}; };
......
...@@ -40,19 +40,20 @@ def attention_lstm( ...@@ -40,19 +40,20 @@ def attention_lstm(
D = b.shape[1] / 4 D = b.shape[1] / 4
assert T == x.shape[0] assert T == x.shape[0]
assert len(fcws) == len(fcbs) assert len(fcws) == len(fcbs)
hidden = [] hidden = []
cell = [] cell = []
start_offset = 0 start_offset = 0
for bid in range(N): for bid in range(N):
seq_len = lod[0][bid] seq_len = lod[0][bid]
xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M) xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len,
M)
prev_cell = np.copy(c0[bid]).reshape([1, D]) prev_cell = np.copy(c0[bid]).reshape([1, D])
prev_hidden = np.copy(h0[bid]).reshape([1, D]) prev_hidden = np.copy(h0[bid]).reshape([1, D])
for step in range(seq_len): for step in range(seq_len):
expanded_cell = np.repeat(prev_cell, seq_len, axis=0) expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
tmp = np.concatenate((xi, expanded_cell), axis=1) tmp = np.concatenate((xi, expanded_cell), axis=1)
assert tmp.shape[0] == seq_len
assert tmp.shape[1] == M + D assert tmp.shape[1] == M + D
for fcid in range(len(fcbs)): for fcid in range(len(fcbs)):
tmp = fc(tmp, fcws[fcid], fcbs[fcid]) tmp = fc(tmp, fcws[fcid], fcbs[fcid])
...@@ -62,7 +63,7 @@ def attention_lstm( ...@@ -62,7 +63,7 @@ def attention_lstm(
lstmx = xi * tmp # seq * M lstmx = xi * tmp # seq * M
lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M]) lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
lstmin = np.concatenate((prev_hidden, lstmx), axis=1) lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
lstmout = np.dot(lstmin, w).reshape([1, 4 * D]) lstmout = fc(lstmin, w, b).reshape([1, 4 * D])
g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1) g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
g_f = act_gate(g_f).reshape([1, D]) g_f = act_gate(g_f).reshape([1, D])
...@@ -88,7 +89,7 @@ def attention_lstm( ...@@ -88,7 +89,7 @@ def attention_lstm(
class TestAttentionLSTMOp(OpTest): class TestAttentionLSTMOp(OpTest):
def set_conf(self): def set_conf(self):
self.lod = [[3]] pass
def setUp(self): def setUp(self):
self.op_type = 'attention_lstm' self.op_type = 'attention_lstm'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册