/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" DEFINE_bool(seq_mode, true, "Use sequence mode"); namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("WeightX"), "Input(WeightX) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("WeightH"), "Input(WeightH) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), "Input(Bias) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("XX"), "Output(XX) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), "Output(BatchedGate) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), "Output(BatchedGate) of LSTM should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), "Input(Cell) and Input(Hidden) of LSTM should not " "be null at the same time."); auto h_dims = ctx->GetInputDim("H0"); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE(h_dims == c_dims, "The dimension of Input(H0) and Input(C0) " "should be the same."); } auto wx_dims = ctx->GetInputDim("WeightX"); PADDLE_ENFORCE_EQ(wx_dims.size(), 2, "The rank of Input(WeightX) should be 2."); PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], "The first dimension of Input(WeightX) " "should be %d.", x_dims[1]); int frame_size = wx_dims[1] / 4; auto wh_dims = ctx->GetInputDim("WeightH"); PADDLE_ENFORCE_EQ(wh_dims.size(), 2, "The rank of Input(WeightH) should be 2."); PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, "The first dimension of Input(WeightH) " "should be %d.", frame_size); PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, "The second dimension of Input(WeightH) " "should be 4 * %d.", frame_size); auto b_dims = ctx->GetInputDim("Bias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), "Do not support peephole yet."); PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, "The second dimension of Input(Bias) should be " "4 * %d if disable peepholes connection", frame_size); framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Cell"); int xx_width; if (FLAGS_seq_mode) { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; } ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } void FusionLSTMOpMaker::Make() { AddInput("X", "(LoDTensor) the input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " "this LoDTensor is a matrix with shape (T X M), where T is the " "total time steps in this mini-batch, M is the dim size of x."); AddInput("WeightX", "(Tensor) the learnable weights of X." " - The shape is (M x 4D), where M is the dim size of x, D is the " "hidden size. " " - Weight = {W_cx, W_ix, W_fx, W_ox}"); AddInput("WeightH", "(Tensor) same as LSTMOp, the learnable hidden-hidden weights." " - The shape is (D x 4D), where D is the hidden size. " " - Weight = {W_ch, W_ih, W_fh, W_oh}"); AddInput("Bias", "(Tensor) the learnable weights. Almost same as LSTMOp" "Note: we should add the fc bias into this (1x4D) in bias." "input-hidden bias weight and peephole connections weight if " "setting `use_peepholes` True. " "1. `use_peepholes = False` " " - The shape is (1 x 4D). " " - Bias = {b_c, b_i, b_f, b_o}." "2. `use_peepholes = True` " " - The shape is (1 x 7D). " " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); AddInput("H0", "(Tensor, optional) (same as LSTMOp) the initial hidden state is an " "optional " "input. This is a tensor with shape (N x D), where N is the " "batch size and D is the hidden size.") .AsDispensable(); AddInput("C0", "(Tensor, optional) (same as LSTMOp) (the initial cell state is an " "optional " "input. This is a tensor with shape (N x D), where N is the " "batch size. `H0` and `C0` can be NULL but only at the same time.") .AsDispensable(); AddOutput("Hidden", "(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("Cell", "(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("XX", "(LoDTensor) the result after X * WeightX (size is T x 4D)" " or batched_X (size is T x M), this will be automatically chosen," " where T is the total time steps in this mini-batch," " D is the hidden size, M is the dim size of x input.") .AsIntermediate(); AddOutput("BatchedGate", "(LoDTensor) (same as LSTMOp).").AsIntermediate(); AddOutput("BatchCellPreAct", "(LoDTensor) (same as LSTMOp).") .AsIntermediate(); AddAttr("use_peepholes", "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") .SetDefault(true); AddAttr("is_reverse", "(bool, defalut: False) " "whether to compute reversed LSTM.") .SetDefault(false); AddAttr("gate_activation", "(string, default: sigmoid)" "The activation for input gate, forget gate and output " "gate, `sigmoid` by default.") .SetDefault("sigmoid") .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("cell_activation", "(string, default: tanh)" "The activation for cell output, `tanh` by defalut.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddAttr("candidate_activation", "(string, default: tanh)" "The activation for candidate hidden state, " "`tanh` by default.") .SetDefault("tanh") .InEnum({"sigmoid", "tanh", "relu", "identity"}); AddComment(R"DOC( Fusion Long-Short Term Memory (LSTM) Operator. This operator fuse the X into LSTM, more details can refer to LSTM op. )DOC"); } template inline void ReorderInitState(const DeviceContext& ctx, const framework::Tensor& src, framework::Vector index_lod, framework::Tensor* dst, bool indexed_src) { math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); // TODO(TJ): check mem copy perf row_shuffle(ctx, src, index_lod, dst, indexed_src); } template class FuisonLSTMKernel : public framework::OpKernel { public: void SeqCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = paddle::platform::CPUDeviceContext; auto* x = ctx.Input("X"); auto* h0 = ctx.Input("H0"); auto* c0 = ctx.Input("C0"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); auto* bias = ctx.Input("Bias"); auto* xx = ctx.Output("XX"); auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); bool is_reverse = ctx.Attr("is_reverse"); std::function act_gate, act_cell, act_cand; auto& act_gate_str = ctx.Attr("gate_activation"); auto& act_cell_str = ctx.Attr("cell_activation"); auto& act_cand_str = ctx.Attr("candidate_activation"); if (platform::jit::MayIUse(platform::jit::avx)) { math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); } else { math::VecActivations act_functor; act_gate = act_functor(act_gate_str); act_cell = act_functor(act_cell_str); act_cand = act_functor(act_cand_str); } auto x_lod = x->lod(); auto x_dims = x->dims(); // T x M auto wh_dims = wh->dims(); // D x 4D const int total_T = x_dims[0]; const int N = x_lod[0].size() - 1; // batch size const int M = x_dims[1]; // x frame size const int D = wh_dims[0]; const int D2 = D * 2; const int D3 = D * 3; const int D4 = wh_dims[1]; const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : NULL; const T* c0_data = c0 ? c0->data() : NULL; const T* wx_data = wx->data(); const T* wh_data = wh->data(); T* xx_data = xx->mutable_data(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); int xx_offset = D4; int gate_offset = D; if (is_reverse) { const int offset = (total_T - 1) * D; xx_data = xx_data + offset * 4; hidden_out_data = hidden_out_data + offset; cell_out_data = cell_out_data + offset; xx_offset = -D4; gate_offset = -D; } auto move_step = [&]() { xx_data = xx_data + xx_offset; hidden_out_data = hidden_out_data + gate_offset; cell_out_data = cell_out_data + gate_offset; }; for (int i = 0; i < N; ++i) { int bid = is_reverse ? N - 1 - i : i; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_cell_data = NULL; const T* prev_hidden_data = NULL; int tstart = 0; if (h0_data) { prev_hidden_data = h0_data + bid * D; prev_cell_data = c0_data + bid * D; } else { // W_ch, W_ih, W_fh, W_oh act_gate(D3, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data); // cell out= input*tilde blas.VMUL(D, xx_data, xx_data + D, cell_out_data); // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev prev_hidden_data = hidden_out_data; prev_cell_data = cell_out_data; tstart = 1; move_step(); } for (int step = tstart; step < seq_len; ++step) { blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), prev_hidden_data, D, wh_data, D4, static_cast(1), xx_data, D4); // W_ch, W_ih, W_fh, W_oh act_gate(D3, xx_data + D, xx_data + D); act_cand(D, xx_data, xx_data); // a = forget * prev_cell blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2); // b = input * tilde blas.VMUL(D, xx_data, xx_data + D, xx_data + D); // cell out= a+b blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev prev_hidden_data = hidden_out_data; prev_cell_data = cell_out_data; move_step(); } } } void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; auto* x = ctx.Input("X"); auto* wx = ctx.Input("WeightX"); auto* wh = ctx.Input("WeightH"); auto* bias = ctx.Input("Bias"); auto* hidden_t0 = ctx.Input("H0"); auto* cell_t0 = ctx.Input("C0"); auto* xx = ctx.Output("XX"); auto* batched_gate = ctx.Output("BatchedGate"); auto* hidden_out = ctx.Output("Hidden"); auto* cell_out = ctx.Output("Cell"); bool is_reverse = ctx.Attr("is_reverse"); T* xx_data = xx->mutable_data(ctx.GetPlace()); T* batched_gate_data = batched_gate->mutable_data(ctx.GetPlace()); hidden_out->mutable_data(ctx.GetPlace()); cell_out->mutable_data(ctx.GetPlace()); const T* x_data = x->data(); const T* wx_data = wx->data(); auto x_dims = x->dims(); auto wx_dims = wx->dims(); math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); if (x_dims[1] > wx_dims[1]) { math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], x_data, wx_data, xx_data, bias->data()); to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); } else { to_batch(dev_ctx, *x, xx, true, is_reverse); batched_gate->set_lod(xx->lod()); math::FCCompute(blas, x_dims[0], wx_dims[1], x_dims[1], xx_data, wx_data, batched_gate_data, bias->data()); } int frame_size = static_cast(wx_dims[1] / 4); framework::DDim out_dims({x_dims[0], frame_size}); math::LstmMetaValue lstm_value; // no peephole lstm_value.check_ig = nullptr; lstm_value.check_fg = nullptr; lstm_value.check_og = nullptr; lstm_value.prev_state_value = nullptr; Tensor ordered_c0; framework::Vector order(batched_gate->lod()[2]); if (cell_t0) { // Since the batch computing for LSTM reorders the input sequence // according to their length. The initialized cell state also needs // to reorder. ReorderInitState(dev_ctx, *cell_t0, order, &ordered_c0, true); lstm_value.prev_state_value = ordered_c0.data(); } // Use the local variable as here. LoDTensor batch_hidden, batch_cell; auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); batch_hidden.mutable_data(out_dims, ctx.GetPlace()); batch_cell.mutable_data(out_dims, ctx.GetPlace()); batch_cell_pre_act->mutable_data(out_dims, ctx.GetPlace()); auto batch_starts = batched_gate->lod()[0]; size_t max_seq_len = batch_starts.size() - 1; auto gate_act = math::detail::GetActivationType( ctx.Attr("gate_activation")); auto cell_act = math::detail::GetActivationType( ctx.Attr("cell_activation")); auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); for (size_t n = 0; n < max_seq_len; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); Tensor gate_t = batched_gate->Slice(bstart, bend); Tensor out_t = batch_hidden.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend); Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); int cur_batch_size = bend - bstart; if (n > 0) { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); // TODO(TJ): use gemm directly blas.MatMul(pre_hidden_t, false, *wh, false, static_cast(1.0), &gate_t, static_cast(1.0)); } else if (hidden_t0) { // TODO(TJ): move h0 outside for // If n == 0 and there is no initialized hidden state, that is to say // the H0 is zeros, the calculation W_h * H0 will be skiped. // If n == 0 and there is initialized hidden state, calculate W_h * H0. // Since the batch computing for LSTM reorders the input sequence // according to their length. The initialized hidden state also needs // to reorder. Tensor ordered_h0; ReorderInitState(dev_ctx, *hidden_t0, order, &ordered_h0, true); // TODO(TJ): use gemm directly blas.MatMul(ordered_h0, false, *wh, false, static_cast(1.0), &gate_t, static_cast(1.0)); } lstm_value.gate_value = gate_t.data(); lstm_value.output_value = out_t.data(); lstm_value.state_value = cell_t.data(); lstm_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute( dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, cand_act); lstm_value.prev_state_value = lstm_value.state_value; } math::Batch2LoDTensorFunctor to_seq; batch_hidden.set_lod(batched_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden to_seq(dev_ctx, batch_hidden, hidden_out); batch_cell.set_lod(batched_gate->lod()); // restore the output cell state in LoDTensor from the batch cell to_seq(dev_ctx, batch_cell, cell_out); } void Compute(const framework::ExecutionContext& ctx) const override { if (FLAGS_seq_mode) { SeqCompute(ctx); } else { BatchCompute(ctx); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel, ops::FuisonLSTMKernel);