diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 75fcc1cda165197fc4413efc6bbbc440088cb4cd..f97bc837dca09060c55cae6a5524c49cd69df28b 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -115,7 +115,8 @@ set(DEPS_OPS softmax_with_cross_entropy_op sum_op pool_op - pool_with_index_op) + pool_with_index_op + lstm_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -126,6 +127,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) +op_library(lstm_op DEPS sequence2batch lstm_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a089b7c2dc1e05224525bc4fe5399ec39036d01 --- /dev/null +++ b/paddle/operators/lstm_op.cc @@ -0,0 +1,226 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/lstm_op.h" + +namespace paddle { +namespace operators { + +class LSTMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(Input) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(Hidden) of LSTM should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Cell"), + "Output(Cell) of LSTM should not be null."); + + auto x_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + + if (ctx->HasInput("H0")) { + PADDLE_ENFORCE(ctx->HasInput("C0"), + "Input(Cell) and Input(Hidden) of LSTM should not " + "be null at the same time."); + auto h_dims = ctx->GetInputDim("H0"); + auto c_dims = ctx->GetInputDim("C0"); + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } + + int frame_size = x_dims[1] / 4; + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], frame_size, + "The first dimension of Input(Weight) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + if (ctx->Attrs().Get("usePeepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes connection", + frame_size); + } + ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); + ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); + ctx->SetOutputDim("BatchGate", x_dims); + ctx->ShareLoD("Input", "Hidden"); + ctx->ShareLoD("Input", "Cell"); + } +}; + +class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LSTMOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTensor is a matrix with shape (T X 4D), where, T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size, D is the hidden size."); + AddInput("C0", + "(Tensor, optional) the initial cell state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size. `H0` and `C0` can be NULL but only at the same time"); + AddInput("Weight", + "(Tensor) the learnable hidden-hidden weights." + " - The shape is (D x 4D), where D is the hidden size. " + " - Weight = {W_ch, W_ih, W_fh, W_oh}"); + AddInput("Bias", + "(Tensor) the learnable weights, which contains two parts: " + "input-hidden bias weight and peephole connections weight if " + "setting `usePeepholes` True. " + "1. `usePeepholes = False` " + " - The shape is (1 x 4D). " + " - Bias = {b_c, b_i, b_f, b_o}." + "2. `usePeepholes = True` " + " - The shape is (1 x 7D). " + " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); + AddOutput("BatchGate", + "(LoDTensor) This LoDTensor contains input gate, forget gate " + "and output gate after the nonlinear computation. This " + "LoDTensor has the same shape with the reorganized input, which " + "was also be called batch input. The LoD size is 2. The first " + "LoD is the batch offsets and the second LoD contains the " + "indexes, which denote the position of reorganized sequence " + "in the raw input.") + .AsIntermediate(); + AddOutput("Hidden", + "(LoDTensor) the hidden state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddOutput("Cell", + "(LoDTensor) the cell state lod tensor of LSTM operator. " + "The shape and lod is the same with the `Input`."); + AddAttr("usePeepholes", + "(bool, defalut: True) " + "whether to enable diagonal/peephole connections.") + .SetDefault(true); + AddAttr("isReverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(false); + AddAttr( + "gateActivation", + "(string, default: sigmoid)" + "The activation for input gate, forget gate and output " + "gate, `sigmoid` by default.") + .SetDefault("sigmoid"); + AddAttr("cellActivation", + "(string, default: tanh)" + "The activation for cell output, `tanh` by defalut.") + .SetDefault("tanh"); + AddAttr("candidateActivation", + "(string, default: tanh)" + "The activation for candidate hidden state, " + "`tanh` by default.") + .SetDefault("tanh"); + AddComment(R"DOC(Long-Short Term Memory (LSTM) Operator + +The defalut implementation is diagonal/peephole connection [1], the formula is +as follows + + i_t = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) + + f_t = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) + + \tilde{c_t} = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + + o_t = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + + c_t = f_t ⊙ c_{t-1} + i_t ⊙ \tilde{c_t} + + h_t = o_t ⊙ act_h(c_t) + +where the W terms denote weight matrices (e.g. \f$W_{xi}\f$ is the matrix +of weights from the input gate to the input), \f$W_{ic}, W_{fc}, W_{oc}\f$ +are diagonal weight matrices for peephole connections. In our implenmention, +We use vectors to reprenset these diagonal weight matrices. The b terms +denote bias vectors (\f$b_i\f$ is the input gate bias vector), \f$\sigma\f$ +is the non-line actications, such as logistic sigmoid function, and +\f$i, f, o\f$ and \f$c\f$ are respectively the input gate, forget gate, +output gate and cell activation vectors, all of which are the same size as +the cell output activation vector \f$h\f$. + +The ⊙ is the element-wise product of the vectors, \f$act_g\f$ and \f$act_h\f$ +are the cell input and cell output activation functions, `tanh` is usually +used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, +which is computed based on the current input and the previous hidden state. + +Set `usePeepholes` False to disable peephole connection [2]. The formula +is omitted here. + +@note These \f$W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}\f$ +operations on the input x_{t} were NOT included in this operator. +Users can choose to use fully-connect operator before LSTM operator. + +[1] Hasim Sak, Andrew Senior, and Francoise Beaufays. Long short-term memory +recurrent neural network architectures for large scale acoustic modeling. +INTERSPEECH, 2014. + +[2] S. Hochreiter and J. Schmidhuber. Long Short-Term Memory. +Neural Computation, 9(8):1735-1780, 1997. + +)DOC"); + } +}; + +class LSTMGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(Hidden@GRAD) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), + "Input(Cell@GRAD) should not be null"); + ctx->SetOutputDim(framework::GradVarName("Weight"), + ctx->GetInputDim("Weight")); + ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm, ops::LSTMOp, ops::LSTMOpMaker, lstm_grad, ops::LSTMGradOp); +REGISTER_OP_CPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_CPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.cu b/paddle/operators/lstm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9ad56941553bf19a56c25f41f76fe20dfa3a106f --- /dev/null +++ b/paddle/operators/lstm_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/lstm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(lstm, ops::LSTMKernel, + ops::LSTMKernel); +REGISTER_OP_GPU_KERNEL(lstm_grad, + ops::LSTMGradKernel, + ops::LSTMGradKernel); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0af5694c48fcb4437e3acd422606de013bb2e145 --- /dev/null +++ b/paddle/operators/lstm_op.h @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class LSTMKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* weight = ctx.Input("Weight"); + auto* bias = ctx.Input("Bias"); + + auto* batch_gate = ctx.Output("BatchGate"); + batch_gate->mutable_data(ctx.GetPlace()); + auto* hidden_out = ctx.Output("Hidden"); + hidden_out->mutable_data(ctx.GetPlace()); + auto* cell_out = ctx.Output("Cell"); + cell_out->mutable_data(ctx.GetPlace()); + + // Now the function ShareLoD in InferShape is not implemented. + // So copy LoD here. + ctx.ShareLoD("Input", "Hidden"); + ctx.ShareLoD("Input", "Cell"); + + bool is_reverse = ctx.Attr("isReverse"); + math::LoDTensor2BatchFunctor to_batch; + to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); + + auto in_dims = input->dims(); + int frame_size = static_cast(in_dims[1] / 4); + framework::DDim dims({in_dims[0], frame_size}); + + if (bias) { + Eigen::array extents({{1, 4 * frame_size}}); + Eigen::array offsets({{0, 0}}); + auto b = EigenMatrix::From(*bias); + auto gate = EigenMatrix::From(*batch_gate); + gate.device(ctx.GetEigenDevice()) = + gate + + b.slice(offsets, extents) + .reshape(Eigen::array({{1, frame_size * 4}})) + .broadcast( + Eigen::array({{static_cast(in_dims[0]), 1}})); + } + + math::LstmMetaValue lstm_value; + T* bias_data = const_cast(bias->data()); + // the code style in LstmMetaValue will be updated later. + lstm_value.checkIg = bias_data + 4 * frame_size; + lstm_value.checkFg = lstm_value.checkIg + frame_size; + lstm_value.checkOg = lstm_value.checkFg + frame_size; + lstm_value.prevStateValue = nullptr; + + framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; + batch_out.mutable_data(dims, ctx.GetPlace()); + batch_cell.mutable_data(dims, ctx.GetPlace()); + batch_cell_pre_act.mutable_data(dims, ctx.GetPlace()); + + auto batch_starts = batch_gate->lod()[0]; + size_t num_batch = batch_starts.size() - 1; + auto gate_act = ctx.Attr("gateActivation"); + auto cell_act = ctx.Attr("cellActivation"); + auto cand_act = ctx.Attr("candidateActivation"); + + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor out_t = batch_out.Slice(bstart, bend); + Tensor cell_t = batch_cell.Slice(bstart, bend); + Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); + + int cur_batch_size = bend - bstart; + + if (n != 0) { + int pre_h_start = static_cast(batch_starts[n - 1]); + int pre_h_end = pre_h_start + cur_batch_size; + auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); + math::matmul(ctx.device_context(), pre_hidden_t, false, + *weight, false, static_cast(1.0), &gate_t, + static_cast(1.0)); + } + // else if : FIXME support the initial hidden and cell + + lstm_value.gateValue = gate_t.data(); + lstm_value.outputValue = out_t.data(); + lstm_value.stateValue = cell_t.data(); + lstm_value.stateActiveValue = cell_pre_act_t.data(); + math::LstmUnitFunctor::compute(ctx.device_context(), lstm_value, + frame_size, cur_batch_size, + gate_act, cell_act, cand_act); + lstm_value.prevStateValue = lstm_value.stateValue; + } + + math::Batch2LoDTensorFunctor to_seq; + batch_out.set_lod(batch_gate->lod()); + // restore the output hidden in LoDTensor from the batch hidden + to_seq(ctx.device_context(), batch_out, *hidden_out); + + batch_cell.set_lod(batch_gate->lod()); + // restore the output cell state in LoDTensor from the batch cell + to_seq(ctx.device_context(), batch_cell, *cell_out); + } +}; + +template +class LSTMGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h index a0ff498c1d3ed2aaa10f5473ef91de168c250649..625b1852c2f0eb2ed435f73fea251c40c614a7dd 100644 --- a/paddle/operators/lstm_unit_op.h +++ b/paddle/operators/lstm_unit_op.h @@ -19,7 +19,6 @@ namespace paddle { namespace operators { -using framework::LoDTensor; using framework::Tensor; template diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 72ce8585045b5166df424a401442db39b47ab098..5598669ef96535b7d47150052b3841771c37c60b 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(detail) + if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) @@ -7,6 +9,8 @@ if(WITH_GPU) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) + nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) + nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -14,6 +18,8 @@ else() cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) + cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) + cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 367190e6b0682ec62550e869e2f04c3a2b2cbec3..db878129d650d663e187ecabb106eea0e39db6fa 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -22,8 +22,6 @@ namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int* label, const int N, const int D) { - // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file. - // CUDA_1D_KERNEL_LOOP(i, N) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { PADDLE_ASSERT(label[i] >= 0 && label[i] < D); diff --git a/paddle/operators/math/detail/CMakeLists.txt b/paddle/operators/math/detail/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..49cf228de2204cb4888cf645a0cb68ed04cc3371 --- /dev/null +++ b/paddle/operators/math/detail/CMakeLists.txt @@ -0,0 +1,5 @@ +if(WITH_AVX) + cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc) +else() + cc_library(activation_functions SRCS hl_cpu_functions.cc) +endif() diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..9d7d9914f0090bff17049038dfa2288d84f3dbda --- /dev/null +++ b/paddle/operators/math/detail/hl_activation_functions.h @@ -0,0 +1,188 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_ACTIVATION_FUNCTIONS_H_ +#define HL_ACTIVATION_FUNCTIONS_H_ + +#include "hl_functions.h" +#include "paddle/operators/math/lstm_compute.h" + +/** + * Active functions: sigmoid, relu, tanh and linear. + */ +#define FLOAT_ACTIVE_FUNCTION \ + { \ + hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \ + hppl::typef::linear \ + } + +#define DOUBLE_ACTIVE_FUNCTION \ + { \ + hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \ + hppl::typed::linear \ + } + +#define AVX_ACTIVE_FUNCTION \ + { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } + +namespace hppl { + +using activation_mode_t = paddle::operators::math::activation_mode_t; + +/** + * Hppl supports sigmoid, relu, tanh, linear active functions + * for neural networks' forward and backward activation. + */ +template +class Active { + public: + typedef T (*forward)(T); + typedef T (*backward)(T, T); +}; + +template +struct ForwardActType; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template <> +struct ForwardActType { + using type = Active::forward; +}; + +template +struct BackwardActType; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + +template <> +struct BackwardActType { + using type = Active::backward; +}; + +#ifdef __NVCC__ +namespace gpu { +static __device__ Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static __device__ Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static __device__ Active::backward backward_d[] = + DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + __device__ typename ForwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + __device__ ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + __device__ typename BackwardActType::type operator()( + activation_mode_t type); +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + __device__ BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + +} // namespace gpu +#else +namespace cpu { +static Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; +static Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; + +static Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; +static Active::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION; + +template +struct ForwardAct { + typename ForwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward[type]; + } +}; + +template <> +struct ForwardAct { + ForwardActType::type operator()(activation_mode_t type) { + return forward_d[type]; + } +}; + +template +struct BackwardAct { + typename BackwardActType::type operator()(activation_mode_t type); +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward[type]; + } +}; + +template <> +struct BackwardAct { + BackwardActType::type operator()(activation_mode_t type) { + return backward_d[type]; + } +}; + +} // namespace cpu + +#ifdef __AVX__ +namespace avx { +static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION; +static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION; +} // namespace avx +#endif +#endif + +} // namespace hppl + +#endif // HL_ACTIVATION_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..415bac5d93ee00244d072b0998c6941b14d4f8d8 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" +// TODO(qingqing) refine this dependence +#include "paddle/cuda/src/avx_mathfun.h" + +namespace hppl { + +__m256 exp(__m256 a) { return exp256_ps(a); } + +__m256 relu(const __m256 a) { + __m256 tmp = _mm256_set1_ps(0.0f); + return _mm256_max_ps(a, tmp); +} + +__m256 sigmoid(const __m256 a) { + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 tmp = _mm256_max_ps(a, min); + tmp = _mm256_min_ps(tmp, max); + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); + tmp = exp(tmp); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + +__m256 tanh(const __m256 a) { + __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + tmp = _mm256_min_ps(tmp, max); + tmp = exp(tmp); + return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), + _mm256_set1_ps(1.0f)); +} + +__m256 linear(const __m256 a) { return a; } + +__m256 relu(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), + _mm256_set1_ps(1.0f))); +} + +__m256 sigmoid(const __m256 a, const __m256 b) { + return _mm256_mul_ps(_mm256_mul_ps(a, b), + _mm256_sub_ps(_mm256_set1_ps(1.0f), b)); +} + +__m256 tanh(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); +} + +__m256 linear(const __m256 a, const __m256 b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_avx_functions.h b/paddle/operators/math/detail/hl_avx_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..35f4eabb4c07c6cc9d2edded02e5b6290b1232f8 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_AVX_FUNCTIONS_H_ +#define HL_AVX_FUNCTIONS_H_ + +#include + +namespace hppl { +__m256 relu(const __m256 a); +__m256 sigmoid(const __m256 a); +__m256 tanh(const __m256 a); +__m256 linear(const __m256 a); + +__m256 relu(const __m256 a, const __m256 b); +__m256 sigmoid(const __m256 a, const __m256 b); +__m256 tanh(const __m256 a, const __m256 b); +__m256 linear(const __m256 a, const __m256 b); +} // namespace hppl + +#endif // HL_AVX_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000000000000000000000000000000000..21ec78f9629af0e4673a56517d76ac6734f57db8 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { +namespace typef { + +float relu(const float a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +float linear(const float a) { return a; } + +float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } + +float sigmoid(const float a, const float b) { + return a * b * (static_cast(1) - b); +} + +float tanh(const float a, const float b) { + return a * (static_cast(1) - b * b); +} + +float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { +double relu(const double a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +double linear(const double a) { return a; } + +double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +double sigmoid(const double a, const double b) { + return a * b * (static_cast(1) - b); +} + +double tanh(const double a, const double b) { + return a * (static_cast(1) - b * b); +} + +double linear(const double a, const double b) { return a; } + +} // namespace typed +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..3e2f0c9ee6d3ae2ed598c4d5f09b85b7d61fdd51 --- /dev/null +++ b/paddle/operators/math/detail/hl_functions.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_FUNCTIONS_H_ +#define HL_FUNCTIONS_H_ + +/** + * sigmoid threshold maximum + */ +#define SIGMOID_THRESHOLD_MIN -40.0 + +/** + * sigmoid threshold minimum + */ +#define SIGMOID_THRESHOLD_MAX 13.0 + +/** + * The maximum input value for exp, used to avoid overflow problem. + * currently only used for tanh function. + */ +#define EXP_MAX_INPUT 40.0 + +#ifndef __NVCC__ +namespace hppl { +namespace typef { +float relu(const float a); +float sigmoid(const float a); +float tanh(const float a); +float linear(const float a); + +float relu(const float a, const float b); +float sigmoid(const float a, const float b); +float tanh(const float a, const float b); +float linear(const float a, const float b); + +} // namespace typef + +namespace typed { +double relu(const double a); +double sigmoid(const double a); +double tanh(const double a); +double linear(const double a); + +double relu(const double a, const double b); +double sigmoid(const double a, const double b); +double tanh(const double a, const double b); +double linear(const double a, const double b); +} // namespace typed + +} // namespace hppl + +#ifdef __AVX__ +#include "hl_avx_functions.h" +#endif + +#else +#include "hl_gpu_functions.h" +#endif + +#endif // HL_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..72f2204e7b2cfdba1367b51e3731dde11fb292d6 --- /dev/null +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_GPU_FUNCTIONS_CUH_ +#define HL_GPU_FUNCTIONS_CUH_ + +#include "hl_base.h" + +namespace hppl { +namespace typef { + +__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; } + +__device__ static float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return __fdividef(1.0f, 1.0f + __expf(-tmp)); +} + +__device__ static float tanh(const float a) { + float tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f; +} + +__device__ static float linear(const float a) { return a; } + +__device__ static float relu(const float a, const float b) { + return a * (b > 0.0f ? 1.0f : 0.0f); +} + +__device__ static float sigmoid(const float a, const float b) { + return a * b * (1.0f - b); +} + +__device__ static float tanh(const float a, const float b) { + return a * (1.0f - b * b); +} + +__device__ static float linear(const float a, const float b) { return a; } + +} // namespace typef + +namespace typed { + +__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; } + +__device__ static double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +__device__ static double tanh(const double a) { + double tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; +} + +__device__ static double linear(const double a) { return a; } + +__device__ static double relu(const double a, const double b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +__device__ static double sigmoid(const double a, const double b) { + return a * b * (1 - b); +} + +__device__ static double tanh(const double a, const double b) { + return a * (1.0 - b * b); +} + +__device__ static double linear(const double a, const double b) { return a; } + +} // namespace typef + +} // namespace hppl + +#endif // HL_GPU_FUNCTIONS_CUH_ diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..74d51d7bc9b91f4c8088384d77183131f57aafab --- /dev/null +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -0,0 +1,310 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI; + T rCheckF; + T rCheckO; + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + hppl::cpu::ForwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + value.stateValue[i] = rState; + value.stateActiveValue[i] = rStateAtv; + value.outputValue[i] = rOut; + } +} + +template +void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI; + T rCheckF; + T rCheckO; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + T *gradIn = grad.gateGrad; + T *gradIg = grad.gateGrad + frameSize; + T *gradFg = grad.gateGrad + frameSize * 2; + T *gradOg = grad.gateGrad + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + rState = value.stateValue[i]; + rStateAtv = value.stateActiveValue[i]; + rOutputGrad = grad.outputGrad[i]; + rStateGrad = grad.stateGrad[i]; + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + hppl::cpu::BackwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, act(active_node), act(active_gate), act(active_state)); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + grad.stateGrad[i] = rStateGrad; + + if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + } +} + +template +void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rState; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rStateAtv; + __m256 rOut; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], + hppl::avx::forward[active_gate], hppl::avx::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + ((__m256 *)value.stateValue)[i] = rState; + ((__m256 *)value.stateActiveValue)[i] = rStateAtv; + ((__m256 *)value.outputValue)[i] = rOut; + } +#endif +} + +template +void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rGradIn; + __m256 rGradIg; + __m256 rGradFg; + __m256 rGradOg; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rPrevStateGrad; + __m256 rStateGrad; + __m256 rState; + __m256 rStateAtv; + __m256 rOutputGrad; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rCheckIGrad; + __m256 rCheckFGrad; + __m256 rCheckOGrad; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + __m256 *gradIn = (__m256 *)grad.gateGrad; + __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); + __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); + __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + rState = ((__m256 *)value.stateValue)[i]; + rStateAtv = ((__m256 *)value.stateActiveValue)[i]; + rOutputGrad = ((__m256 *)grad.outputGrad)[i]; + rStateGrad = ((__m256 *)grad.stateGrad)[i]; + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::avx::backward[active_node], + hppl::avx::backward[active_gate], hppl::avx::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + ((__m256 *)grad.stateGrad)[i] = rStateGrad; + + if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; + if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + } + if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + } +#endif +} + +template +void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } +} + +template +void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, + int frameSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (std::is_same::value)) { + avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9573eaefb6a9d678ef70f2e2bffdc6a3011b21ea --- /dev/null +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -0,0 +1,256 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/device_context.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.outputValue += batchIdx * frameSize; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + } + + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + hppl::gpu::ForwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate), + act(active_state)); + + value.gateValue[frameIdx] = rValueIn; + value.gateValue[frameIdx + frameSize] = rValueIg; + value.gateValue[frameIdx + frameSize * 2] = rValueFg; + value.gateValue[frameIdx + frameSize * 3] = rValueOg; + + value.stateValue[frameIdx] = rState; + value.stateActiveValue[frameIdx] = rStateAtv; + value.outputValue[frameIdx] = rOut; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmBackward(Op op, LstmMetaValue value, + LstmMetaGrad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + grad.gateGrad += batchIdx * frameSize * 4; + grad.stateGrad += batchIdx * frameSize; + grad.outputGrad += batchIdx * frameSize; + } + + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + rState = value.stateValue[frameIdx]; + rStateAtv = value.stateActiveValue[frameIdx]; + rOutputGrad = grad.outputGrad[frameIdx]; + rStateGrad = grad.stateGrad[frameIdx]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + hppl::gpu::BackwardAct act; + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, + rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, + act(active_node), act(active_gate), act(active_state)); + + grad.gateGrad[frameIdx] = rGradIn; + grad.gateGrad[frameIdx + frameSize] = rGradIg; + grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; + grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; + grad.stateGrad[frameIdx] = rStateGrad; + if (grad.prevStateGrad) { + if (isBatch) grad.prevStateGrad += batchIdx * frameSize; + grad.prevStateGrad[frameIdx] = rPrevStateGrad; + } + + if (isBatch) { + if (value.prevStateValue) { + if (grad.checkIgGrad) + paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, + rCheckIGrad); + if (grad.checkFgGrad) + paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, + rCheckFGrad); + } + if (grad.checkOgGrad) + paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + } else { + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + } +} + +template +void gpu_lstm_forward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + auto stream = + reinterpret_cast(context).stream(); + if (batchSize == 1) { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +template +void gpu_lstm_backward(const platform::DeviceContext& context, Op op, + LstmMetaValue value, LstmMetaGrad grad, + int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + auto stream = + reinterpret_cast(context).stream(); + if (batchSize == 1) { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6f3ead2397d5131b4468d0ad288513cedb289594 --- /dev/null +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/platform/hostdevice.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +namespace forward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + typename hppl::ForwardActType::type actInput, + typename hppl::ForwardActType::type actGate, + typename hppl::ForwardActType::type actState) { + valueIn = actInput(valueIn); + valueIg = actGate(valueIg + prevState * checkI); + valueFg = actGate(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = actGate(valueOg + state * checkO); + stateAtv = actState(state); + output = valueOg * stateAtv; + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + hppl::Active<__m256>::forward actInput, + hppl::Active<__m256>::forward actGate, + hppl::Active<__m256>::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); + valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); + state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), + _mm256_mul_ps(prevState, valueFg)); + valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); + stateAtv = actState(state); + output = _mm256_mul_ps(valueOg, stateAtv); + } +#endif +#endif +}; + +} // namespace forward + +namespace backward { + +template +class lstm { + public: + HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, + T &stateGrad, T &stateAtv, T &outputGrad, + T &checkI, T &checkF, T &checkO, T &checkIGrad, + T &checkFGrad, T &checkOGrad, + typename hppl::BackwardActType::type actInput, + typename hppl::BackwardActType::type actGate, + typename hppl::BackwardActType::type actState) { + gradOg = actGate(outputGrad * stateAtv, valueOg); + stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = actInput(stateGrad * valueIg, valueIn); + gradIg = actGate(stateGrad * valueIn, valueIg); + gradFg = actGate(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; + } +#ifndef __NVCC__ +#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default + static const bool avx = false; +#else + // Only float support AVX optimization + static const bool avx = std::is_same::value; + HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, + __m256 &checkFGrad, __m256 &checkOGrad, + hppl::Active<__m256>::backward actInput, + hppl::Active<__m256>::backward actGate, + hppl::Active<__m256>::backward actState) { + gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); + stateGrad = _mm256_add_ps( + actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); + stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); + gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); + gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); + gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); + prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), + _mm256_mul_ps(gradFg, checkF)); + prevStateGrad = + _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); + checkIGrad = _mm256_mul_ps(gradIg, prevState); + checkFGrad = _mm256_mul_ps(gradFg, prevState); + checkOGrad = _mm256_mul_ps(gradOg, state); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..0febf8e3b70111d12f858cf6259a2801a42d9a90 --- /dev/null +++ b/paddle/operators/math/lstm_compute.cc @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(detail::forward::lstm(), value, frame_size, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; + if (value.prevStateValue) { + value.prevStateValue += frame_size; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, + frame_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frame_size * 4; + value.stateValue += frame_size; + value.stateActiveValue += frame_size; + value.outputValue += frame_size; + if (value.prevStateValue) { + value.prevStateValue += frame_size; + } + + grad.gateGrad += frame_size * 4; + grad.stateGrad += frame_size; + grad.stateActiveGrad += frame_size; + grad.outputGrad += frame_size; + if (grad.prevStateGrad) { + grad.prevStateGrad += frame_size; + } + } + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..b2122f2a5c08a6d9d53293833177f0ba2c3ab860 --- /dev/null +++ b/paddle/operators/math/lstm_compute.cu @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/lstm_gpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + detail::gpu_lstm_forward(context, detail::forward::lstm(), value, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(const platform::DeviceContext& context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string& gate_act, const std::string& cell_act, + const std::string& cand_act) { + detail::gpu_lstm_backward(context, detail::backward::lstm(), value, grad, + frame_size, batch_size, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + } +}; + +template class LstmUnitFunctor; +template class LstmUnitFunctor; +template class LstmUnitGradFunctor; +template class LstmUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..28d2c6fd3b0d8143da90c37f241072e37397f98b --- /dev/null +++ b/paddle/operators/math/lstm_compute.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +typedef enum { + HL_ACTIVATION_SIGMOID = 0, + HL_ACTIVATION_RELU = 1, + HL_ACTIVATION_TANH = 2, + HL_ACTIVATION_LINEAR = 3, + HL_ACTIVATION_END +} activation_mode_t; + +template +struct LstmMetaValue { + T *gateValue; + T *prevStateValue; + T *stateValue; + T *stateActiveValue; + T *outputValue; + T *checkIg; + T *checkFg; + T *checkOg; +}; + +template +struct LstmMetaGrad { + T *gateGrad; + T *prevStateGrad; + T *stateGrad; + T *stateActiveGrad; + T *outputGrad; + T *checkIgGrad; + T *checkFgGrad; + T *checkOgGrad; +}; + +inline activation_mode_t ActiveType(const std::string &type) { + if (type == "sigmoid") { + return HL_ACTIVATION_SIGMOID; + } else if (type == "relu") { + return HL_ACTIVATION_RELU; + } else if (type == "tanh") { + return HL_ACTIVATION_TANH; + } else if (type == "linear" || type == "identity" || type == "") { + return HL_ACTIVATION_LINEAR; + } else { + PADDLE_THROW("Do not support activation type."); + } +} + +template +class LstmUnitFunctor { + public: + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); +}; + +template +class LstmUnitGradFunctor { + public: + static void compute(const platform::DeviceContext &context, + LstmMetaValue value, LstmMetaGrad grad, + int frame_size, int batch_size, + const std::string &gate_act, const std::string &cell_act, + const std::string &cand_act); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc new file mode 100644 index 0000000000000000000000000000000000000000..10c6e105b950b9d510e7a14828d72531e8eb0028 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2UL, + "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + for (int i = 0; i < height; ++i) { + if (is_src_index) { + memcpy(dst_data + i * width, src_data + index[i] * width, + width * sizeof(T)); + } else { + memcpy(dst_data + index[i] * width, src_data + i * width, + width * sizeof(T)); + } + } + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu new file mode 100644 index 0000000000000000000000000000000000000000..4f349946785171e6c59b22163ba76791c7244f88 --- /dev/null +++ b/paddle/operators/math/sequence2batch.cu @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/sequence2batch.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index, + int64_t height, int64_t width, + bool is_src_index) { + int idx = threadIdx.x; + int idy = threadIdx.y; + int id = blockIdx.x + idy * GridDimX; + while (id < height) { + int src_idx = is_src_index ? index[id] : id; + int dst_idx = is_src_index ? id : index[id]; + const T* src_data = src + src_idx * width; + T* dst_data = dst + dst_idx * width; + for (int i = idx; i < width; i += BlockDimX) { + dst_data[i] = src_data[i]; + } + id += BlockDimY * GridDimX; + } +} + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE_EQ(src_dims.size(), 2, + "The src must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(dst_dims.size(), 2, + "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + + dim3 threads(128, 8); + dim3 grid(8, 1); + auto stream = + reinterpret_cast(context).stream(); + CopyMatrixRowsKernel<<>>( + src_data, dst_data, index, height, width, is_src_index); + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h new file mode 100644 index 0000000000000000000000000000000000000000..03cd018e46e90c9bbe689c9686377e0e998ee513 --- /dev/null +++ b/paddle/operators/math/sequence2batch.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& src, const size_t* index, + framework::LoDTensor& dst, bool is_src_index); +}; + +template +class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& lod_tensor, + framework::LoDTensor& batch, bool is_reverse) const { + auto lods = lod_tensor.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); + auto lod = lods[0]; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + + // calculate the start position of each batch + // (numBatch equal the maxLength of sequences) + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // num_batch = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + // The batch number represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + + paddle::framework::LoD batch_lods; + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + + // batch_lods[0] is the start positions for batch LoDTensor + int num_batch = seq_info[0].length; + batch_lods[0].resize(static_cast(num_batch + 1)); + // batch_lods[1] is the raw index in the input LoDTensor + auto dims = lod_tensor.dims(); + batch_lods[1].resize(static_cast(dims[0])); + + size_t* batch_starts = batch_lods[0].data(); + size_t* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (size_t n = 0; n < num_batch; n++) { + auto batch_id = static_cast(batch_starts[n]); + for (size_t i = 0; i < seq_info.size(); ++i) { + size_t seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = static_cast(batch_id); + } + batch.set_lod(batch_lods); + + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, seq2batch_idx, batch, true); + } +}; + +template +class Batch2LoDTensorFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::LoDTensor& batch, + framework::LoDTensor& lod_tensor) const { + auto in_lod = batch.lod(); + PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, + "The LoD size of input `batch` should be 2."); + auto out_lod = lod_tensor.lod()[0]; + auto num = out_lod[out_lod.size() - 1]; + PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); + PADDLE_ENFORCE_EQ(num, in_lod[1].size()); + PADDLE_ENFORCE_EQ(num, batch.dims()[0]); + CopyMatrixRowsFunctor to_seq; + size_t* index = in_lod[1].data(); + to_seq(context, batch, index, lod_tensor, false); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 215fa0b94e423755b7bc3f05a2b14a8c85451202..169052fe412f546a5081c383da4520e4deb6c122 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -242,7 +242,7 @@ class OpTest(unittest.TestCase): self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + out_name + ") has diff at " + str(place)) else: actual = np.array(self.scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] @@ -250,7 +250,7 @@ class OpTest(unittest.TestCase): self.assertTrue( np.allclose( actual, expect, atol=atol), - "output name: " + out_name + " has diff.") + "Output (" + out_name + ") has diff at " + str(place)) def check_output(self, atol=1e-5): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bcce8d32c944a39e6d6aad4c99f8aa152222c3c1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -0,0 +1,185 @@ +import unittest +import numpy as np +from op_test import OpTest + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + + +def identity(x): + return x + + +def sigmoid(x): + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) + + +def tanh(x): + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +ACTVATION = { + 'identity': identity, + 'sigmoid': sigmoid, + 'tanh': tanh, + 'relu': relu +} + + +def lstm( + input, # T x 4D + lod, # 1 x N + h0=None, # N x D + c0=None, # N x D + w_h=None, # D x 4D + w_b=None, # 1 x 4D + w_c=None, # 1 x 3D + is_reverse=False, + act_gate=None, + act_cell=None, + act_cand=None): + def _step(x, w_h, w_c, h_pre, c_pre, act_gate, act_cell, act_cand): + g = np.dot(h_pre, w_h) # 1 x 4D + g = g + x + g = np.reshape(g, (1, g.size)) + c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1) + if w_c is None: + g_i = act_gate(g_i) # 1 x D + g_f = act_gate(g_f) # 1 x D + else: + w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1) + g_i = act_gate(g_i + w_ic * c_pre) # 1 x D + g_f = act_gate(g_f + w_fc * c_pre) # 1 x D + c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D + + if w_c is None: + g_o = act_gate(g_o) # 1 x D + else: + _, _, w_oc = np.split(w_c, 3, axis=1) + g_o = act_gate(g_o + w_oc * c) # 1 x D + h = g_o * act_cell(c) + bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1) + return h, c, bg + + def _reverse(x, lod): + y = np.zeros_like(x) + for i in range(len(lod) - 1): + b, e = lod[i], lod[i + 1] + y[b:e, :] = np.flip(x[b:e, :], 0) + return y + + offset = lod[0] + batch_size = len(offset) - 1 + hidden = [] + cell = [] + gate = [] + input = _reverse(input, offset) if is_reverse else input + if w_b is not None: + input = input + np.tile(w_b, (offset[-1], 1)) + for i in range(batch_size): + # compute one sequence + seq_len = offset[i + 1] - offset[i] + x = input[offset[i]:offset[i + 1], :] + h_pre = h0[i] # 1 x D + c_pre = c0[i] # 1 x D + for j in range(seq_len): + # compute one step + h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate, + act_cell, act_cand) + hidden.append(h_pre.flatten()) + cell.append(c_pre.flatten()) + gate.append(g_pre.flatten()) + + hidden = np.array(hidden).astype("float64") + cell = np.array(cell).astype("float64") + gate = np.array(gate).astype("float64") + + hidden = _reverse(hidden, offset) if is_reverse else hidden + cell = _reverse(cell, offset) if is_reverse else cell + + assert gate.shape == input.shape + assert hidden.shape == (input.shape[0], input.shape[1] / 4) + assert cell.shape == (input.shape[0], input.shape[1] / 4) + return hidden, cell, gate + + +class TestLstmOp(OpTest): + def set_data(self): + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = False + + def setUp(self): + self.set_data() + self.op_type = "lstm" + + T = self.lod[0][-1] + N = len(self.lod[0]) - 1 + + x = np.random.normal(size=(T, 4 * self.D)).astype("float64") + h0 = np.zeros((N, self.D)).astype("float64") + c0 = np.zeros((N, self.D)).astype("float64") + w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64") + b = np.random.normal(size=(1, 7 * self.D)).astype("float64") + + w_b = b[:, 0:4 * self.D] + w_c = b[:, 4 * self.D:] + h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, + ACTVATION[self.act_gate], ACTVATION[self.act_cell], + ACTVATION[self.act_cand]) + + g_sort = np.zeros_like(x) + for i, j in enumerate(self.sort_idx): + g_sort[i, :] = g[j, :] + + self.inputs = { + 'Input': (x, self.lod), + 'H0': h0, + 'C0': c0, + 'Weight': w, + 'Bias': b + } + self.outputs = {'Hidden': h, 'Cell': c, 'BatchGate': g_sort} + self.attrs = { + 'usePeepholes': True, + 'isReverse': self.is_reverse, + 'gateActivation': 'sigmoid', + 'cellActivation': 'tanh', + 'candidateActivation': 'tanh' + } + + def test_check_output(self): + self.check_output() + + +class TestLstmOpRerverse(TestLstmOp): + def set_data(self): + self.lod = [[0, 2, 6, 9]] + self.D = 64 + self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5] + + self.act_gate = "sigmoid" + self.act_cell = "tanh" + self.act_cand = "tanh" + + self.is_reverse = True + + +if __name__ == "__main__": + unittest.main()