diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 3888333ec5626f1d8d35db215085f483c985cf0a..e4e4ac8e333ba423e151dea05e40a0e41042570e 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,10 +15,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#include "paddle/fluid/platform/cpu_info.h" + +DEFINE_bool(seq_mode, true, "Use sequence mode"); namespace paddle { namespace operators { @@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); - int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + int xx_width; + if (FLAGS_seq_mode) { + xx_width = wx_dims[1]; + } else { + xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx, row_shuffle(ctx, src, index_lod, dst, indexed_src); } -template +template class FuisonLSTMKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void SeqCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = paddle::platform::CPUDeviceContext; + auto* x = ctx.Input("X"); + auto* h0 = ctx.Input("H0"); + auto* c0 = ctx.Input("C0"); + auto* wx = ctx.Input("WeightX"); + auto* wh = ctx.Input("WeightH"); + auto* bias = ctx.Input("Bias"); + + auto* xx = ctx.Output("XX"); + auto* hidden_out = ctx.Output("Hidden"); + auto* cell_out = ctx.Output("Cell"); + bool is_reverse = ctx.Attr("is_reverse"); + + std::function act_gate, act_cell, act_cand; + auto& act_gate_str = ctx.Attr("gate_activation"); + auto& act_cell_str = ctx.Attr("cell_activation"); + auto& act_cand_str = ctx.Attr("candidate_activation"); + if (platform::jit::MayIUse(platform::jit::avx)) { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } else { + math::VecActivations act_functor; + act_gate = act_functor(act_gate_str); + act_cell = act_functor(act_cell_str); + act_cand = act_functor(act_cand_str); + } + + auto x_lod = x->lod(); + auto x_dims = x->dims(); // T x M + auto wh_dims = wh->dims(); // D x 4D + const int total_T = x_dims[0]; + const int N = x_lod[0].size() - 1; // batch size + const int M = x_dims[1]; // x frame size + const int D = wh_dims[0]; + const int D2 = D * 2; + const int D3 = D * 3; + const int D4 = wh_dims[1]; + + const T* x_data = x->data(); + const T* h0_data = h0 ? h0->data() : NULL; + const T* c0_data = c0 ? c0->data() : NULL; + const T* wx_data = wx->data(); + const T* wh_data = wh->data(); + T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); + T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + + auto blas = math::GetBlas(ctx); + math::FCCompute(blas, total_T, D4, M, x_data, wx_data, + xx_data, bias->data()); + int xx_offset = D4; + int gate_offset = D; + if (is_reverse) { + const int offset = (total_T - 1) * D; + xx_data = xx_data + offset * 4; + hidden_out_data = hidden_out_data + offset; + cell_out_data = cell_out_data + offset; + xx_offset = -D4; + gate_offset = -D; + } + + auto move_step = [&]() { + xx_data = xx_data + xx_offset; + hidden_out_data = hidden_out_data + gate_offset; + cell_out_data = cell_out_data + gate_offset; + }; + + for (int i = 0; i < N; ++i) { + int bid = is_reverse ? N - 1 - i : i; + int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; + const T* prev_cell_data = NULL; + const T* prev_hidden_data = NULL; + int tstart = 0; + if (h0_data) { + prev_hidden_data = h0_data + bid * D; + prev_cell_data = c0_data + bid * D; + } else { + // W_ch, W_ih, W_fh, W_oh + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); + // cell out= input*tilde + blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + // hidden out= act_state(cellout) * outgate + act_cell(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + tstart = 1; + + move_step(); + } + for (int step = tstart; step < seq_len; ++step) { + blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), + prev_hidden_data, D, wh_data, D4, static_cast(1), xx_data, + D4); + + // W_ch, W_ih, W_fh, W_oh + act_gate(D3, xx_data + D, xx_data + D); + act_cand(D, xx_data, xx_data); + + // a = forget * prev_cell + blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); + + // b = input * tilde + blas.VMUL(D, xx_data, xx_data + D, xx_data + D); + + // cell out= a+b + blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + + // hidden out= act_state(cellout) * outgate + act_cell(D, cell_out_data, xx_data + D2); + blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); + + // prev + prev_hidden_data = hidden_out_data; + prev_cell_data = cell_out_data; + + move_step(); + } + } + } + + void BatchCompute(const framework::ExecutionContext& ctx) const { + using DeviceContext = platform::CPUDeviceContext; auto* x = ctx.Input("X"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); @@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel { // restore the output cell state in LoDTensor from the batch cell to_seq(dev_ctx, batch_cell, cell_out); } + void Compute(const framework::ExecutionContext& ctx) const override { + if (FLAGS_seq_mode) { + SeqCompute(ctx); + } else { + BatchCompute(ctx); + } + } }; } // namespace operators @@ -348,7 +492,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OP_CPU_KERNEL( - fusion_lstm, - ops::FuisonLSTMKernel, - ops::FuisonLSTMKernel); +REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel, + ops::FuisonLSTMKernel); diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 9d8bef677fd16fb6bdc20b929137b4d885f4efd1..5805bdf461998e90611dec05b079cd55feda520d 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -43,13 +43,13 @@ def fusion_lstm( act_cell, act_cand) -class TestLstmOp(OpTest): - def set_argument(self): - self.lod = [[2, 3, 2]] +class TestFusionLSTMOp(OpTest): + def set_conf(self): + pass def setUp(self): self.op_type = 'fusion_lstm' - self.lod = [[2, 3, 2]] + self.lod = [[2, 3, 5, 4]] self.M = 8 self.D = 16 self.has_initial_state = False @@ -58,33 +58,33 @@ class TestLstmOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' self.use_peepholes = False - self.set_argument() + self.set_conf() T = sum(self.lod[0]) bs = len(self.lod[0]) - x = np.random.normal(size=(T, self.M)).astype('float64') + x = np.random.normal(size=(T, self.M)).astype('float32') if self.has_initial_state: - h0 = np.random.normal(size=(bs, self.D)).astype('float64') - c0 = np.random.normal(size=(bs, self.D)).astype('float64') + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') else: - h0 = np.zeros((bs, self.D)).astype('float64') - c0 = np.zeros((bs, self.D)).astype('float64') + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') - wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') if self.use_peepholes: - b = np.random.normal(size=(1, 7 * self.D)).astype('float64') + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') else: - b = np.random.normal(size=(1, 4 * self.D)).astype('float64') + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') w_b = np.copy(b[:, 0:4 * self.D]) w_c = b[:, 4 * self.D:] if self.use_peepholes else None # this is the weight of fc - wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64') + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') # this is the bias of fc # and it should be manually added into the bias of this fusion LSTM - bx = np.random.normal(size=(1, 4 * self.D)).astype('float64') + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') b[0, 0:4 * self.D] += bx[0, :] h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, self.is_reverse, ACTIVATION[self.act_gate], @@ -114,35 +114,45 @@ class TestLstmOp(OpTest): } def test_check_output(self): - self.check_output(atol=1e-8) + self.check_output() -class TestLstmOpInitReverse(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpInit(TestFusionLSTMOp): + def set_conf(self): + self.has_initial_state = True + + +class TestFusionLSTMOpReverse(TestFusionLSTMOp): + def set_conf(self): + self.is_reverse = True + + +class TestFusionLSTMOpInitReverse(TestFusionLSTMOp): + def set_conf(self): self.has_initial_state = True self.is_reverse = True -class TestLstmOpMD1(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD1(TestFusionLSTMOp): + def set_conf(self): self.M = 36 self.D = 8 -class TestLstmOpMD2(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD2(TestFusionLSTMOp): + def set_conf(self): self.M = 8 self.D = 8 -class TestLstmOpMD3(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpMD3(TestFusionLSTMOp): + def set_conf(self): self.M = 15 self.D = 3 -class TestLstmOpBS1(TestLstmOp): - def set_argument(self): +class TestFusionLSTMOpBS1(TestFusionLSTMOp): + def set_conf(self): self.lod = [[3]] self.D = 16