From 3cace73701a052c6593f6cf9151be14c3874f2e8 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 16 Oct 2017 13:23:08 +0800 Subject: [PATCH] Add lstm implementation. --- paddle/operators/lstm_op.cc | 54 +++- paddle/operators/lstm_op.h | 35 +- .../math/detail/hl_activation_functions.h | 64 ++++ .../operators/math/detail/hl_avx_functions.cc | 68 ++++ .../operators/math/detail/hl_avx_functions.h | 32 ++ .../operators/math/detail/hl_cpu_functions.cc | 44 +++ paddle/operators/math/detail/hl_functions.h | 63 ++++ .../operators/math/detail/hl_gpu_functions.h | 80 +++++ .../operators/math/detail/lstm_cpu_kernel.h | 306 ++++++++++++++++++ .../operators/math/detail/lstm_gpu_kernel.h | 244 ++++++++++++++ paddle/operators/math/detail/lstm_kernel.h | 138 ++++++++ paddle/operators/math/lstm_compute.cc | 73 +++++ paddle/operators/math/lstm_compute.cu | 73 +++++ paddle/operators/math/lstm_compute.h | 87 +++++ paddle/operators/math/sequence2batch.cc | 31 ++ paddle/operators/math/sequence2batch.cu | 47 +++ paddle/operators/math/sequence2batch.h | 19 +- 17 files changed, 1436 insertions(+), 22 deletions(-) create mode 100644 paddle/operators/math/detail/hl_activation_functions.h create mode 100644 paddle/operators/math/detail/hl_avx_functions.cc create mode 100644 paddle/operators/math/detail/hl_avx_functions.h create mode 100644 paddle/operators/math/detail/hl_cpu_functions.cc create mode 100644 paddle/operators/math/detail/hl_functions.h create mode 100644 paddle/operators/math/detail/hl_gpu_functions.h create mode 100644 paddle/operators/math/detail/lstm_cpu_kernel.h create mode 100644 paddle/operators/math/detail/lstm_gpu_kernel.h create mode 100644 paddle/operators/math/detail/lstm_kernel.h create mode 100644 paddle/operators/math/lstm_compute.cc create mode 100644 paddle/operators/math/lstm_compute.cu create mode 100644 paddle/operators/math/lstm_compute.h diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 6233e12923..1803aa1e44 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -1,18 +1,18 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "paddle/operators/lstm_unit_op.h" +#include "paddle/operators/lstm_op.h" namespace paddle { namespace operators { @@ -44,8 +44,36 @@ class LSTMOp : public framework::OperatorWithKernel { "should be the same."); } + int frame_size = x_dims[1]; + auto w_dims = ctx->GetInputDim("Weight"); + PADDLE_ENFORCE_EQ(w_dims.size(), 2, + "The rank of Input(Weight) should be 2."); + PADDLE_ENFORCE_EQ(w_dims[0], frame_size, + "The first dimension of Input(Weight) " + "should be %d.", + frame_size); + PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size, + "The second dimension of Input(Weight) " + "should be 4 * %d.", + frame_size); + auto b_dims = ctx->GetInputDim("Bias"); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims[0], 1, + "The first dimension of Input(Bias) should be 1."); + if (ctx->Attrs().Get("use_peepholes")) { + PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection", + frame_size); + } else { + PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + "The second dimension of Input(Bias) should be " + "4 * %d if diable peepholes connection", + frame_size); + } ctx->SetOutputDim("Hidden", x_dims); ctx->SetOutputDim("Cell", x_dims); + ctx->SetOutputDim("Hidden", x_dims); ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Cell"); } @@ -82,6 +110,8 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "2. `use_peepholes = True` " " - The shape is (1 x 7*D). " " - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}."); + AddOutput("Batch", "(LoDTensor) save the reorganized input as batch info. ") + .AsIntermediate(); AddOutput("Hidden", "(LoDTensor) the hidden state lod tensor of LSTM operator. " "The shape and lod is the same with the `Input`."); @@ -92,6 +122,10 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, defalut: True) " "whether to enable diagonal/peephole connections.") .SetDefault(true); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed LSTM.") + .SetDefault(true); AddAttr( "gate_activation", "(string, defalut: sigmoid)" diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index 6e77cadead..037f0485a1 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -1,19 +1,18 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once -#include "glog/logging.h" #include "paddle/framework/op_registry.h" namespace paddle { @@ -25,7 +24,21 @@ using framework::Tensor; template class LSTMKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override {} + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input_t = ctx.Input("Input"); + auto* batch_t = ctx.Input("Batch"); + auto* bias_t = ctx.Input("Bias"); + bool is_reverse = ctx.Attr("is_reverse"); + LoDTensor2BatchFunctor to_batch(ctx.device_context(), input_t, + batch_t, is_reverse); + + auto in_dims = input_t->dims(); + int frame_size = in_dims[1]; + + if (bias_t) { + auto b = EigenMatrix::From(*bias); + } + } }; template diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h new file mode 100644 index 0000000000..d5cf874636 --- /dev/null +++ b/paddle/operators/math/detail/hl_activation_functions.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_ACTIVATION_FUNCTIONS_H_ +#define HL_ACTIVATION_FUNCTIONS_H_ + +#include "hl_functions.h" + +/** + * Active functions: sigmoid, relu, tanh and linear. + */ +#define HPPL_ACTIVE_FUNCTION \ + { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } + +namespace hppl { + +/** + * Hppl supports sigmoid, relu, tanh, linear active functions + * for neural networks' forward and backward activation. + */ +template +class Active { + public: + typedef T (*forward)(T); + typedef T (*backward)(T, T); +}; + +#ifdef __NVCC__ +namespace gpu { +static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static __device__ Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace gpu +#else +namespace cpu { +static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +static Active::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace cpu + +#ifdef __AVX__ +namespace avx { +static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION; +static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION; +} // namespace avx +#endif +#endif + +} // namespace hppl + +#endif // HL_ACTIVATION_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/hl_avx_functions.cc new file mode 100644 index 0000000000..70e7d80304 --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "hl_functions.h" + +namespace hppl { + +extern __m256 exp(__m256 a); + +__m256 relu(const __m256 a) { + __m256 tmp = _mm256_set1_ps(0.0f); + return _mm256_max_ps(a, tmp); +} + +__m256 sigmoid(const __m256 a) { + __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); + __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); + __m256 tmp = _mm256_max_ps(a, min); + tmp = _mm256_min_ps(tmp, max); + tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); + tmp = exp(tmp); + tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); + tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); + return tmp; +} + +__m256 tanh(const __m256 a) { + __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); + __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); + tmp = _mm256_min_ps(tmp, max); + tmp = exp(tmp); + return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), + _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), + _mm256_set1_ps(1.0f)); +} + +__m256 linear(const __m256 a) { return a; } + +__m256 relu(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), + _mm256_set1_ps(1.0f))); +} + +__m256 sigmoid(const __m256 a, const __m256 b) { + return _mm256_mul_ps(_mm256_mul_ps(a, b), + _mm256_sub_ps(_mm256_set1_ps(1.0f), b)); +} + +__m256 tanh(const __m256 a, const __m256 b) { + return _mm256_mul_ps( + a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); +} + +__m256 linear(const __m256 a, const __m256 b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_avx_functions.h b/paddle/operators/math/detail/hl_avx_functions.h new file mode 100644 index 0000000000..35f4eabb4c --- /dev/null +++ b/paddle/operators/math/detail/hl_avx_functions.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_AVX_FUNCTIONS_H_ +#define HL_AVX_FUNCTIONS_H_ + +#include + +namespace hppl { +__m256 relu(const __m256 a); +__m256 sigmoid(const __m256 a); +__m256 tanh(const __m256 a); +__m256 linear(const __m256 a); + +__m256 relu(const __m256 a, const __m256 b); +__m256 sigmoid(const __m256 a, const __m256 b); +__m256 tanh(const __m256 a, const __m256 b); +__m256 linear(const __m256 a, const __m256 b); +} // namespace hppl + +#endif // HL_AVX_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc new file mode 100644 index 0000000000..b42e11fd90 --- /dev/null +++ b/paddle/operators/math/detail/hl_cpu_functions.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "/paddle/operators/math/detail/hl_functions.h" + +namespace hppl { + +real relu(const real a) { return a > 0.0f ? a : 0.0f; } + +real sigmoid(const real a) { + const real min = SIGMOID_THRESHOLD_MIN; + const real max = SIGMOID_THRESHOLD_MAX; + real tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +real tanh(const real a) { + real tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +real linear(const real a) { return a; } + +real relu(const real a, const real b) { return a * (b > 0.0f ? 1.0f : 0.0f); } + +real sigmoid(const real a, const real b) { return a * b * (1 - b); } + +real tanh(const real a, const real b) { return a * (1.0f - b * b); } + +real linear(const real a, const real b) { return a; } +} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h new file mode 100644 index 0000000000..4eda1adfe9 --- /dev/null +++ b/paddle/operators/math/detail/hl_functions.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_FUNCTIONS_H_ +#define HL_FUNCTIONS_H_ + +/** + * sigmoid threshold maximum + */ +#define SIGMOID_THRESHOLD_MIN -40.0 + +/** + * sigmoid threshold minimum + */ +#define SIGMOID_THRESHOLD_MAX 13.0 + +#ifndef __NVCC__ +namespace hppl { +/* + * forward activation + */ +template +T relu(const T a); +template +T sigmoid(const T a); +template +T tanh(const T a); +template +T linear(const T a); + +/* + * backward activation + */ +template +T relu(const T a, const T b); +template +T sigmoid(const T a, const T b); +template +T tanh(const T a, const T b); +template +T linear(const T a, const T b); +} // namespace hppl + +#ifdef __AVX__ +#include "hl_avx_functions.h" +#endif + +#else +#include "hl_gpu_functions.h" +#endif + +#endif // HL_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h new file mode 100644 index 0000000000..25fa7c409a --- /dev/null +++ b/paddle/operators/math/detail/hl_gpu_functions.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_GPU_FUNCTIONS_CUH_ +#define HL_GPU_FUNCTIONS_CUH_ + +#include "hl_base.h" + +namespace hppl { + +template +__device__ static T relu(const T a) { + return a > 0.0f ? a : 0.0f; +} + +template <> +__device__ static float sigmoid(const float a) { + const float min = SIGMOID_THRESHOLD_MIN; + const float max = SIGMOID_THRESHOLD_MAX; + float tmp = (a < min) ? min : ((a > max) ? max : a); + return __fdividef(1.0f, 1.0f + __expf(-tmp)); +} + +template <> +__device__ static double sigmoid(const double a) { + const double min = SIGMOID_THRESHOLD_MIN; + const double max = SIGMOID_THRESHOLD_MAX; + double tmp = (a < min) ? min : ((a > max) ? max : a); + return 1.0 / (1.0 + exp(-tmp)); +} + +template <> +__device__ static float tanh(const float a) { + return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f; +} + +template <> +__device__ static double tanh(const double a) { + return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; +} + +template +__device__ static T linear(const T a) { + return a; +} + +template +__device__ static T relu(const T a, const T b) { + return a * (b > 0.0f ? 1.0f : 0.0f); +} + +template +__device__ static T sigmoid(const T a, const T b) { + return a * b * (1 - b); +} + +template +__device__ static T tanh(const T a, const T b) { + return a * (1.0f - b * b); +} + +template +__device__ static T linear(const T a, const T b) { + return a; +} + +} // namespace hppl + +#endif // HL_GPU_FUNCTIONS_CUH_ diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h new file mode 100644 index 0000000000..a8e78a449d --- /dev/null +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -0,0 +1,306 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/math/lstm_compute.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI; + T rCheckF; + T rCheckO; + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::cpu::forward[active_node], + hppl::cpu::forward[active_gate], hppl::cpu::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + value.stateValue[i] = rState; + value.stateActiveValue[i] = rStateAtv; + value.outputValue[i] = rOut; + } +} + +template +void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI; + T rCheckF; + T rCheckO; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + T *valueIn = value.gateValue; + T *valueIg = value.gateValue + frameSize; + T *valueFg = value.gateValue + frameSize * 2; + T *valueOg = value.gateValue + frameSize * 3; + T *gradIn = grad.gateGrad; + T *gradIg = grad.gateGrad + frameSize; + T *gradFg = grad.gateGrad + frameSize * 2; + T *gradOg = grad.gateGrad + frameSize * 3; + + for (int i = 0; i < frameSize; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = value.checkIg[i]; + rCheckF = value.checkFg[i]; + rCheckO = value.checkOg[i]; + rState = value.stateValue[i]; + rStateAtv = value.stateActiveValue[i]; + rOutputGrad = grad.outputGrad[i]; + rStateGrad = grad.stateGrad[i]; + if (value.prevStateValue) { + rPrevState = value.prevStateValue[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::cpu::backward[active_node], + hppl::cpu::backward[active_gate], hppl::cpu::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + grad.stateGrad[i] = rStateGrad; + + if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad; + } +} + +template +void avx_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rState; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rStateAtv; + __m256 rOut; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], + hppl::avx::forward[active_gate], hppl::avx::forward[active_state]); + + valueIn[i] = rValueIn; + valueIg[i] = rValueIg; + valueFg[i] = rValueFg; + valueOg[i] = rValueOg; + ((__m256 *)value.stateValue)[i] = rState; + ((__m256 *)value.stateActiveValue)[i] = rStateAtv; + ((__m256 *)value.outputValue)[i] = rOut; + } +#endif +} + +template +void avx_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad, + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { +#ifdef __AVX__ + __m256 rValueIn; + __m256 rValueIg; + __m256 rValueFg; + __m256 rValueOg; + __m256 rGradIn; + __m256 rGradIg; + __m256 rGradFg; + __m256 rGradOg; + __m256 rPrevState = _mm256_set1_ps(0.0f); + __m256 rPrevStateGrad; + __m256 rStateGrad; + __m256 rState; + __m256 rStateAtv; + __m256 rOutputGrad; + __m256 rCheckI; + __m256 rCheckF; + __m256 rCheckO; + __m256 rCheckIGrad; + __m256 rCheckFGrad; + __m256 rCheckOGrad; + + __m256 *valueIn = (__m256 *)value.gateValue; + __m256 *valueIg = (__m256 *)(value.gateValue + frameSize); + __m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2); + __m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3); + __m256 *gradIn = (__m256 *)grad.gateGrad; + __m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize); + __m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2); + __m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3); + + for (int i = 0; i < frameSize / 8; i++) { + rValueIn = valueIn[i]; + rValueIg = valueIg[i]; + rValueFg = valueFg[i]; + rValueOg = valueOg[i]; + rCheckI = ((__m256 *)value.checkIg)[i]; + rCheckF = ((__m256 *)value.checkFg)[i]; + rCheckO = ((__m256 *)value.checkOg)[i]; + rState = ((__m256 *)value.stateValue)[i]; + rStateAtv = ((__m256 *)value.stateActiveValue)[i]; + rOutputGrad = ((__m256 *)grad.outputGrad)[i]; + rStateGrad = ((__m256 *)grad.stateGrad)[i]; + if (value.prevStateValue) { + rPrevState = ((__m256 *)value.prevStateValue)[i]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, + rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, + rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, + rCheckOGrad, hppl::avx::backward[active_node], + hppl::avx::backward[active_gate], hppl::avx::backward[active_state]); + + gradIn[i] = rGradIn; + gradIg[i] = rGradIg; + gradFg[i] = rGradFg; + gradOg[i] = rGradOg; + ((__m256 *)grad.stateGrad)[i] = rStateGrad; + + if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad; + if (value.prevStateValue) { + if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad; + if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad; + } + if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad; + } +#endif +} + +template +void cpu_lstm_forward(Op op, lstm_value value, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + avx_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); + } +} + +template +void cpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } else { + naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h new file mode 100644 index 0000000000..8d0274c19d --- /dev/null +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -0,0 +1,244 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/operators/math/detail/lstm_kernel.h" +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmForward(Op op, lstm_value value, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.outputValue += batchIdx * frameSize; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + } + + T rState; + T rPrevState = 0; + T rStateAtv; + T rOut; + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, + rOut, rCheckI, rCheckF, rCheckO, hppl::gpu::forward[active_node], + hppl::gpu::forward[active_gate], hppl::gpu::forward[active_state]); + + value.gateValue[frameIdx] = rValueIn; + value.gateValue[frameIdx + frameSize] = rValueIg; + value.gateValue[frameIdx + frameSize * 2] = rValueFg; + value.gateValue[frameIdx + frameSize * 3] = rValueOg; + + value.stateValue[frameIdx] = rState; + value.stateActiveValue[frameIdx] = rStateAtv; + value.outputValue[frameIdx] = rOut; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad, + int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + value.gateValue += batchIdx * frameSize * 4; + value.stateValue += batchIdx * frameSize; + value.stateActiveValue += batchIdx * frameSize; + grad.gateGrad += batchIdx * frameSize * 4; + grad.stateGrad += batchIdx * frameSize; + grad.outputGrad += batchIdx * frameSize; + } + + T rValueIn; + T rValueIg; + T rValueFg; + T rValueOg; + T rGradIn; + T rGradIg; + T rGradFg; + T rGradOg; + T rPrevState = 0; + T rPrevStateGrad; + T rState; + T rStateGrad; + T rStateAtv; + T rOutputGrad; + T rCheckI = value.checkIg[frameIdx]; + T rCheckF = value.checkFg[frameIdx]; + T rCheckO = value.checkOg[frameIdx]; + T rCheckIGrad; + T rCheckFGrad; + T rCheckOGrad; + + rValueIn = value.gateValue[frameIdx]; + rValueIg = value.gateValue[frameIdx + frameSize]; + rValueFg = value.gateValue[frameIdx + frameSize * 2]; + rValueOg = value.gateValue[frameIdx + frameSize * 3]; + rState = value.stateValue[frameIdx]; + rStateAtv = value.stateActiveValue[frameIdx]; + rOutputGrad = grad.outputGrad[frameIdx]; + rStateGrad = grad.stateGrad[frameIdx]; + + if (value.prevStateValue) { + if (isBatch) value.prevStateValue += batchIdx * frameSize; + rPrevState = value.prevStateValue[frameIdx]; + } + + op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, + rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, + hppl::gpu::backward[active_node], hppl::gpu::backward[active_gate], + hppl::gpu::backward[active_state]); + + grad.gateGrad[frameIdx] = rGradIn; + grad.gateGrad[frameIdx + frameSize] = rGradIg; + grad.gateGrad[frameIdx + frameSize * 2] = rGradFg; + grad.gateGrad[frameIdx + frameSize * 3] = rGradOg; + grad.stateGrad[frameIdx] = rStateGrad; + if (grad.prevStateGrad) { + if (isBatch) grad.prevStateGrad += batchIdx * frameSize; + grad.prevStateGrad[frameIdx] = rPrevStateGrad; + } + + if (isBatch) { + if (value.prevStateValue) { + if (grad.checkIgGrad) + paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx, + rCheckIGrad); + if (grad.checkFgGrad) + paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx, + rCheckFGrad); + } + if (grad.checkOgGrad) + paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad); + } else { + if (value.prevStateValue) { + if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad; + if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad; + } + if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad; + } +} + +template +void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (batchSize == 1) { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmForward<<>>( + op, value, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +template +void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + /* framePerBlock = 32 batchPerBlock = 32 */ + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (batchSize == 1) { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } else { + KeLstmBackward<<>>( + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); + } +} + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h new file mode 100644 index 0000000000..107030f8ba --- /dev/null +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hl_activation_functions.h" + +#ifdef __CUDA_ARCH__ +#define INLINE __device__ inline +#else +#define INLINE inline +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +namespace forward { + +template +class lstm { + public: + INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &prevState, T &state, T &stateAtv, T &output, + T &checkI, T &checkF, T &checkO, + Active::forward actInput, + Active::forward actGate, + Active::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(valueIg + prevState * checkI); + valueFg = actGate(valueFg + prevState * checkF); + state = valueIn * valueIg + prevState * valueFg; + valueOg = actGate(valueOg + state * checkO); + stateAtv = actState(state); + output = valueOg * stateAtv; + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &prevState, __m256 &state, + __m256 &stateAtv, __m256 &output, __m256 &checkI, + __m256 &checkF, __m256 &checkO, + Active<__m256>::forward actInput, + Active<__m256>::forward actGate, + Active<__m256>::forward actState) { + valueIn = actInput(valueIn); + valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); + valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); + state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), + _mm256_mul_ps(prevState, valueFg)); + valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); + stateAtv = actState(state); + output = _mm256_mul_ps(valueOg, stateAtv); + } +#endif +#endif +}; + +} // namespace forward + +namespace backward { + +template +class lstm { + public: + INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, + T &gradIn, T &gradIg, T &gradFg, T &gradOg, + T &prevState, T &prevStateGrad, T &state, T &stateGrad, + T &stateAtv, T &outputGrad, T &checkI, T &checkF, + T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad, + Active::backward actInput, + Active::backward actGate, + Active::backward actState) { + gradOg = actGate(outputGrad * stateAtv, valueOg); + stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; + gradIn = actInput(stateGrad * valueIg, valueIn); + gradIg = actGate(stateGrad * valueIn, valueIg); + gradFg = actGate(stateGrad * prevState, valueFg); + prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; + checkIGrad = gradIg * prevState; + checkFGrad = gradFg * prevState; + checkOGrad = gradOg * state; + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, + __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, + __m256 &gradFg, __m256 &gradOg, __m256 &prevState, + __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, + __m256 &outputGrad, __m256 &checkI, __m256 &checkF, + __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, + __m256 &checkOGrad, Active<__m256>::backward actInput, + Active<__m256>::backward actGate, + Active<__m256>::backward actState) { + gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); + stateGrad = _mm256_add_ps( + actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); + stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); + gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); + gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); + gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); + prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), + _mm256_mul_ps(gradFg, checkF)); + prevStateGrad = + _mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad); + checkIGrad = _mm256_mul_ps(gradIg, prevState); + checkFGrad = _mm256_mul_ps(gradFg, prevState); + checkOGrad = _mm256_mul_ps(gradOg, state); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle + +#endif /* HL_LSTM_OPS_CUH_ */ diff --git a/paddle/operators/math/lstm_compute.cc b/paddle/operators/math/lstm_compute.cc new file mode 100644 index 0000000000..77d317048a --- /dev/null +++ b/paddle/operators/math/lstm_compute.cc @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "LstmCompute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::cpu_lstm_forward(detail::forward::lstm(), value, frameSize, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act) { + for (int b = 0; b < batchSize; b++) { + detail::cpu_lstm_backward(detail::backward::lstm(), value, grad, + frameSize, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + + grad.gateGrad += frameSize * 4; + grad.stateGrad += frameSize; + grad.stateActiveGrad += frameSize; + grad.outputGrad += frameSize; + if (grad.prevStateGrad) { + grad.prevStateGrad += frameSize; + } + } + }; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.cu b/paddle/operators/math/lstm_compute.cu new file mode 100644 index 0000000000..a7e23920aa --- /dev/null +++ b/paddle/operators/math/lstm_compute.cu @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "LstmCompute.h" +#include "paddle/operators/math/detail/lstm_cpu_kernel.h" +#include "paddle/operators/math/detail/lstm_kernel.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct LstmUnitFunctor { + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act) { + for (int b = 0; b < batch_size; b++) { + detail::gpu_lstm_forward(detail::forward::lstm(), value, frameSize, + ActiveType(cand_act), ActiveType(gate_act), + ActiveType(cell_act)); + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + } + } +}; + +template +struct LstmUnitGradFunctor { + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act) { + for (int b = 0; b < batchSize; b++) { + detail::gpu_lstm_backward(detail::backward::lstm(), value, grad, + frameSize, ActiveType(cand_act), + ActiveType(gate_act), ActiveType(cell_act)); + + value.gateValue += frameSize * 4; + value.stateValue += frameSize; + value.stateActiveValue += frameSize; + value.outputValue += frameSize; + if (value.prevStateValue) { + value.prevStateValue += frameSize; + } + + grad.gateGrad += frameSize * 4; + grad.stateGrad += frameSize; + grad.stateActiveGrad += frameSize; + grad.outputGrad += frameSize; + if (grad.prevStateGrad) { + grad.prevStateGrad += frameSize; + } + } + }; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h new file mode 100644 index 0000000000..2d7fccf1a0 --- /dev/null +++ b/paddle/operators/math/lstm_compute.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/macros.h" + +namespace paddle { +namespace operators { +namespace math { + +typedef enum { + HL_ACTIVATION_SIGMOID = 0, + HL_ACTIVATION_RELU = 1, + HL_ACTIVATION_TANH = 2, + HL_ACTIVATION_LINEAR = 3, + HL_ACTIVATION_END +} activation_mode_t; + +template +struct lstm_value { + real *gateValue; + real *prevStateValue; + real *stateValue; + real *stateActiveValue; + real *outputValue; + real *checkIg; + real *checkFg; + real *checkOg; +}; + +template +struct lstm_grad { + real *gateGrad; + real *prevStateGrad; + real *stateGrad; + real *stateActiveGrad; + real *outputGrad; + real *checkIgGrad; + real *checkFgGrad; + real *checkOgGrad; +}; + +activation_mode_t ActiveType(const std::string &type) { + if (type == "sigmoid") { + return HL_ACTIVATION_SIGMOID; + } else if (type == "relu") { + return HL_ACTIVATION_RELU; + } else if (type == "tanh") { + return HL_ACTIVATION_TANH; + } else if (type == "linear" || type == "") { + return HL_ACTIVATION_LINEAR; + } else { + PADDLE_THROW("Do not support activation type."); + } +} + +template +class LstmUnitFunctor { + public: + static void compute(lstm_value value, int frame_size, int batch_size, + std::string gate_act, std::string cell_act, + std::string cand_act); +}; + +template +class LstmUnitGradFunctor { + public: + static void compute(lstm_value value, lstm_grad grad, int frame_size, + int batch_size, std::string gate_act, + std::string cell_act, std::string cand_act); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index c29baaae08..f4da949d4e 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -18,6 +18,37 @@ namespace paddle { namespace operators { namespace math { +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); + PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + for (int i = 0; i < height; ++i) { + if (is_src_index) { + memcpy(dst_data + i * width, src_data + index[i] * width, + width * sizeof(T)); + } else { + memcpy(dst_data + index[i] * width, src_data + i * width, + width * sizeof(T)); + } + } + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + template class LoDTensor2BatchFunctor; template class Batch2LoDTensor2Functor; diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index 5afb87e4a4..ecd05a30d3 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -18,6 +18,53 @@ namespace paddle { namespace operators { namespace math { +template +__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const int* index, + int height, int width, + const bool is_src_index) { + int idx = threadIdx.x; + int idy = threadIdx.y; + int id = blockIdx.x + idy * GridDimX; + while (id < height) { + int src_idx = is_src_index ? index[id] : id; + int dst_idx = is_src_index ? id : index[id]; + T* src_data = src + src_idx * width; + T* dst_data = dst + dst_idx * width; + for (int i = idx; i < width; i += BlockDimX) { + dst_data[i] = src_data[i]; + } + id += BlockDimY * GridDimX; + } +} + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, bool is_src_index) { + auto src_dims = src.dims(); + auto dst_dims = dst.dims(); + PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2."); + PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2."); + PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1], + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst.data(); + + dim3 threads(128, 8); + dim3 grid(8, 1); + auto stream = reinterpret_cast(context); + CopyMatrixRowsKernel<<>>( + src_data, dst_data, index, height, width); + } +}; + +template class CopyMatrixRowsFunctor; +template class CopyMatrixRowsFunctor; + template class LoDTensor2BatchFunctor; template class Batch2LoDTensor2Functor; diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 6ee870cf78..e662292a02 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -16,6 +16,19 @@ namespace paddle { namespace operators { namespace math { +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const platform::DeviceContext& context, + const framework::Tensor& src, const size_t* index, + framework::Tensor& dst, const bool is_src_index); +}; + template class LoDTensor2BatchFunctor { public: @@ -97,8 +110,11 @@ class LoDTensor2BatchFunctor { } batch_starts[n + 1] = batch_id; } + + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, batch, true); } -} +}; template class Batch2LoDTensor2Functor { @@ -107,6 +123,7 @@ class Batch2LoDTensor2Functor { const framework::LoDTensor& batch, framework::LoDTensor& lod_tensor, const bool is_reverse) const; +}; } // namespace math } // namespace operators -- GitLab