提交 29a9f9b5 编写于 作者: D dangqingqing

Refine code format and fix threads number.

上级 5a4cdbb3
...@@ -32,17 +32,17 @@ namespace detail { ...@@ -32,17 +32,17 @@ namespace detail {
namespace forward { namespace forward {
template <typename T> template <typename T>
DEVICE T linear(const T a) { DEVICE T Identity(const T a) {
return a; return a;
} }
template <typename T> template <typename T>
DEVICE T relu(const T a) { DEVICE T Relu(const T a) {
return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0); return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0);
} }
template <typename T> template <typename T>
DEVICE T sigmoid(const T a) { DEVICE T Sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX; const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a); T tmp = (a < min) ? min : ((a > max) ? max : a);
...@@ -50,7 +50,7 @@ DEVICE T sigmoid(const T a) { ...@@ -50,7 +50,7 @@ DEVICE T sigmoid(const T a) {
} }
template <typename T> template <typename T>
DEVICE T tanh(const T a) { DEVICE T Tanh(const T a) {
T tmp = -2.0 * a; T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0; return (2.0 / (1.0 + exp(tmp))) - 1.0;
...@@ -61,22 +61,22 @@ DEVICE T tanh(const T a) { ...@@ -61,22 +61,22 @@ DEVICE T tanh(const T a) {
namespace backward { namespace backward {
template <typename T> template <typename T>
DEVICE T linear(const T a, const T b) { DEVICE T Identity(const T a, const T b) {
return a; return a;
} }
template <typename T> template <typename T>
DEVICE T relu(const T a, const T b) { DEVICE T Relu(const T a, const T b) {
return a * (b > 0.0 ? 1.0 : 0.0); return a * (b > 0.0 ? 1.0 : 0.0);
} }
template <typename T> template <typename T>
DEVICE T sigmoid(const T a, const T b) { DEVICE T Sigmoid(const T a, const T b) {
return a * b * (1.0 - b); return a * b * (1.0 - b);
} }
template <typename T> template <typename T>
DEVICE T tanh(const T a, const T b) { DEVICE T Tanh(const T a, const T b) {
return a * (1.0 - b * b); return a * (1.0 - b * b);
} }
...@@ -89,20 +89,20 @@ struct Active { ...@@ -89,20 +89,20 @@ struct Active {
}; };
static DEVICE Active<float>::Act kActFloat[] = { static DEVICE Active<float>::Act kActFloat[] = {
&forward::sigmoid<float>, &forward::relu<float>, &forward::tanh<float>, &forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>,
&forward::linear<float>}; &forward::Identity<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = { static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::sigmoid<float>, &backward::relu<float>, &backward::tanh<float>, &backward::Sigmoid<float>, &backward::Relu<float>, &backward::Tanh<float>,
&backward::linear<float>}; &backward::Identity<float>};
static DEVICE Active<double>::Act kActDouble[] = { static DEVICE Active<double>::Act kActDouble[] = {
&forward::sigmoid<double>, &forward::relu<double>, &forward::tanh<double>, &forward::Sigmoid<double>, &forward::Relu<double>, &forward::Tanh<double>,
&forward::linear<double>}; &forward::Identity<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = { static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::sigmoid<double>, &backward::relu<double>, &backward::Sigmoid<double>, &backward::Relu<double>,
&backward::tanh<double>, &backward::linear<double>}; &backward::Tanh<double>, &backward::Identity<double>};
namespace forward { namespace forward {
inline DEVICE float activation(float a, int index) { inline DEVICE float activation(float a, int index) {
...@@ -128,29 +128,29 @@ inline DEVICE double activation(double a, double b, int index) { ...@@ -128,29 +128,29 @@ inline DEVICE double activation(double a, double b, int index) {
#ifdef __AVX__ #ifdef __AVX__
namespace forward { namespace forward {
namespace avx { namespace avx {
__m256 relu(const __m256 a); __m256 Relu(const __m256 a);
__m256 sigmoid(const __m256 a); __m256 Sigmoid(const __m256 a);
__m256 tanh(const __m256 a); __m256 Tanh(const __m256 a);
__m256 linear(const __m256 a); __m256 Identity(const __m256 a);
} // namespace avx } // namespace avx
} // namespace forward } // namespace forward
namespace backward { namespace backward {
namespace avx { namespace avx {
__m256 relu(const __m256 a, const __m256 b); __m256 Relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b); __m256 Sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b); __m256 Tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b); __m256 Identity(const __m256 a, const __m256 b);
} // namespace avx } // namespace avx
} // namespace backward } // namespace backward
static Active<__m256>::Act kActAvx[] = { static Active<__m256>::Act kActAvx[] = {
&forward::avx::sigmoid, &forward::avx::relu, &forward::avx::tanh, &forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh,
&forward::avx::linear}; &forward::avx::Identity};
static Active<__m256>::ActGrad kActGradAvx[] = { static Active<__m256>::ActGrad kActGradAvx[] = {
&backward::avx::sigmoid, &backward::avx::relu, &backward::avx::tanh, &backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh,
&backward::avx::linear}; &backward::avx::Identity};
namespace forward { namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
......
...@@ -22,61 +22,61 @@ namespace operators { ...@@ -22,61 +22,61 @@ namespace operators {
namespace math { namespace math {
namespace detail { namespace detail {
__m256 exp(__m256 a) { return exp256_ps(a); } __m256 Exp(__m256 a) { return exp256_ps(a); }
namespace forward { namespace forward {
namespace avx { namespace avx {
__m256 relu(const __m256 a) { __m256 Relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f); __m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp); return _mm256_max_ps(a, tmp);
} }
__m256 sigmoid(const __m256 a) { __m256 Sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min); __m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max); tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp); tmp = Exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp; return tmp;
} }
__m256 tanh(const __m256 a) { __m256 Tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT); __m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max); tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp); tmp = Exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f)); _mm256_set1_ps(1.0f));
} }
__m256 linear(const __m256 a) { return a; } __m256 Identity(const __m256 a) { return a; }
} // namespace avx } // namespace avx
} // namespace forward } // namespace forward
namespace backward { namespace backward {
namespace avx { namespace avx {
__m256 relu(const __m256 a, const __m256 b) { __m256 Relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps( return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f))); _mm256_set1_ps(1.0f)));
} }
__m256 sigmoid(const __m256 a, const __m256 b) { __m256 Sigmoid(const __m256 a, const __m256 b) {
return _mm256_mul_ps(_mm256_mul_ps(a, b), return _mm256_mul_ps(_mm256_mul_ps(a, b),
_mm256_sub_ps(_mm256_set1_ps(1.0f), b)); _mm256_sub_ps(_mm256_set1_ps(1.0f), b));
} }
__m256 tanh(const __m256 a, const __m256 b) { __m256 Tanh(const __m256 a, const __m256 b) {
return _mm256_mul_ps( return _mm256_mul_ps(
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
} }
__m256 linear(const __m256 a, const __m256 b) { return a; } __m256 Identity(const __m256 a, const __m256 b) { return a; }
} // namespace avx } // namespace avx
} // namespace backward } // namespace backward
......
...@@ -226,9 +226,9 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -226,9 +226,9 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
threads = dim3(framePerBlock, 1); threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1); grid = dim3(frameBlocks, 1);
} else { } else {
/* framePerBlock = 32 batchPerBlock = 32 */ /* framePerBlock = 32 batchPerBlock = 16 */
threads = dim3(32, 16); threads = dim3(32, 16);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16);
} }
auto stream = auto stream =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册