diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 870292827dc76b6c5de496b74b88e7fe683edc8f..604c6f183979e8f21c32b51ff6690ab04f06d4e4 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -220,24 +220,105 @@ class FuisonLSTMKernel : public framework::OpKernel { void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); auto* bias = ctx.Input("Bias"); auto* xx = ctx.Output("XX"); + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D - const int M = x_dims[1]; // x frame size + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int N = x_lod[0].size() - 1; // batch size + const int M = x_dims[1]; // x frame size + const int D = wh_dims[0]; + const int D2 = D * 2; + const int D3 = D * 3; const int D4 = wh_dims[1]; const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0 ? c0->data() : NULL; const T* wx_data = wx->data(); + const T* wh_data = wh->data(); T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data()); + + for (int i = 0; i < N; ++i) { + int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + i * D; + prev_cell_data = c0_data + i * D; + } else { + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh(D, xx_data, xx_data); + // cell out= input*tilde + blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + tstart = 1; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + for (int step = tstart; step < seq_len; ++step) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, wh_data, D4, static_cast(1), xx_data, + D4); + + // W_ch, W_ih, W_fh, W_oh + // actgate + math::vec_sigmoid(D3, xx_data + D, xx_data + D); + // ch gate + math::vec_tanh(D, xx_data, xx_data); + + // a = forget * prev_cell + blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); + + // b = input * tilde + blas.VMUL(D, xx_data, xx_data + D, xx_data + D); + + // cell out= a+b + blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + + // hidden out= act_state(cellout) * outgate + // act state + math::vec_tanh(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + + // move offset + xx_data = xx_data + D4; + hidden_out_data = hidden_out_data + D; + cell_out_data = cell_out_data + D; + } + } } void BatchCompute(const framework::ExecutionContext& ctx) const {