From 9dd5a177a55f1d2c052c42a511a0eb2cceb3a2c3 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 5 Sep 2018 17:00:08 +0800 Subject: [PATCH] refine batch mode and peephole --- paddle/fluid/operators/fusion_lstm_op.cc | 407 ++++++++++------------- 1 file changed, 179 insertions(+), 228 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index a6dc870bba4..90736137c6c 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -252,154 +252,162 @@ class FuisonLSTMKernel : public framework::OpKernel { const int D3 = D * 3; \ const int D4 = wh_dims[1]; +#define INIT_BASE_INPUT_DATAS \ + const T* x_data = x->data(); \ + const T* wx_data = wx->data(); \ + const T* wh_data = wh->data(); \ + /* diagonal weight*/ \ + const T* wc_data = bias->data() + D4; \ + /* for peephole only*/ \ + Tensor checked_cell; \ + T* checked_cell_data = nullptr; \ + auto place = ctx.GetPlace(); \ + if (use_peepholes) { \ + /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ + checked_cell_data = checked_cell.mutable_data({2, D}, place); \ + } + +/// Compute LSTM +#define GEMM_WH_ADDON(bs, prev, out) \ + blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ + wh_data, D4, static_cast(1), out, D4) + +// gates: W_ch, W_ih, W_fh, W_oh +#define GET_Ct(ct_1, gates, ct) \ + /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, gates + D); \ + blas.VMUL(D, ct_1, gates + D2, gates + D2); \ + blas.VADD(D, gates + D, gates + D2, ct) + +#define GET_Ht(ct, gates, ht) \ + /* H_t = act_cell(C_t) * ogated */ \ + act_cell(D, ct, gates + D2); \ + blas.VMUL(D, gates + D2, gates + D3, ht) + +#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \ + act_gate(D, gates + D, gates + D); \ + act_cand(D, gates, gates); \ + /* C_t = igated * cgated*/ \ + blas.VMUL(D, gates, gates + D, ct); \ + /* get outgated*/ \ + if (use_peepholes) { \ + /* put W_oc * C_t on igated */ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + } \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt(gates, ct_1, ct, ht) \ + act_gate(D3, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ + /* get fgated and igated*/ \ + blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ + blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \ + blas.VADD(D2, checked_cell_data, gates + D, gates + D); \ + act_gate(D2, gates + D, gates + D); \ + GET_Ct(ct_1, gates, ct); \ + /* get ogated*/ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT INIT_BASE_SIZES INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS auto x_lod = x->lod(); const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; - const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; const T* c0_data = c0 ? c0->data() : nullptr; - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - const T* wc_data = bias->data() + D4; // diagonal weight - auto place = ctx.GetPlace(); T* xx_data = xx->mutable_data(place); - T* hidden_out_data = hidden_out->mutable_data(place); - T* cell_out_data = cell_out->mutable_data(place); - + T* h_out_data = hidden_out->mutable_data(place); + T* c_out_data = cell_out->mutable_data(place); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); - // for peephole only - Tensor checked_cell; - T* checked_cell_data = nullptr; - if (use_peepholes) { - // w_ic * Ct-1, w_fc * Ct-1 // , w_oc * Ct => ih - checked_cell_data = checked_cell.mutable_data({2, D}, place); - } int xx_offset = D4; int gate_offset = D; if (is_reverse) { const int offset = (total_T - 1) * D; xx_data = xx_data + offset * 4; - hidden_out_data = hidden_out_data + offset; - cell_out_data = cell_out_data + offset; + h_out_data = h_out_data + offset; + c_out_data = c_out_data + offset; xx_offset = -D4; gate_offset = -D; } - auto move_step = [&]() { - xx_data = xx_data + xx_offset; - hidden_out_data = hidden_out_data + gate_offset; - cell_out_data = cell_out_data + gate_offset; - }; - -#define GEMM_WH_ADDON \ - blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), \ - prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4) - -#define GET_Ct \ - /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ - act_cand(D, xx_data, xx_data); \ - blas.VMUL(D, xx_data, xx_data + D, xx_data + D); \ - blas.VMUL(D, prev_c_data, xx_data + D2, xx_data + D2); \ - blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data) - -#define GET_Ht_AND_MOVE \ - /* H_t = act_cell(C_t) * ogated */ \ - act_cell(D, cell_out_data, xx_data + D2); \ - blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); \ - /* get prev and move*/ \ - prev_h_data = hidden_out_data; \ - prev_c_data = cell_out_data; \ - move_step() - - for (int i = 0; i < N; ++i) { - int bid = is_reverse ? N - 1 - i : i; - int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; - const T* prev_c_data = nullptr; - const T* prev_h_data = nullptr; - int tstart = 0; - if (h0_data) { - prev_h_data = h0_data + bid * D; - prev_c_data = c0_data + bid * D; - } else { - // W_ch, W_ih, W_fh, W_oh - act_gate(D, xx_data + D, xx_data + D); - act_cand(D, xx_data, xx_data); - // C_t = igated * cgated - blas.VMUL(D, xx_data, xx_data + D, cell_out_data); - - // get outgated - if (use_peepholes) { - // put W_oc * C_t on igated - blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); - blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); - } - act_gate(D, xx_data + D3, xx_data + D3); - GET_Ht_AND_MOVE; - tstart = 1; - } +#define MOVE_ONE_STEP \ + prev_h_data = h_out_data; \ + prev_c_data = c_out_data; \ + xx_data = xx_data + xx_offset; \ + h_out_data = h_out_data + gate_offset; \ + c_out_data = c_out_data + gate_offset + +#define PROCESS_H0C0 \ + int bid = is_reverse ? N - 1 - i : i; \ + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ + const T* prev_c_data = nullptr; \ + const T* prev_h_data = nullptr; \ + int tstart = 0; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } - if (use_peepholes) { + if (use_peepholes) { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0; for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON; - // get fgated and igated - blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); - blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); - blas.VADD(D2, checked_cell_data, xx_data + D, xx_data + D); - act_gate(D2, xx_data + D, xx_data + D); - GET_Ct; - - // get ogated - blas.VMUL(D, wc_data + D2, cell_out_data, xx_data + D); - blas.VADD(D, xx_data + D, xx_data + D3, xx_data + D3); - act_gate(D, xx_data + D3, xx_data + D3); - GET_Ht_AND_MOVE; - } // for seqlen - } else { + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; + } + } + } else { + for (int i = 0; i < N; ++i) { + PROCESS_H0C0; for (int step = tstart; step < seq_len; ++step) { - GEMM_WH_ADDON; - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, xx_data + D, xx_data + D); - GET_Ct; - GET_Ht_AND_MOVE; - } // for seqlen + GEMM_WH_ADDON(1, prev_h_data, xx_data); + COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + MOVE_ONE_STEP; + } } - } // for batch -#undef GET_Ht_AND_MOVE -#undef GEMM_WH_ADDON -#undef GET_Ct + } +#undef PROCESS_H0C0 +#undef MOVE_ONE_STEP } void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT - if (x->lod()[0].size() == 2) { // batch size == 1 + if (x->lod()[0].size() == 2) { SeqCompute(ctx); return; } INIT_BASE_SIZES INIT_VEC_FUNC + INIT_BASE_INPUT_DATAS auto* reordered_h0 = ctx.Output("ReorderedH0"); auto* reordered_c0 = ctx.Output("ReorderedC0"); auto* batched_input = ctx.Output("BatchedInput"); auto* batched_c_out = ctx.Output("BatchedCell"); auto* batched_h_out = ctx.Output("BatchedHidden"); - - const T* x_data = x->data(); - const T* wx_data = wx->data(); - const T* wh_data = wh->data(); - const T* bias_data = bias->data(); - const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc - auto place = ctx.GetPlace(); T* xx_data = xx->mutable_data(place); T* batched_input_data = batched_input->mutable_data(place); T* batched_c_out_data = batched_c_out->mutable_data(place); @@ -407,12 +415,6 @@ class FuisonLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); - // use local variable - framework::DDim check_dims({3, D}); - Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct - auto checked_cell_data = - checked_cell.mutable_data(check_dims, ctx.GetPlace()); - math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); @@ -434,27 +436,17 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_h0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D}); - T* prev_batch_h_data = nullptr; - T* prev_batch_c_data = nullptr; - T* cur_batch_in_data = batched_input_data; - T* cur_batch_h_out_data = batched_h_out_data; - T* cur_batch_c_out_data = batched_c_out_data; - - auto move_step = [&](int bs) { - cur_batch_in_data += bs * D4; - cur_batch_c_out_data += bs * D; - cur_batch_h_out_data += bs * D; - }; - int tstart = 0; + T* prev_h_data = nullptr; + T* prev_c_data = nullptr; if (h0) { // reorder h0, c0 T* reordered_h0_data = reordered_h0->mutable_data(place); T* reordered_c0_data = reordered_c0->mutable_data(place); const T* h0_data = h0->data(); const T* c0_data = c0->data(); - prev_batch_h_data = reordered_h0_data; - prev_batch_c_data = reordered_c0_data; + prev_h_data = reordered_h0_data; + prev_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); @@ -463,123 +455,74 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_c0_data += D; } } else { - // Compute with no H0/C0 - T* cur_in_data = cur_batch_in_data; - T* cur_c_out_data = cur_batch_c_out_data; - T* cur_h_out_data = cur_batch_h_out_data; - - // If step == 0 and there is no initialized hidden state, that is to say - // the H0 is zeros. Then W_h * H_t-1 can be skiped - - for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch - // ~C_t - act_cand(D, cur_in_data, cur_in_data); - - if (use_peepholes) { - // I_t, F_t - act_gate(D2, cur_in_data + D, cur_in_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, cur_in_data + D, cur_in_data + D); - } - - // C_t = I_t * ~C_t - blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data); - - if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); - blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, - cur_in_data + D3); - // O_t - act_gate(D, cur_in_data + D3, cur_in_data + D3); - } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cur_c_out_data, cur_in_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - - // move to next data in the same batch + // compute without h0, c0 + T* cur_in_data = batched_input_data; + T* cur_h_out_data = batched_h_out_data; + T* cur_c_out_data = batched_c_out_data; + for (int i = 0; i < max_bs; ++i) { + COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data); cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; } - - // move to data for next timestep - prev_batch_h_data = cur_batch_h_out_data; - prev_batch_c_data = cur_batch_c_out_data; - move_step(max_bs); tstart = 1; + prev_h_data = batched_h_out_data; + prev_c_data = batched_c_out_data; } - const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; - for (int step = tstart; step < max_seq_len; ++step) { - const int cur_bs = batch_starts[step + 1] - batch_starts[step]; - // + W_h * H_t-1 - blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast(1), - prev_batch_h_data, D, wh_data, D4, static_cast(1), - cur_batch_in_data, D4); - - T* cur_in_data = cur_batch_in_data; - T* cur_c_out_data = cur_batch_c_out_data; - T* cur_h_out_data = cur_batch_h_out_data; - T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0 - T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0 - auto next_data_in_batch = [&]() { - cur_in_data += D4; - cur_c_out_data += D; - cur_h_out_data += D; - prev_c_data = prev_c_data ? prev_c_data + D : nullptr; - prev_h_data = prev_h_data ? prev_h_data + D : nullptr; - }; - - for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch - // ~C_t - act_cand(D, cur_in_data, cur_in_data); - - if (use_peepholes) { - // + W_ic|W_fc * C_t-1 for peephole connection - blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); - blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); - blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D); - // I_t, F_t - act_gate(D2, cur_in_data + D, cur_in_data + D); - } else { - // I_t, F_t, O_t - act_gate(D3, cur_in_data + D, cur_in_data + D); - } + const int offset = tstart * max_bs * D; + batched_input_data = batched_input_data + offset * 4; + batched_h_out_data = batched_h_out_data + offset; + batched_c_out_data = batched_c_out_data + offset; + +#define DEFINE_CUR \ + T* cur_in_data = batched_input_data; \ + T* cur_prev_c_data = prev_c_data; \ + T* cur_c_out_data = batched_c_out_data; \ + T* cur_h_out_data = batched_h_out_data + +#define MOVE_ONE_BATCH \ + cur_in_data += D4; \ + cur_prev_c_data += D; \ + cur_c_out_data += D; \ + cur_h_out_data += D + +#define MOVE_ONE_STEP \ + prev_c_data = batched_c_out_data; \ + prev_h_data = batched_h_out_data; \ + batched_c_out_data = cur_c_out_data; \ + batched_h_out_data = cur_h_out_data; \ + batched_input_data = cur_in_data - // F_t * C_t-1 - blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2); - // I_t * ~C_t - blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); - // C_t = F_t * C_t-1 + I_t * ~C_t - blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); - - if (use_peepholes) { - // + W_oc * C_t for peephole connection - blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); - blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, - cur_in_data + D3); - // O_t - act_gate(D, cur_in_data + D3, cur_in_data + D3); + if (use_peepholes) { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt_PEEPHOLE(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; } - - // hidden out= act_state(cellout) * outgate - act_cell(D, cur_c_out_data, cur_in_data + D2); - // H_t = O_t * act_state(C_t) - blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - - // move to next data in same batch - next_data_in_batch(); + MOVE_ONE_STEP; + } + } else { + for (int step = tstart; step < max_seq_len; ++step) { + const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); + DEFINE_CUR; + for (int i = 0; i < cur_bs; ++i) { + COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); + MOVE_ONE_BATCH; + } + MOVE_ONE_STEP; } - // move to data for next timestep - prev_batch_h_data = cur_batch_h_out_data; - prev_batch_c_data = cur_batch_c_out_data; - move_step(cur_bs); } +#undef MOVE_ONE_STEP +#undef MOVE_ONE_BATCH +#undef DEFINE_CUR math::Batch2LoDTensorFunctor to_seq; batched_h_out->set_lod(batched_lod); @@ -595,6 +538,14 @@ class FuisonLSTMKernel : public framework::OpKernel { BatchCompute(ctx); } } + +#undef COMPUTE_CtHt_PEEPHOLE +#undef COMPUTE_CtHt +#undef COMPUTE_CtHt_WITHOUT_H0C0 +#undef GET_Ht +#undef GET_Ct +#undef GEMM_WH_ADDON +#undef INIT_BASE_INPUT_DATAS #undef INIT_BASE_SIZES #undef INIT_BASE_INPUT_OUTPUT #undef INIT_VEC_FUNC -- GitLab