diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index b2c631b8480ab2a54559cfa698d556ba0707c1cf..35215d7fa6cf984415fad5db6e290046fc4bea46 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -14,29 +14,37 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/detail/activation_functions.h" +#include "paddle/fluid/operators/math/lstm_compute.h" +#include "paddle/fluid/operators/math/sequence2batch.h" +DECLARE_int32(paddle_num_threads); namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Weight"), - "Input(Weight) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("WeightX"), + "Input(WeightX) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("WeightH"), + "Input(WeightH) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("Bias"), "Input(Bias) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("XX"), + "Output(XX) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Hidden"), "Output(Hidden) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Cell"), "Output(Cell) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), - "Output(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchedGate"), + "Output(BatchedGate) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), - "Output(BatchGate) of LSTM should not be null."); + "Output(BatchedGate) of LSTM should not be null."); - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); if (ctx->HasInput("H0")) { PADDLE_ENFORCE(ctx->HasInput("C0"), @@ -49,15 +57,24 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "should be the same."); } - int frame_size = in_dims[1] / 4; - auto w_dims = ctx->GetInputDim("Weight"); - PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); - PADDLE_ENFORCE_EQ(w_dims[0], frame_size, - "The first dimension of Input(Weight) " + auto wx_dims = ctx->GetInputDim("WeightX"); + PADDLE_ENFORCE_EQ(wx_dims.size(), 2, + "The rank of Input(WeightX) should be 2."); + PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], + "The first dimension of Input(WeightX) " + "should be %d.", + x_dims[1]); + + int frame_size = wx_dims[1] / 4; + auto wh_dims = ctx->GetInputDim("WeightH"); + PADDLE_ENFORCE_EQ(wh_dims.size(), 2, + "The rank of Input(WeightH) should be 2."); + PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, + "The first dimension of Input(WeightH) " "should be %d.", frame_size); - PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, - "The second dimension of Input(Weight) " + PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, + "The second dimension of Input(WeightH) " "should be 4 * %d.", frame_size); @@ -66,36 +83,35 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - if (ctx->Attrs().Get("use_peepholes")) { - PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection", - frame_size); - } else { - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, - "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes connection", - frame_size); - } + PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), + "Do not support peephole yet."); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); - framework::DDim out_dims({in_dims[0], frame_size}); + framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Cell", out_dims); - ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchedGate", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchCellPreAct", out_dims); - ctx->ShareLoD("Input", "Hidden"); - ctx->ShareLoD("Input", "Cell"); + ctx->ShareLoD("X", "Hidden"); + ctx->ShareLoD("X", "Cell"); + + int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->ShareLoD("X", "XX"); } framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), + framework::ToDataType(ctx.Input("X")->type()), ctx.device_context()); } void FusionLSTMOpMaker::Make() { - AddInput("Input", + AddInput("X", "(LoDTensor) the first input is a LodTensor, which support " "variable-time length input sequence. The underlying tensor in " "this LoDTensor is a matrix with shape (T X 4D), where T is the " @@ -130,7 +146,12 @@ void FusionLSTMOpMaker::Make() { AddOutput("Cell", "(LoDTensor) the cell state of LSTM operator. " "The shape is (T x D), and lod is the same with the `Input`."); - AddOutput("BatchGate", + AddOutput("XX", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X 4D), where T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddOutput("BatchedGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " "LoDTensor has the same shape as the reorganized input, which " @@ -219,80 +240,102 @@ inline void ReorderInitState(const DeviceContext& ctx, framework::Tensor* dst, bool indexed_src) { math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); + // TODO(TJ): check mem copy perf row_shuffle(ctx, src, index_lod, dst, indexed_src); } +// TODO(TJ): can move to math::details +template +inline void SimpleFC(const math::BlasT& blas, const int M, + const int N, const int K, const T* A, const T* B, T* C, + const T* bias_data = NULL) { + blas.GEMM(CblasNoTrans, CblasNoTrans, M, N, K, static_cast(1), A, B, + static_cast(0), C); + if (bias_data) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) +#endif + for (int i = 0; i < M; i++) { + blas.AXPY(N, static_cast(1), bias_data, C + i * N); + } + } +} + template -class LSTMKernel : public framework::OpKernel { +class FuisonLSTMKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("Input"); - auto* weight = ctx.Input("Weight"); + auto* x = ctx.Input("X"); + auto* wx = ctx.Input("WeightX"); // x*4D + auto* wh = ctx.Input("WeightH"); // D*4D auto* bias = ctx.Input("Bias"); - auto* hidden_t0 = ctx.Input("H0"); auto* cell_t0 = ctx.Input("C0"); - auto* batch_gate = ctx.Output("BatchGate"); - batch_gate->mutable_data(ctx.GetPlace()); + // the result after x*Wx (size: sum_words*4D) or batched_x (size: + // sum_words*x) + auto* xx = ctx.Output("XX"); + auto* batched_gate = ctx.Output("BatchedGate"); auto* hidden_out = ctx.Output("Hidden"); - hidden_out->mutable_data(ctx.GetPlace()); auto* cell_out = ctx.Output("Cell"); + bool is_reverse = ctx.Attr("is_reverse"); + + T* xx_data = xx->mutable_data(ctx.GetPlace()); + T* batched_gate_data = batched_gate->mutable_data(ctx.GetPlace()); + hidden_out->mutable_data(ctx.GetPlace()); cell_out->mutable_data(ctx.GetPlace()); - bool is_reverse = ctx.Attr("is_reverse"); + const T* x_data = x->data(); + const T* wx_data = wx->data(); + auto x_dims = x->dims(); + auto wx_dims = wx->dims(); + math::LoDTensor2BatchFunctor to_batch; - auto& device_ctx = ctx.template device_context(); - to_batch(device_ctx, *input, batch_gate, true, is_reverse); - - auto in_dims = input->dims(); - int frame_size = static_cast(in_dims[1] / 4); - framework::DDim dims({in_dims[0], frame_size}); - - if (bias) { - Tensor b = *bias; - b.Resize({bias->numel(), 1}); - Tensor gate_bias = b.Slice(0, 4 * frame_size); - math::RowwiseAdd add_bias; - add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + // TODO(TJ): op test these two cases + if (x_dims[1] > wx_dims[1]) { + SimpleFC(blas, x_dims[0], wx_dims[1], x_dims[1], x_data, + wx_data, xx_data, bias->data()); + to_batch(dev_ctx, *xx, batched_gate, true, is_reverse); + } else { + to_batch(dev_ctx, *x, xx, true, is_reverse); + SimpleFC(blas, x_dims[0], wx_dims[1], x_dims[1], + xx_data, wx_data, batched_gate_data, + bias->data()); } + int frame_size = static_cast(wx_dims[1] / 4); + framework::DDim out_dims({x_dims[0], frame_size}); math::LstmMetaValue lstm_value; - if (bias && ctx.Attr("use_peepholes")) { - T* bias_data = const_cast(bias->data()); - // the code style in LstmMetaValue will be updated later. + // no peephole + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; - lstm_value.check_ig = bias_data + 4 * frame_size; - lstm_value.check_fg = lstm_value.check_ig + frame_size; - lstm_value.check_og = lstm_value.check_fg + frame_size; - } else { - lstm_value.check_ig = nullptr; - lstm_value.check_fg = nullptr; - lstm_value.check_og = nullptr; - } lstm_value.prev_state_value = nullptr; Tensor ordered_c0; - framework::Vector order(batch_gate->lod()[2]); + framework::Vector order(batched_gate->lod()[2]); if (cell_t0) { // Since the batch computing for LSTM reorders the input sequence // according to their length. The initialized cell state also needs // to reorder. - ReorderInitState(device_ctx, *cell_t0, order, - &ordered_c0, true); + ReorderInitState(dev_ctx, *cell_t0, order, &ordered_c0, + true); lstm_value.prev_state_value = ordered_c0.data(); } // Use the local variable as here. LoDTensor batch_hidden, batch_cell; auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); - batch_hidden.mutable_data(dims, ctx.GetPlace()); - batch_cell.mutable_data(dims, ctx.GetPlace()); - batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); + batch_hidden.mutable_data(out_dims, ctx.GetPlace()); + batch_cell.mutable_data(out_dims, ctx.GetPlace()); + batch_cell_pre_act->mutable_data(out_dims, ctx.GetPlace()); - auto batch_starts = batch_gate->lod()[0]; - size_t num_batch = batch_starts.size() - 1; + auto batch_starts = batched_gate->lod()[0]; + size_t max_seq_len = batch_starts.size() - 1; auto gate_act = math::detail::GetActivationType( ctx.Attr("gate_activation")); auto cell_act = math::detail::GetActivationType( @@ -300,12 +343,11 @@ class LSTMKernel : public framework::OpKernel { auto cand_act = math::detail::GetActivationType( ctx.Attr("candidate_activation")); - auto blas = math::GetBlas(device_ctx); - for (size_t n = 0; n < num_batch; n++) { + for (size_t n = 0; n < max_seq_len; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); - Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor gate_t = batched_gate->Slice(bstart, bend); Tensor out_t = batch_hidden.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend); Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); @@ -316,9 +358,11 @@ class LSTMKernel : public framework::OpKernel { int pre_h_start = static_cast(batch_starts[n - 1]); int pre_h_end = pre_h_start + cur_batch_size; auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); - blas.MatMul(pre_hidden_t, false, *weight, false, static_cast(1.0), + // TODO(TJ): use gemm directly + blas.MatMul(pre_hidden_t, false, *wh, false, static_cast(1.0), &gate_t, static_cast(1.0)); } else if (hidden_t0) { + // TODO(TJ): move h0 outside for // If n == 0 and there is no initialized hidden state, that is to say // the H0 is zeros, the calculation W_h * H0 will be skiped. // If n == 0 and there is initialized hidden state, calculate W_h * H0. @@ -327,10 +371,11 @@ class LSTMKernel : public framework::OpKernel { // according to their length. The initialized hidden state also needs // to reorder. Tensor ordered_h0; - ReorderInitState(device_ctx, *hidden_t0, order, + ReorderInitState(dev_ctx, *hidden_t0, order, &ordered_h0, true); - blas.MatMul(ordered_h0, false, *weight, false, static_cast(1.0), - &gate_t, static_cast(1.0)); + // TODO(TJ): use gemm directly + blas.MatMul(ordered_h0, false, *wh, false, static_cast(1.0), &gate_t, + static_cast(1.0)); } lstm_value.gate_value = gate_t.data(); @@ -338,19 +383,19 @@ class LSTMKernel : public framework::OpKernel { lstm_value.state_value = cell_t.data(); lstm_value.state_active_value = cell_pre_act_t.data(); math::LstmUnitFunctor::compute( - device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, - cell_act, cand_act); + dev_ctx, lstm_value, frame_size, cur_batch_size, gate_act, cell_act, + cand_act); lstm_value.prev_state_value = lstm_value.state_value; } math::Batch2LoDTensorFunctor to_seq; - batch_hidden.set_lod(batch_gate->lod()); + batch_hidden.set_lod(batched_gate->lod()); // restore the output hidden in LoDTensor from the batch hidden - to_seq(device_ctx, batch_hidden, hidden_out); + to_seq(dev_ctx, batch_hidden, hidden_out); - batch_cell.set_lod(batch_gate->lod()); + batch_cell.set_lod(batched_gate->lod()); // restore the output cell state in LoDTensor from the batch cell - to_seq(device_ctx, batch_cell, cell_out); + to_seq(dev_ctx, batch_cell, cell_out); } }; @@ -358,9 +403,10 @@ class LSTMKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(lstm, ops::LSTMOp, ops::LSTMOpMaker, +REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OP_CPU_KERNEL( - fusion_lstm, ops::LSTMKernel, - ops::LSTMKernel); + fusion_lstm, + ops::FuisonLSTMKernel, + ops::FuisonLSTMKernel); diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h index 88a65aee033ee8fc1e263e4afb03f9cf0c39cf18..39dc09b4d116193399d8ac9a51e88dbc3e239918 100644 --- a/paddle/fluid/operators/fusion_lstm_op.h +++ b/paddle/fluid/operators/fusion_lstm_op.h @@ -15,10 +15,6 @@ limitations under the License. */ #pragma once // #include #include "paddle/fluid/framework/op_registry.h" -// #include "paddle/fluid/operators/math/blas.h" -// #include "paddle/fluid/operators/math/detail/activation_functions.h" -// #include "paddle/fluid/operators/math/lstm_compute.h" -// #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { namespace operators {