diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index f9761d6ec4019a2ceb65a024f6e8a13b3d958ce9..a6dc870bba499a142d7d5fa0fd07dd56a1a9b9e4 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -272,6 +272,10 @@ class FuisonLSTMKernel : public framework::OpKernel { T* hidden_out_data = hidden_out->mutable_data(place); T* cell_out_data = cell_out->mutable_data(place); + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, D4, M, x_data, wx_data, + xx_data, bias->data()); + // for peephole only Tensor checked_cell; T* checked_cell_data = nullptr; if (use_peepholes) { @@ -279,9 +283,6 @@ class FuisonLSTMKernel : public framework::OpKernel { checked_cell_data = checked_cell.mutable_data({2, D}, place); } - auto blas = math::GetBlas(ctx); - math::FCCompute(blas, total_T, D4, M, x_data, wx_data, - xx_data, bias->data()); int xx_offset = D4; int gate_offset = D; if (is_reverse) { @@ -299,6 +300,26 @@ class FuisonLSTMKernel : public framework::OpKernel { cell_out_data = cell_out_data + gate_offset; }; +#define GEMM_WH_ADDON \ + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), \ + prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4) + +#define GET_Ct \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + act_cand(D, xx_data, xx_data); \ + blas.VMUL(D, xx_data, xx_data + D, xx_data + D); \ + blas.VMUL(D, prev_c_data, xx_data + D2, xx_data + D2); \ + blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data) + +#define GET_Ht_AND_MOVE \ + /* H_t = act_cell(C_t) * ogated */ \ + act_cell(D, cell_out_data, xx_data + D2); \ + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); \ + /* get prev and move*/ \ + prev_h_data = hidden_out_data; \ + prev_c_data = cell_out_data; \ + move_step() + for (int i = 0; i < N; ++i) { int bid = is_reverse ? N - 1 - i : i; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; @@ -312,67 +333,49 @@ class FuisonLSTMKernel : public framework::OpKernel { // W_ch, W_ih, W_fh, W_oh act_gate(D, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data); - // C_t = input * tilde + // C_t = igated * cgated blas.VMUL(D, xx_data, xx_data + D, cell_out_data); - // H_t = act_state(cellout) * outgate + // get outgated if (use_peepholes) { - // + W_oc * C_t for peephole connection - // put result on W_ih + // put W_oc * C_t on igated blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); } act_gate(D, xx_data + D3, xx_data + D3); - act_cell(D, cell_out_data, xx_data + D2); - blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); - - // prev - prev_h_data = hidden_out_data; - prev_c_data = cell_out_data; + GET_Ht_AND_MOVE; tstart = 1; - move_step(); } - for (int step = tstart; step < seq_len; ++step) { - // + W_h * H_t-1 - blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), - prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4); - - // W_ch, W_ih, W_fh, W_oh - if (use_peepholes) { - // + W_ic|W_fc * C_t-1 for peephole connection + if (use_peepholes) { + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON; + // get fgated and igated blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); blas.VADD(D2, checked_cell_data, xx_data + D, xx_data + D); act_gate(D2, xx_data + D, xx_data + D); - } else { - act_gate(D3, xx_data + D, xx_data + D); - } - // a = I_t * act_cand(ch) - act_cand(D, xx_data, xx_data); - blas.VMUL(D, xx_data, xx_data + D, xx_data + D); - // b = C_t-1 * F_t - blas.VMUL(D, prev_c_data, xx_data + D2, xx_data + D2); - // C_t = a + b - blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + GET_Ct; - // H_t = act_cell(C_t) * act_gate(O_c += C_t * W_oc) - if (use_peepholes) { - // put result on W_ih + // get ogated blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); act_gate(D, xx_data + D3, xx_data + D3); - } - act_cell(D, cell_out_data, xx_data + D2); - blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); - - // prev - prev_h_data = hidden_out_data; - prev_c_data = cell_out_data; - - move_step(); - } // for seqlen - } // for batch + GET_Ht_AND_MOVE; + } // for seqlen + } else { + for (int step = tstart; step < seq_len; ++step) { + GEMM_WH_ADDON; + // W_ch, W_ih, W_fh, W_oh + act_gate(D3, xx_data + D, xx_data + D); + GET_Ct; + GET_Ht_AND_MOVE; + } // for seqlen + } + } // for batch +#undef GET_Ht_AND_MOVE +#undef GEMM_WH_ADDON +#undef GET_Ct } void BatchCompute(const framework::ExecutionContext& ctx) const {