From e61cf3214da019ca1de1fb68ae143928877b4e62 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 26 Aug 2018 21:00:56 +0800 Subject: [PATCH] complete reverse seq --- paddle/fluid/operators/fusion_lstm_op.cc | 41 ++++++++++++------- .../tests/unittests/test_fusion_lstm_op.py | 17 ++++---- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 97852e29283..e4e4ac8e333 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -229,6 +229,7 @@ class FuisonLSTMKernel : public framework::OpKernel { auto* xx = ctx.Output("XX"); auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); + bool is_reverse = ctx.Attr("is_reverse"); std::function act_gate, act_cell, act_cand; auto& act_gate_str = ctx.Attr("gate_activation"); @@ -247,8 +248,9 @@ class FuisonLSTMKernel : public framework::OpKernel { } auto x_lod = x->lod(); - auto x_dims = x->dims(); // T x M - auto wh_dims = wh->dims(); // D x 4D + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; // batch size const int M = x_dims[1]; // x frame size const int D = wh_dims[0]; @@ -266,17 +268,34 @@ class FuisonLSTMKernel : public framework::OpKernel { T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); - math::FCCompute(blas, x_dims[0], D4, M, x_data, wx_data, + math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + hidden_out_data = hidden_out_data + offset; + cell_out_data = cell_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + cell_out_data = cell_out_data + gate_offset; + }; for (int i = 0; i < N; ++i) { - int seq_len = x_lod[0][i + 1] - x_lod[0][i]; + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; int tstart = 0; if (h0_data) { - prev_hidden_data = h0_data + i * D; - prev_cell_data = c0_data + i * D; + prev_hidden_data = h0_data + bid * D; + prev_cell_data = c0_data + bid * D; } else { // W_ch, W_ih, W_fh, W_oh act_gate(D3, xx_data + D, xx_data + D); @@ -292,10 +311,7 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_cell_data = cell_out_data; tstart = 1; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } for (int step = tstart; step < seq_len; ++step) { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), @@ -323,10 +339,7 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_hidden_data = hidden_out_data; prev_cell_data = cell_out_data; - // move offset - xx_data = xx_data + D4; - hidden_out_data = hidden_out_data + D; - cell_out_data = cell_out_data + D; + move_step(); } } } diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 19f22fc7bd2..5805bdf4619 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -122,14 +122,15 @@ class TestFusionLSTMOpInit(TestFusionLSTMOp): self.has_initial_state = True -# class TestFusionLSTMOpReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.is_reverse = True - -# class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): -# def set_conf(self): -# self.has_initial_state = True -# self.is_reverse = True +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + self.is_reverse = True class TestFusionLSTMOpMD1(TestFusionLSTMOp): -- GitLab