From ddb05dffb6b8e7230c0a3fec50812f9caf2a96cd Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 15 Aug 2018 14:56:45 +0800 Subject: [PATCH] init fusion lstm op --- paddle/fluid/operators/fusion_lstm_op.cc | 366 +++++++++++++++++++++++ paddle/fluid/operators/fusion_lstm_op.h | 46 +++ 2 files changed, 412 insertions(+) create mode 100644 paddle/fluid/operators/fusion_lstm_op.cc create mode 100644 paddle/fluid/operators/fusion_lstm_op.h diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc new file mode 100644 index 00000000000..b2c631b8480 --- /dev/null +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -0,0 +1,366 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fusion_lstm_op.h" +#include + +namespace paddle { +namespace operators { + +void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(Weight) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of LSTM should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(BatchGate) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), + "Output(BatchGate) of LSTM should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + int frame_size = in_dims[1] / 4; + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], frame_size, + "The first dimension of Input(Weight) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + + if (ctx->Attrs().Get("use_peepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); + } + + framework::DDim out_dims({in_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", out_dims); + ctx->SetOutputDim("Cell", out_dims); + ctx->SetOutputDim("BatchGate", in_dims); + ctx->SetOutputDim("BatchCellPreAct", out_dims); + ctx->ShareLoD("Input", "Hidden"); + ctx->ShareLoD("Input", "Cell"); +} + +framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); +} + +void FusionLSTMOpMaker::Make() { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X 4D), where T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size and D is the hidden size.") + .AsDispensable(); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time.") + .AsDispensable(); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "setting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Hidden", + "(LoDTensor) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); + AddOutput("BatchGate", + "(LoDTensor) This LoDTensor contains input gate, forget gate " + "and output gate after the nonlinear computation. This " + "LoDTensor has the same shape as the reorganized input, which " + "is also be called batch input. The LoD size is 2. The first " + "LoD is the batch offsets and the second LoD contains the " + "indexes, which denote the position of reorganized sequence " + "in the raw input.") + .AsIntermediate(); + AddOutput("BatchCellPreAct", + "(LoDTensor) This LoDTensor is obtained in the forward and used " + "in the backward.") + .AsIntermediate(); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(false); + AddAttr("gate_activation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("cell_activation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddAttr("candidate_activation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh") + .InEnum({"sigmoid", "tanh", "relu", "identity"}); + AddComment(R"DOC( +Long-Short Term Memory (LSTM) Operator. + +The defalut implementation is diagonal/peephole connection +(https://arxiv.org/pdf/1402.1128.pdf), the formula is as follows: + +$$ i_t = \\sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) $$ + +$$ f_t = \\sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) $$ + +$$ \\tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) $$ + +$$ o_t = \\sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) $$ + +$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$ + +$$ h_t = o_t \\odot act_h(c_t) $$ + +- W terms denote weight matrices (e.g. $W_{xi}$ is the matrix + of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$ + are diagonal weight matrices for peephole connections. In our implementation, + we use vectors to reprenset these diagonal weight matrices. +- The b terms denote bias vectors ($b_i$ is the input gate bias vector). +- $\sigma$ is the non-line activations, such as logistic sigmoid function. +- $i, f, o$ and $c$ are the input gate, forget gate, output gate, + and cell activation vectors, respectively, all of which have the same size as + the cell output activation vector $h$. +- The $\odot$ is the element-wise product of the vectors. +- $act_g$ and $act_h$ are the cell input and cell output activation functions + and `tanh` is usually used for them. +- $\tilde{c_t}$ is also called candidate hidden state, + which is computed based on the current input and the previous hidden state. + +Set `use_peepholes` False to disable peephole connection. The formula +is omitted here, please refer to the paper +http://www.bioinf.jku.at/publications/older/2604.pdf for details. + +Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ +operations on the input $x_{t}$ are NOT included in this operator. +Users can choose to use fully-connect operator before LSTM operator. + +)DOC"); +} + +template +inline void ReorderInitState(const DeviceContext& ctx, + const framework::Tensor& src, + framework::Vector index_lod, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims(), ctx.GetPlace()); + row_shuffle(ctx, src, index_lod, dst, indexed_src); +} + +template +class LSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* hidden_t0 = ctx.Input("H0"); + auto* cell_t0 = ctx.Input("C0"); + + auto* batch_gate = ctx.Output("BatchGate"); + batch_gate->mutable_data(ctx.GetPlace()); + auto* hidden_out = ctx.Output("Hidden"); + hidden_out->mutable_data(ctx.GetPlace()); + auto* cell_out = ctx.Output("Cell"); + cell_out->mutable_data(ctx.GetPlace()); + + bool is_reverse = ctx.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + auto& device_ctx = ctx.template device_context(); + to_batch(device_ctx, *input, batch_gate, true, is_reverse); + + auto in_dims = input->dims(); + int frame_size = static_cast(in_dims[1] / 4); + framework::DDim dims({in_dims[0], frame_size}); + + if (bias) { + Tensor b = *bias; + b.Resize({bias->numel(), 1}); + Tensor gate_bias = b.Slice(0, 4 * frame_size); + math::RowwiseAdd add_bias; + add_bias(device_ctx, *batch_gate, gate_bias, batch_gate); + } + + math::LstmMetaValue lstm_value; + if (bias && ctx.Attr("use_peepholes")) { + T* bias_data = const_cast(bias->data()); + // the code style in LstmMetaValue will be updated later. + + lstm_value.check_ig = bias_data + 4 * frame_size; + lstm_value.check_fg = lstm_value.check_ig + frame_size; + lstm_value.check_og = lstm_value.check_fg + frame_size; + } else { + lstm_value.check_ig = nullptr; + lstm_value.check_fg = nullptr; + lstm_value.check_og = nullptr; + } + lstm_value.prev_state_value = nullptr; + Tensor ordered_c0; + + framework::Vector order(batch_gate->lod()[2]); + + if (cell_t0) { + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(device_ctx, *cell_t0, order, + &ordered_c0, true); + lstm_value.prev_state_value = ordered_c0.data(); + } + + // Use the local variable as here. + LoDTensor batch_hidden, batch_cell; + auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct"); + batch_hidden.mutable_data(dims, ctx.GetPlace()); + batch_cell.mutable_data(dims, ctx.GetPlace()); + batch_cell_pre_act->mutable_data(dims, ctx.GetPlace()); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto gate_act = math::detail::GetActivationType( + ctx.Attr("gate_activation")); + auto cell_act = math::detail::GetActivationType( + ctx.Attr("cell_activation")); + auto cand_act = math::detail::GetActivationType( + ctx.Attr("candidate_activation")); + + auto blas = math::GetBlas(device_ctx); + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor out_t = batch_hidden.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n > 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); + blas.MatMul(pre_hidden_t, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); + } else if (hidden_t0) { + // If n == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros, the calculation W_h * H0 will be skiped. + // If n == 0 and there is initialized hidden state, calculate W_h * H0. + + // Since the batch computing for LSTM reorders the input sequence + // according to their length. The initialized hidden state also needs + // to reorder. + Tensor ordered_h0; + ReorderInitState(device_ctx, *hidden_t0, order, + &ordered_h0, true); + blas.MatMul(ordered_h0, false, *weight, false, static_cast(1.0), + &gate_t, static_cast(1.0)); + } + + lstm_value.gate_value = gate_t.data(); + lstm_value.output_value = out_t.data(); + lstm_value.state_value = cell_t.data(); + lstm_value.state_active_value = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute( + device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, + cell_act, cand_act); + lstm_value.prev_state_value = lstm_value.state_value; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_hidden.set_lod(batch_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(device_ctx, batch_hidden, hidden_out); + + batch_cell.set_lod(batch_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(device_ctx, batch_cell, cell_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lstm, ops::LSTMOp, ops::LSTMOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL( + fusion_lstm, ops::LSTMKernel, + ops::LSTMKernel); diff --git a/paddle/fluid/operators/fusion_lstm_op.h b/paddle/fluid/operators/fusion_lstm_op.h new file mode 100644 index 00000000000..88a65aee033 --- /dev/null +++ b/paddle/fluid/operators/fusion_lstm_op.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +// #include +#include "paddle/fluid/framework/op_registry.h" +// #include "paddle/fluid/operators/math/blas.h" +// #include "paddle/fluid/operators/math/detail/activation_functions.h" +// #include "paddle/fluid/operators/math/lstm_compute.h" +// #include "paddle/fluid/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +class FusionLSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class FusionLSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle -- GitLab