提交 dd938d0b 编写于 作者: T tensor-tang

fix bugs and pass op test

上级 522b3e41
......@@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x (%d + %d).", M,
D);
PADDLE_ENFORCE_EQ(b_dims[1], M + D, "LSTMBias dims should be 1 x (%d + %d).",
M, D);
PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D);
PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D);
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2.");
......@@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() {
"(Tensor) the weights of attention fc. Always relu the fc result."
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"gate size of LSTM.");
AddInput("AttentionBias, optional",
"(Tensor) the bias of attention fc."
AddInput("AttentionBias",
"(Tensor, optional) the bias of attention fc."
"The shape is (1 x 1)")
.AsDispensable();
AddInput("AttentionScalar",
......@@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalar"); // 1x1
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
......@@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// }
const T* x_data = x->data<T>();
const T* h0_data = h0->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0->data<T>();
const T* lstm_w_data = lstm_w->data<T>();
const T* lstm_b_data = lstm_b->data<T>();
......@@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
atted_x_data, atten_b_data);
const T* cur_atten_x_data = atted_x_data;
const T* cur_x_data = x_data;
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
T* cur_cell_out_data = cell_out_data;
T* cur_hidden_out_data = hidden_out_data;
for (int i = 0; i < N; ++i) {
int seq_len = x_lod[0][i + 1];
int seq_len = x_lod[0][i + 1] - x_lod[0][i];
prev_cell_data = c0_data + i * D;
prev_hidden_data = h0 ? h0_data + i * D : NULL;
prev_hidden_data = h0_data ? h0_data + i * D : NULL;
for (int step = 0; step < seq_len; ++step) {
/// compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
/// 1. compute attention vector
// 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
T prev_cell_bias = blas.DOT(D, prev_cell_data, atten_w_data + M);
// add cell bias and relu
bias_relu<T>(seq_len, atted_x_data, &prev_cell_bias, fc_out_data);
// fc2: scalar
// 1b. add cell bias and relu
bias_relu<T>(seq_len, cur_atten_x_data, &prev_cell_bias, fc_out_data);
// 1c. fc scalar
if (atten_scalar_data) {
// x = a*x
blas.SCAL(seq_len, *atten_scalar_data, fc_out_data);
bias_relu<T>(seq_len, fc_out_data, atten_scalar_bias_data,
fc_out_data);
}
// 1d. softmax
vec_softmax<DeviceContext, T>(blas, seq_len, fc_out_data, fc_out_data);
// mul x(seq_len*M) and sum pool
math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
cur_x_data, lstm_x_data);
/// compute LSTM step
/// 2. compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
// shape : (D + M) x (4 * D)
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
......@@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
cur_hidden_out_data = cur_hidden_out_data + D;
}
cur_x_data = cur_x_data + seq_len * M;
cur_atten_x_data = cur_atten_x_data + seq_len;
}
}
};
......
......@@ -40,19 +40,20 @@ def attention_lstm(
D = b.shape[1] / 4
assert T == x.shape[0]
assert len(fcws) == len(fcbs)
hidden = []
cell = []
start_offset = 0
for bid in range(N):
seq_len = lod[0][bid]
xi = np.copy(x[start_offset:seq_len, :]).reshape(seq_len, M)
xi = np.copy(x[start_offset:start_offset + seq_len, :]).reshape(seq_len,
M)
prev_cell = np.copy(c0[bid]).reshape([1, D])
prev_hidden = np.copy(h0[bid]).reshape([1, D])
for step in range(seq_len):
expanded_cell = np.repeat(prev_cell, seq_len, axis=0)
tmp = np.concatenate((xi, expanded_cell), axis=1)
assert tmp.shape[0] == seq_len
assert tmp.shape[1] == M + D
for fcid in range(len(fcbs)):
tmp = fc(tmp, fcws[fcid], fcbs[fcid])
......@@ -62,7 +63,7 @@ def attention_lstm(
lstmx = xi * tmp # seq * M
lstmx = np.sum(lstmx.reshape(seq_len, M), axis=0).reshape([1, M])
lstmin = np.concatenate((prev_hidden, lstmx), axis=1)
lstmout = np.dot(lstmin, w).reshape([1, 4 * D])
lstmout = fc(lstmin, w, b).reshape([1, 4 * D])
g_f, g_i, g_o, cand = np.split(lstmout, 4, axis=1)
g_f = act_gate(g_f).reshape([1, D])
......@@ -88,7 +89,7 @@ def attention_lstm(
class TestAttentionLSTMOp(OpTest):
def set_conf(self):
self.lod = [[3]]
pass
def setUp(self):
self.op_type = 'attention_lstm'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册