From 8728b3cce24c69f76167d843b9bb667027110c56 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 12 Oct 2017 11:30:44 +0800 Subject: [PATCH] Add LSTM Operators. --- paddle/operators/lstm_op.cc | 185 ++++++++++++++++++++++++ paddle/operators/lstm_op.h | 38 +++++ paddle/operators/lstm_unit_op.h | 1 - paddle/operators/math/cross_entropy.cu | 2 - paddle/operators/math/sequence2batch.cc | 26 ++++ paddle/operators/math/sequence2batch.cu | 26 ++++ paddle/operators/math/sequence2batch.h | 113 +++++++++++++++ 7 files changed, 388 insertions(+), 3 deletions(-) create mode 100644 paddle/operators/lstm_op.cc create mode 100644 paddle/operators/lstm_op.h create mode 100644 paddle/operators/math/sequence2batch.cc create mode 100644 paddle/operators/math/sequence2batch.cu create mode 100644 paddle/operators/math/sequence2batch.h diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc new file mode 100644 index 00000000000..6233e129233 --- /dev/null +++ b/paddle/operators/lstm_op.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("H"), + "Output(Cell) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + ctx->SetOutputDim("Hidden", x_dims); + ctx->SetOutputDim("Cell", x_dims); + ctx->ShareLoD("Input", "Hidden"); + ctx->ShareLoD("Input", "Cell"); + } +}; + +class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTenosr is a matrix with shape (T X D), where, T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size, D is the hidden size."); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time"); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (D x 4*D), where D is the hidden size. " + " - Weight = {W_ih, W_fh, W_ch, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "seting `use_peepholes` True. " + "1. `use_peepholes = False` " + " - The shape is (1 x 4*D). " + " - Bias = {b_i, b_f, b_c, b_o}." + "2. `use_peepholes = True` " + " - The shape is (1 x 7*D). " + " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Hidden", + "(LoDTensor) the hidden state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddAttr("use_peepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr( + "gate_activation", + "(string, defalut: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by defalut.") + .SetDefault("sigmoid"); + AddAttr("cell_activation", + "(string, defalut: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh"); + AddAttr("candidate_activation", + "(string, defalut: tanh)" + "The activation for candidate hidden state, " + "`tanh` by defalut.") + .SetDefault("tanh"); + AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator + +The defalut implementation is diagonal/peephole connection [1], the formula is +as follows + + i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) + + f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) + + \tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + + o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + + c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t} + + h_t = o_t ⊙ act_h(c_t) + +where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix +of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$ +are diagonal weight matrices for peephole connections. In our implenmention, +We use vectors to reprenset these diagonal weight matrices. The b terms +denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$ +is the non-line actications, such as logistic sigmoid function, and +\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate, +output gate and cell activation vectors, all of which are the same size as +the cell output activation vector \f$h\f$. + +The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$ +are the cell input and cell output activation functions, `tanh` is usually +used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, +which is computed based on the current input and the previous hidden state. + +Set `use_peepholes` False to disable peephole connection [2]. The formula +is omitted here. + +@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ +operations on the input x_{t} were NOT included in this operator. The +users can choose to use fully-connect operator before LSTM operator. + +[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory +recurrent neural network architectures for large scale acoustic modeling. +INTERSPEECH, 2014. + +[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory. +Neural Computation, 9(8):1735-1780, 1997. + +)DOC"); + } +}; + +class LSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(Hidden@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), + "Input(Cell@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("Weight"), + ctx->GetInputDim("Weight")); + ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp); +REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_CPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h new file mode 100644 index 00000000000..6e77cadead6 --- /dev/null +++ b/paddle/operators/lstm_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +class LSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +template +class LSTMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index a0ff498c1d3..625b1852c2f 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -19,7 +19,6 @@ namespace paddle { namespace operators { -using framework::LoDTensor; using framework::Tensor; template diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 367190e6b06..db878129d65 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -22,8 +22,6 @@ namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc new file mode 100644 index 00000000000..c29baaae086 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensor2Functor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu new file mode 100644 index 00000000000..5afb87e4a43 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensor2Functor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h new file mode 100644 index 00000000000..6ee870cf786 --- /dev/null +++ b/paddle/operators/math/sequence2batch.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { +namespace math { + +template +class LoDTensor2BatchFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& lod_tensor, + framework::LoDTensor& batch, const bool is_reverse) const { + auto lods = lod_tensor->lod(); + PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seqIdx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size(); ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + + // calculate the start position of each batch + // (numBatch equal the maxLength of sequences) + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // num_batch = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + + // The batch number represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + auto batch_lods = batch->lod(); + if (!batch_lods) { + batch_lods->resize(2); + } + // batch_lods[0] is the start positions for batch LoDTensor + int num_batch = (size_t)seq_info[0].length; + batch_lods[0]->resize(num_batch + 1); + // batch_lods[1] is the raw index in the input LoDTensor + auto dims = lod_tensor->dims(); + batch_lods[1]->resize(dims[0]); + + auto* batch_starts = batch_lods[0].data(); + auto* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < num_batch; n++) { + int batch_id = batch_starts[n]; + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + if (!is_reverse) { + seq2batch_idx[batch_id] = start + n; + } else { + seq2batch_idx[batch_id] = start + seq_len - 1 - n; + } + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = batch_id; + } + } +} + +template +class Batch2LoDTensor2Functor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& batch, + framework::LoDTensor& lod_tensor, + const bool is_reverse) const; + +} // namespace math +} // namespace operators +} // namespace paddle -- GitLab