From d7ac1cc83642bf19b133752156c57883000324a1 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 5 Sep 2018 18:32:48 +0800 Subject: [PATCH] refine seq when bs is large --- paddle/fluid/operators/fusion_lstm_op.cc | 87 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 2 +- 2 files changed, 59 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 90736137c..ef23ab3f9 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -285,18 +285,23 @@ class FuisonLSTMKernel : public framework::OpKernel { act_cell(D, ct, gates + D2); \ blas.VMUL(D, gates + D2, gates + D3, ht) -#define COMPUTE_CtHt_WITHOUT_H0C0(gates, ct, ht) \ - act_gate(D, gates + D, gates + D); \ - act_cand(D, gates, gates); \ - /* C_t = igated * cgated*/ \ - blas.VMUL(D, gates, gates + D, ct); \ - /* get outgated*/ \ - if (use_peepholes) { \ - /* put W_oc * C_t on igated */ \ - blas.VMUL(D, wc_data + D2, ct, gates + D); \ - blas.VADD(D, gates + D, gates + D3, gates + D3); \ - } \ - act_gate(D, gates + D3, gates + D3); \ +#define GET_Ct_NOH0C0(gates, ct) \ + /* C_t = igated * cgated*/ \ + act_gate(D, gates + D, gates + D); \ + act_cand(D, gates, gates); \ + blas.VMUL(D, gates, gates + D, ct) + +#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + act_gate(D, gates + D3, gates + D3); \ + GET_Ht(ct, gates, ht) + +#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \ + GET_Ct_NOH0C0(gates, ct); \ + /* get outgated, put W_oc * C_t on igated */ \ + blas.VMUL(D, wc_data + D2, ct, gates + D); \ + blas.VADD(D, gates + D, gates + D3, gates + D3); \ + act_gate(D, gates + D3, gates + D3); \ GET_Ht(ct, gates, ht) #define COMPUTE_CtHt(gates, ct_1, ct, ht) \ @@ -354,24 +359,38 @@ class FuisonLSTMKernel : public framework::OpKernel { h_out_data = h_out_data + gate_offset; \ c_out_data = c_out_data + gate_offset -#define PROCESS_H0C0 \ - int bid = is_reverse ? N - 1 - i : i; \ - int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ - const T* prev_c_data = nullptr; \ - const T* prev_h_data = nullptr; \ - int tstart = 0; \ - if (h0_data) { \ - prev_h_data = h0_data + bid * D; \ - prev_c_data = c0_data + bid * D; \ - } else { \ - COMPUTE_CtHt_WITHOUT_H0C0(xx_data, c_out_data, h_out_data); \ - MOVE_ONE_STEP; \ - tstart = 1; \ +#define PROCESS_H0C0_DEFINES \ + int bid = is_reverse ? N - 1 - i : i; \ + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \ + const T* prev_c_data = nullptr; \ + const T* prev_h_data = nullptr; \ + int tstart = 0 + +#define PROCESS_H0C0_PEEPHOLE \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ + } + +#define PROCESS_H0C0 \ + PROCESS_H0C0_DEFINES; \ + if (h0_data) { \ + prev_h_data = h0_data + bid * D; \ + prev_c_data = c0_data + bid * D; \ + } else { \ + COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \ + MOVE_ONE_STEP; \ + tstart = 1; \ } if (use_peepholes) { for (int i = 0; i < N; ++i) { - PROCESS_H0C0; + PROCESS_H0C0_PEEPHOLE for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); COMPUTE_CtHt_PEEPHOLE(xx_data, prev_c_data, c_out_data, h_out_data); @@ -380,7 +399,7 @@ class FuisonLSTMKernel : public framework::OpKernel { } } else { for (int i = 0; i < N; ++i) { - PROCESS_H0C0; + PROCESS_H0C0 for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); @@ -388,6 +407,8 @@ class FuisonLSTMKernel : public framework::OpKernel { } } } +#undef PROCESS_H0C0_DEFINES +#undef PROCESS_H0C0_PEEPHOLE #undef PROCESS_H0C0 #undef MOVE_ONE_STEP } @@ -460,7 +481,13 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cur_h_out_data = batched_h_out_data; T* cur_c_out_data = batched_c_out_data; for (int i = 0; i < max_bs; ++i) { - COMPUTE_CtHt_WITHOUT_H0C0(cur_in_data, cur_c_out_data, cur_h_out_data); + GET_Ct_NOH0C0(cur_in_data, cur_c_out_data); + if (use_peepholes) { + blas.VMUL(D, wc_data + D2, cur_c_out_data, cur_in_data + D); + blas.VADD(D, cur_in_data + D, cur_in_data + D3, cur_in_data + D3); + } + act_gate(D, cur_in_data + D3, cur_in_data + D3); + GET_Ht(cur_c_out_data, cur_in_data, cur_h_out_data); cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; @@ -541,7 +568,9 @@ class FuisonLSTMKernel : public framework::OpKernel { #undef COMPUTE_CtHt_PEEPHOLE #undef COMPUTE_CtHt -#undef COMPUTE_CtHt_WITHOUT_H0C0 +#undef GET_Ct_NOH0C0 +#undef COMPUTE_CtHt_NOH0C0 +#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 #undef GET_Ht #undef GET_Ct #undef GEMM_WH_ADDON diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 6ffb52185..de0c86f96 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -183,7 +183,7 @@ class TestFusionLSTMOpPeepholesInitReverse(TestFusionLSTMOp): self.is_reverse = True -class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): +class TestFusionLSTMOpPeepholesBS1(TestFusionLSTMOp): def set_conf(self): self.use_peepholes = True self.lod = [[2]] -- GitLab