From 04272c0d4124d3c69718be1b8801a07081969ced Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Wed, 5 Sep 2018 10:08:39 +0800 Subject: [PATCH] Enable lstm peephole (#13160) * Refine fusion lstm op code for better readability * Enable peephole in fusion lstm op (seq_mode part) and add unit test * Enable peephole in fused lstop op (batch_mode part) Set batch_mode as default as well * Use pre-commit to clean format * Follow up review comments as well as adding more unit tests for seq mode --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 1 - paddle/fluid/operators/fusion_lstm_op.cc | 265 +++++++++++++----- .../tests/unittests/test_fusion_lstm_op.py | 65 +++++ 3 files changed, 257 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 00f5e7fad2e..55153ecc3ed 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index f91236975d0..104e160e2d7 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), - "Do not support peephole yet."); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + auto use_peepholes = ctx->Attrs().Get("use_peepholes"); + PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size, "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes connection", - frame_size); + "7 * %d if enable peepholes connection or" + "4 * %d if disable peepholes", + frame_size, frame_size); framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); @@ -232,16 +232,17 @@ class FuisonLSTMKernel : public framework::OpKernel { act_cand = act_functor(act_cand_str); \ } -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ +#define INIT_BASE_INPUT_OUTPUT \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); \ bool is_reverse = ctx.Attr("is_reverse"); #define INIT_BASE_SIZES \ @@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel { const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; const T* c0_data = c0 ? c0->data() : nullptr; + const T* bias_data = bias->data(); + const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc const T* wx_data = wx->data(); const T* wh_data = wh->data(); + T* xx_data = xx->mutable_data(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + // use local variable + framework::DDim check_dims({3, D}); + Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct + auto checked_cell_data = + checked_cell.mutable_data(check_dims, ctx.GetPlace()); + auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); @@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel { int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_c_data = nullptr; const T* prev_h_data = nullptr; + int tstart = 0; if (h0_data) { prev_h_data = h0_data + bid * D; prev_c_data = c0_data + bid * D; } else { - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, xx_data + D, xx_data + D); + // If step == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros. Then W_h * H_t-1 can be skipped + + // ~C_t act_cand(D, xx_data, xx_data); - // cell out= input*tilde + if (use_peepholes) { + // I_t, F_t + act_gate(D2, xx_data + D, xx_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, xx_data + D, xx_data + D); + } + // C_t = I_t * ~C_t blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); + blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); + // O_t + act_gate(D, xx_data + D3, xx_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev prev_h_data = hidden_out_data; prev_c_data = cell_out_data; - tstart = 1; + tstart = 1; move_step(); } + for (int step = tstart; step < seq_len; ++step) { + // + W_h * H_t-1 blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4); - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, xx_data + D, xx_data + D); + // ~C_t act_cand(D, xx_data, xx_data); - // a = forget * prev_cell + if (use_peepholes) { + // + W_ic|W_fc * C_t-1 for peephole connection + blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); + blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); + blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D); + // I_t, F_t + act_gate(D2, xx_data + D, xx_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, xx_data + D, xx_data + D); + } + + // F_t * C_t-1 blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); - - // b = input * tilde + // I_t * ~C_t blas.VMUL(D, xx_data, xx_data + D, xx_data + D); - - // cell out= a+b + // C_t = F_t * C_t-1 + I_t * ~C_t blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); + blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); + // O_t + act_gate(D, xx_data + D3, xx_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_c_data = cell_out_data; move_step(); - } - } + } // for each step in batch + } // for each batch } void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT - if (x->lod()[0].size() == 2) { + if (x->lod()[0].size() == 2) { // batch size == 1 SeqCompute(ctx); return; } @@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel { const T* x_data = x->data(); const T* wx_data = wx->data(); const T* wh_data = wh->data(); + const T* bias_data = bias->data(); + const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc auto place = ctx.GetPlace(); T* xx_data = xx->mutable_data(place); T* batched_input_data = batched_input->mutable_data(place); @@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); + // use local variable + framework::DDim check_dims({3, D}); + Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct + auto checked_cell_data = + checked_cell.mutable_data(check_dims, ctx.GetPlace()); + math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); @@ -396,17 +454,27 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_h0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D}); + T* prev_batch_h_data = nullptr; + T* prev_batch_c_data = nullptr; + T* cur_batch_in_data = batched_input_data; + T* cur_batch_h_out_data = batched_h_out_data; + T* cur_batch_c_out_data = batched_c_out_data; + + auto move_step = [&](int bs) { + cur_batch_in_data += bs * D4; + cur_batch_c_out_data += bs * D; + cur_batch_h_out_data += bs * D; + }; + int tstart = 0; - T* prev_h_data = nullptr; - T* prev_c_data = nullptr; if (h0) { // reorder h0, c0 T* reordered_h0_data = reordered_h0->mutable_data(place); T* reordered_c0_data = reordered_c0->mutable_data(place); const T* h0_data = h0->data(); const T* c0_data = c0->data(); - prev_h_data = reordered_h0_data; - prev_c_data = reordered_c0_data; + prev_batch_h_data = reordered_h0_data; + prev_batch_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); @@ -415,71 +483,122 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_c0_data += D; } } else { - // compute without h0, c0 - T* cur_in_data = batched_input_data; - T* cur_h_out_data = batched_h_out_data; - T* cur_c_out_data = batched_c_out_data; - // W_ch, W_ih, W_fh, W_oh - for (int i = 0; i < max_bs; ++i) { - act_gate(D3, cur_in_data + D, cur_in_data + D); + // Compute with no H0/C0 + T* cur_in_data = cur_batch_in_data; + T* cur_c_out_data = cur_batch_c_out_data; + T* cur_h_out_data = cur_batch_h_out_data; + + // If step == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros. Then W_h * H_t-1 can be skiped + + for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch + // ~C_t act_cand(D, cur_in_data, cur_in_data); - // cell out= input*tilde + + if (use_peepholes) { + // I_t, F_t + act_gate(D2, cur_in_data + D, cur_in_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, cur_in_data + D, cur_in_data + D); + } + + // C_t = I_t * ~C_t blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); + blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, + cur_in_data + D3); + // O_t + act_gate(D, cur_in_data + D3, cur_in_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cur_c_out_data, cur_in_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - // add offset + // move to next data in the same batch cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; } + + // move to data for next timestep + prev_batch_h_data = cur_batch_h_out_data; + prev_batch_c_data = cur_batch_c_out_data; + move_step(max_bs); tstart = 1; - prev_h_data = batched_h_out_data; - prev_c_data = batched_c_out_data; } - // Then start from next + const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; - const int offset = tstart * max_bs * D; - batched_input_data = batched_input_data + offset * 4; - batched_h_out_data = batched_h_out_data + offset; - batched_c_out_data = batched_c_out_data + offset; for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + // + W_h * H_t-1 blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast(1), - prev_h_data, D, wh_data, D4, static_cast(1), - batched_input_data, D4); - - T* cur_in_data = batched_input_data; - T* cur_prev_c_data = prev_c_data; - T* cur_c_out_data = batched_c_out_data; - T* cur_h_out_data = batched_h_out_data; - for (int i = 0; i < cur_bs; ++i) { - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, cur_in_data + D, cur_in_data + D); + prev_batch_h_data, D, wh_data, D4, static_cast(1), + cur_batch_in_data, D4); + + T* cur_in_data = cur_batch_in_data; + T* cur_c_out_data = cur_batch_c_out_data; + T* cur_h_out_data = cur_batch_h_out_data; + T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0 + T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0 + auto next_data_in_batch = [&]() { + cur_in_data += D4; + cur_c_out_data += D; + cur_h_out_data += D; + prev_c_data = prev_c_data ? prev_c_data + D : nullptr; + prev_h_data = prev_h_data ? prev_h_data + D : nullptr; + }; + + for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch + // ~C_t act_cand(D, cur_in_data, cur_in_data); - // a = forget * prev_cell - blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2); - // b = input * tilde + + if (use_peepholes) { + // + W_ic|W_fc * C_t-1 for peephole connection + blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); + blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); + blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D); + // I_t, F_t + act_gate(D2, cur_in_data + D, cur_in_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, cur_in_data + D, cur_in_data + D); + } + + // F_t * C_t-1 + blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2); + // I_t * ~C_t blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); - // cell out= a+b + // C_t = F_t * C_t-1 + I_t * ~C_t blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); + blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, + cur_in_data + D3); + // O_t + act_gate(D, cur_in_data + D3, cur_in_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cur_c_out_data, cur_in_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - cur_in_data += D4; - cur_prev_c_data += D; - cur_c_out_data += D; - cur_h_out_data += D; + // move to next data in same batch + next_data_in_batch(); } - - prev_c_data = batched_c_out_data; - prev_h_data = batched_h_out_data; - batched_c_out_data = cur_c_out_data; - batched_h_out_data = cur_h_out_data; - batched_input_data = cur_in_data; + // move to data for next timestep + prev_batch_h_data = cur_batch_h_out_data; + prev_batch_c_data = cur_batch_c_out_data; + move_step(cur_bs); } math::Batch2LoDTensorFunctor to_seq; diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 1f1eb37667e..4767e9433ea 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' self.use_peepholes = False + self.use_seq = False self.set_conf() T = sum(self.lod[0]) @@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest): } self.attrs = { 'use_peepholes': self.use_peepholes, + 'use_seq': self.use_seq, 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, @@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp): self.D = 16 +class TestFusionLSTMOpPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + + +class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.is_reverse = True + + +class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.lod = [[3]] + self.D = 16 + + +class TestFusionLSTMOpSeqInit(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.has_initial_state = True + + +class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.is_reverse = True + + +class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + + +class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + self.is_reverse = True + + if __name__ == '__main__': unittest.main() -- GitLab