提交 1c8a0c4b 编写于 作者: D dangqingqing

Refine activation function pointer for LSTM operator.

上级 2c5d4c6d
...@@ -20,7 +20,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) ...@@ -20,7 +20,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
......
if(WITH_AVX) if(WITH_AVX)
cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc) cc_library(activation_functions SRCS avx_functions.cc)
else()
cc_library(activation_functions SRCS hl_cpu_functions.cc)
endif() endif()
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/platform/hostdevice.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
namespace forward {
template <typename T>
DEVICE T linear(const T a) {
return a;
}
template <typename T>
DEVICE T relu(const T a) {
return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0);
}
template <typename T>
DEVICE T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
} // namespace forward
namespace backward {
template <typename T>
DEVICE T linear(const T a, const T b) {
return a;
}
template <typename T>
DEVICE T relu(const T a, const T b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
template <typename T>
DEVICE T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
} // namespace backward
template <typename T>
struct Active {
typedef T (*Act)(T);
typedef T (*ActGrad)(T, T);
};
static DEVICE Active<float>::Act kActFloat[] = {
&forward::sigmoid<float>, &forward::relu<float>, &forward::tanh<float>,
&forward::linear<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::sigmoid<float>, &backward::relu<float>, &backward::tanh<float>,
&backward::linear<float>};
static DEVICE Active<double>::Act kActDouble[] = {
&forward::sigmoid<double>, &forward::relu<double>, &forward::tanh<double>,
&forward::linear<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::sigmoid<double>, &backward::relu<double>,
&backward::tanh<double>, &backward::linear<double>};
namespace forward {
inline DEVICE float activation(float a, int index) {
return kActFloat[index](a);
}
inline DEVICE double activation(double a, int index) {
return kActDouble[index](a);
}
} // namespace forward
namespace backward {
inline DEVICE float activation(float a, float b, int index) {
return kActGradFloat[index](a, b);
}
inline DEVICE double activation(double a, double b, int index) {
return kActGradDouble[index](a, b);
}
} // namespace backward
#ifdef __AVX__
namespace forward {
namespace avx {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
} // namespace avx
} // namespace forward
namespace backward {
namespace avx {
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
} // namespace avx
} // namespace backward
static Active<__m256>::Act kActAvx[] = {
&forward::avx::sigmoid, &forward::avx::relu, &forward::avx::tanh,
&forward::avx::linear};
static Active<__m256>::ActGrad kActGradAvx[] = {
&backward::avx::sigmoid, &backward::avx::relu, &backward::avx::tanh,
&backward::avx::linear};
namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
} // namespace forward
namespace backward {
inline __m256 activation(__m256 a, __m256 b, int index) {
return kActGradAvx[index](a, b);
}
} // namespace backward
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -13,14 +13,19 @@ See the License for the specific language governing permissions and ...@@ -13,14 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <immintrin.h> #include <immintrin.h>
#include "hl_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
// TODO(qingqing) refine this dependence // TODO(qingqing) refine this dependence
#include "paddle/cuda/src/avx_mathfun.h" #include "paddle/cuda/src/avx_mathfun.h"
namespace hppl { namespace paddle {
namespace operators {
namespace math {
namespace detail {
__m256 exp(__m256 a) { return exp256_ps(a); } __m256 exp(__m256 a) { return exp256_ps(a); }
namespace forward {
namespace avx {
__m256 relu(const __m256 a) { __m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f); __m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp); return _mm256_max_ps(a, tmp);
...@@ -50,6 +55,11 @@ __m256 tanh(const __m256 a) { ...@@ -50,6 +55,11 @@ __m256 tanh(const __m256 a) {
__m256 linear(const __m256 a) { return a; } __m256 linear(const __m256 a) { return a; }
} // namespace avx
} // namespace forward
namespace backward {
namespace avx {
__m256 relu(const __m256 a, const __m256 b) { __m256 relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps( return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
...@@ -67,4 +77,10 @@ __m256 tanh(const __m256 a, const __m256 b) { ...@@ -67,4 +77,10 @@ __m256 tanh(const __m256 a, const __m256 b) {
} }
__m256 linear(const __m256 a, const __m256 b) { return a; } __m256 linear(const __m256 a, const __m256 b) { return a; }
} // namespace hppl } // namespace avx
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
#include "paddle/operators/math/lstm_compute.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define FLOAT_ACTIVE_FUNCTION \
{ \
hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \
hppl::typef::linear \
}
#define DOUBLE_ACTIVE_FUNCTION \
{ \
hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \
hppl::typed::linear \
}
#define AVX_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace hppl {
using activation_mode_t = paddle::operators::math::activation_mode_t;
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template <class T>
class Active {
public:
typedef T (*forward)(T);
typedef T (*backward)(T, T);
};
template <typename T>
struct ForwardActType;
template <>
struct ForwardActType<float> {
using type = Active<float>::forward;
};
template <>
struct ForwardActType<double> {
using type = Active<double>::forward;
};
template <typename T>
struct BackwardActType;
template <>
struct BackwardActType<float> {
using type = Active<float>::backward;
};
template <>
struct BackwardActType<double> {
using type = Active<double>::backward;
};
#ifdef __NVCC__
namespace gpu {
static __device__ Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static __device__ Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static __device__ Active<double>::backward backward_d[] =
DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
__device__ typename ForwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct ForwardAct<float> {
__device__ ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
__device__ ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
__device__ typename BackwardActType<T>::type operator()(
activation_mode_t type);
};
template <>
struct BackwardAct<float> {
__device__ BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
__device__ BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace gpu
#else
namespace cpu {
static Active<float>::forward forward[] = FLOAT_ACTIVE_FUNCTION;
static Active<float>::backward backward[] = FLOAT_ACTIVE_FUNCTION;
static Active<double>::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION;
static Active<double>::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION;
template <typename T>
struct ForwardAct {
typename ForwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct ForwardAct<float> {
ForwardActType<float>::type operator()(activation_mode_t type) {
return forward[type];
}
};
template <>
struct ForwardAct<double> {
ForwardActType<double>::type operator()(activation_mode_t type) {
return forward_d[type];
}
};
template <typename T>
struct BackwardAct {
typename BackwardActType<T>::type operator()(activation_mode_t type);
};
template <>
struct BackwardAct<float> {
BackwardActType<float>::type operator()(activation_mode_t type) {
return backward[type];
}
};
template <>
struct BackwardAct<double> {
BackwardActType<double>::type operator()(activation_mode_t type) {
return backward_d[type];
}
};
} // namespace cpu
#ifdef __AVX__
namespace avx {
static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION;
static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION;
} // namespace avx
#endif
#endif
} // namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace hppl {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
} // namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "hl_functions.h"
namespace hppl {
namespace typef {
float relu(const float a) {
return a > static_cast<float>(0.0) ? a : static_cast<float>(0.0);
}
float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<float>(1.0) / (static_cast<float>(1.0) + exp(-tmp));
}
float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
float linear(const float a) { return a; }
float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); }
float sigmoid(const float a, const float b) {
return a * b * (static_cast<float>(1) - b);
}
float tanh(const float a, const float b) {
return a * (static_cast<float>(1) - b * b);
}
float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
double relu(const double a) {
return a > static_cast<double>(0.0) ? a : static_cast<double>(0.0);
}
double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<double>(1.0) / (static_cast<double>(1.0) + exp(-tmp));
}
double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
double linear(const double a) { return a; }
double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
double sigmoid(const double a, const double b) {
return a * b * (static_cast<double>(1) - b);
}
double tanh(const double a, const double b) {
return a * (static_cast<double>(1) - b * b);
}
double linear(const double a, const double b) { return a; }
} // namespace typed
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
#ifndef __NVCC__
namespace hppl {
namespace typef {
float relu(const float a);
float sigmoid(const float a);
float tanh(const float a);
float linear(const float a);
float relu(const float a, const float b);
float sigmoid(const float a, const float b);
float tanh(const float a, const float b);
float linear(const float a, const float b);
} // namespace typef
namespace typed {
double relu(const double a);
double sigmoid(const double a);
double tanh(const double a);
double linear(const double a);
double relu(const double a, const double b);
double sigmoid(const double a, const double b);
double tanh(const double a, const double b);
double linear(const double a, const double b);
} // namespace typed
} // namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace hppl {
namespace typef {
__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; }
__device__ static float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return __fdividef(1.0f, 1.0f + __expf(-tmp));
}
__device__ static float tanh(const float a) {
float tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f;
}
__device__ static float linear(const float a) { return a; }
__device__ static float relu(const float a, const float b) {
return a * (b > 0.0f ? 1.0f : 0.0f);
}
__device__ static float sigmoid(const float a, const float b) {
return a * b * (1.0f - b);
}
__device__ static float tanh(const float a, const float b) {
return a * (1.0f - b * b);
}
__device__ static float linear(const float a, const float b) { return a; }
} // namespace typef
namespace typed {
__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; }
__device__ static double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
}
__device__ static double tanh(const double a) {
double tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
}
__device__ static double linear(const double a) { return a; }
__device__ static double relu(const double a, const double b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
__device__ static double sigmoid(const double a, const double b) {
return a * b * (1 - b);
}
__device__ static double tanh(const double a, const double b) {
return a * (1.0 - b * b);
}
__device__ static double linear(const double a, const double b) { return a; }
} // namespace typef
} // namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/lstm_compute.h"
namespace paddle { namespace paddle {
...@@ -26,7 +26,10 @@ namespace detail { ...@@ -26,7 +26,10 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize) { int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
...@@ -58,7 +61,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -58,7 +61,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO); rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
valueIn[i] = rValueIn; valueIn[i] = rValueIn;
valueIg[i] = rValueIg; valueIg[i] = rValueIg;
...@@ -72,7 +75,10 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -72,7 +75,10 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize) { LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn; T rValueIn;
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
...@@ -122,7 +128,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -122,7 +128,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad); rCheckOGrad, active_node, active_gate, active_state);
gradIn[i] = rGradIn; gradIn[i] = rGradIn;
gradIg[i] = rGradIg; gradIg[i] = rGradIg;
...@@ -176,8 +182,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -176,8 +182,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
hppl::avx::forward[active_gate], hppl::avx::forward[active_state]);
valueIn[i] = rValueIn; valueIn[i] = rValueIn;
valueIg[i] = rValueIg; valueIg[i] = rValueIg;
...@@ -246,8 +251,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -246,8 +251,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, hppl::avx::backward[active_node], rCheckOGrad, active_node, active_gate, active_state);
hppl::avx::backward[active_gate], hppl::avx::backward[active_state]);
gradIn[i] = rGradIn; gradIn[i] = rGradIn;
gradIg[i] = rGradIg; gradIg[i] = rGradIg;
...@@ -274,7 +278,8 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -274,7 +278,8 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node, avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize); naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
} }
} }
...@@ -287,7 +292,8 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, ...@@ -287,7 +292,8 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state); active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize); naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
} }
} }
......
...@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include <glog/logging.h> #include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,7 +31,9 @@ namespace detail { ...@@ -32,7 +31,9 @@ namespace detail {
*/ */
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize) { int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
...@@ -69,7 +70,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -69,7 +70,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
} }
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO); rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state);
value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg; value.gateValue[frameIdx + frameSize] = rValueIg;
...@@ -88,7 +89,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -88,7 +89,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template <class T, class Op, bool isBatch> template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize, LstmMetaGrad<T> grad, int frameSize,
int batchSize) { int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return; if (frameIdx >= frameSize) return;
...@@ -141,7 +144,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -141,7 +144,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad); rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
active_node, active_gate, active_state);
grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg; grad.gateGrad[frameIdx + frameSize] = rGradIg;
...@@ -197,11 +201,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -197,11 +201,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize); op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize); op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} }
} }
...@@ -230,11 +236,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -230,11 +236,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) { if (batchSize == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>( /* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize); op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>( /* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize); op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
} }
} }
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
#include <type_traits> #include <type_traits>
...@@ -24,45 +24,22 @@ namespace detail { ...@@ -24,45 +24,22 @@ namespace detail {
namespace forward { namespace forward {
template <typename T>
DEVICE inline T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE inline T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output, T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO) { T &checkI, T &checkF, T &checkO,
#if 0 activation_mode_t active_node,
// TODO(qingqing) support to activation speficed by users activation_mode_t active_gate,
valueIn = actInput(valueIn); activation_mode_t active_state) {
valueIg = actGate(valueIg + prevState * checkI); valueIn = activation(valueIn, active_node);
valueFg = actGate(valueFg + prevState * checkF); valueIg = activation(valueIg + prevState * checkI, active_gate);
state = valueIn * valueIg + prevState * valueFg; valueFg = activation(valueFg + prevState * checkF, active_gate);
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
#else
valueIn = tanh<T>(valueIn);
valueIg = sigmoid<T>(valueIg + prevState * checkI);
valueFg = sigmoid<T>(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg; state = valueIn * valueIg + prevState * valueFg;
valueOg = sigmoid<T>(valueOg + state * checkO); valueOg = activation(valueOg + state * checkO, active_gate);
stateAtv = tanh<T>(state); stateAtv = activation(state, active_state);
output = valueOg * stateAtv; output = valueOg * stateAtv;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -75,16 +52,19 @@ class lstm { ...@@ -75,16 +52,19 @@ class lstm {
__m256 &valueOg, __m256 &prevState, __m256 &state, __m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI, __m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkF, __m256 &checkO,
hppl::Active<__m256>::forward actInput, activation_mode_t active_node,
hppl::Active<__m256>::forward actGate, activation_mode_t active_gate,
hppl::Active<__m256>::forward actState) { activation_mode_t active_state) {
valueIn = actInput(valueIn); valueIn = activation(valueIn, active_node);
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); valueIg = activation(
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); _mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate);
valueFg = activation(
_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate);
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg)); _mm256_mul_ps(prevState, valueFg));
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)),
stateAtv = actState(state); active_gate);
stateAtv = activation(state, active_state);
output = _mm256_mul_ps(valueOg, stateAtv); output = _mm256_mul_ps(valueOg, stateAtv);
} }
#endif #endif
...@@ -95,16 +75,6 @@ class lstm { ...@@ -95,16 +75,6 @@ class lstm {
namespace backward { namespace backward {
template <typename T>
DEVICE inline T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE inline T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
template <class T> template <class T>
class lstm { class lstm {
public: public:
...@@ -113,29 +83,20 @@ class lstm { ...@@ -113,29 +83,20 @@ class lstm {
T &prevState, T &prevStateGrad, T &state, T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad, T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad) { T &checkFGrad, T &checkOGrad,
#if 0 activation_mode_t active_node,
// TODO(qingqing) support to activation speficed by users activation_mode_t active_gate,
gradOg = actGate(outputGrad * stateAtv, valueOg); activation_mode_t active_state) {
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; gradOg = activation(outputGrad * stateAtv, valueOg, active_gate);
gradIn = actInput(stateGrad * valueIg, valueIn); stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) +
gradIg = actGate(stateGrad * valueIn, valueIg); gradOg * checkO;
gradFg = actGate(stateGrad * prevState, valueFg); gradIn = activation(stateGrad * valueIg, valueIn, active_node);
gradIg = activation(stateGrad * valueIn, valueIg, active_gate);
gradFg = activation(stateGrad * prevState, valueFg, active_gate);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState; checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState; checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state; checkOGrad = gradOg * state;
#else
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#endif
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -143,24 +104,26 @@ class lstm { ...@@ -143,24 +104,26 @@ class lstm {
#else #else
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, HOSTDEVICE void operator()(
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg, __m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState, __m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg,
__m256 &prevStateGrad, __m256 &state, __m256 &prevState, __m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv, __m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
__m256 &checkO, __m256 &checkIGrad, __m256 &checkOGrad, activation_mode_t active_node,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_gate, activation_mode_t active_state) {
hppl::Active<__m256>::backward actInput, gradOg =
hppl::Active<__m256>::backward actGate, activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate);
hppl::Active<__m256>::backward actState) {
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
stateGrad = _mm256_add_ps( stateGrad = _mm256_add_ps(
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state),
stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); gradIn =
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node);
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); gradIg =
activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate);
gradFg =
activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF)); _mm256_mul_ps(gradFg, checkF));
prevStateGrad = prevStateGrad =
......
...@@ -157,7 +157,7 @@ class TestLstmOp(OpTest): ...@@ -157,7 +157,7 @@ class TestLstmOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(atol=1e-8)
#TODO(qingqing) add more unit testing case #TODO(qingqing) add more unit testing case
def test_check_grad(self): def test_check_grad(self):
...@@ -167,7 +167,7 @@ class TestLstmOp(OpTest): ...@@ -167,7 +167,7 @@ class TestLstmOp(OpTest):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02) ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpHasNoInitial(TestLstmOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册