提交 29a9f9b5 编写于 作者: D dangqingqing

Refine code format and fix threads number.

上级 5a4cdbb3
......@@ -32,17 +32,17 @@ namespace detail {
namespace forward {
template <typename T>
DEVICE T linear(const T a) {
DEVICE T Identity(const T a) {
return a;
}
template <typename T>
DEVICE T relu(const T a) {
DEVICE T Relu(const T a) {
return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0);
}
template <typename T>
DEVICE T sigmoid(const T a) {
DEVICE T Sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
......@@ -50,7 +50,7 @@ DEVICE T sigmoid(const T a) {
}
template <typename T>
DEVICE T tanh(const T a) {
DEVICE T Tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
......@@ -61,22 +61,22 @@ DEVICE T tanh(const T a) {
namespace backward {
template <typename T>
DEVICE T linear(const T a, const T b) {
DEVICE T Identity(const T a, const T b) {
return a;
}
template <typename T>
DEVICE T relu(const T a, const T b) {
DEVICE T Relu(const T a, const T b) {
return a * (b > 0.0 ? 1.0 : 0.0);
}
template <typename T>
DEVICE T sigmoid(const T a, const T b) {
DEVICE T Sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE T tanh(const T a, const T b) {
DEVICE T Tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
......@@ -89,20 +89,20 @@ struct Active {
};
static DEVICE Active<float>::Act kActFloat[] = {
&forward::sigmoid<float>, &forward::relu<float>, &forward::tanh<float>,
&forward::linear<float>};
&forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>,
&forward::Identity<float>};
static DEVICE Active<float>::ActGrad kActGradFloat[] = {
&backward::sigmoid<float>, &backward::relu<float>, &backward::tanh<float>,
&backward::linear<float>};
&backward::Sigmoid<float>, &backward::Relu<float>, &backward::Tanh<float>,
&backward::Identity<float>};
static DEVICE Active<double>::Act kActDouble[] = {
&forward::sigmoid<double>, &forward::relu<double>, &forward::tanh<double>,
&forward::linear<double>};
&forward::Sigmoid<double>, &forward::Relu<double>, &forward::Tanh<double>,
&forward::Identity<double>};
static DEVICE Active<double>::ActGrad kActGradDouble[] = {
&backward::sigmoid<double>, &backward::relu<double>,
&backward::tanh<double>, &backward::linear<double>};
&backward::Sigmoid<double>, &backward::Relu<double>,
&backward::Tanh<double>, &backward::Identity<double>};
namespace forward {
inline DEVICE float activation(float a, int index) {
......@@ -128,29 +128,29 @@ inline DEVICE double activation(double a, double b, int index) {
#ifdef __AVX__
namespace forward {
namespace avx {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
__m256 Relu(const __m256 a);
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
__m256 Identity(const __m256 a);
} // namespace avx
} // namespace forward
namespace backward {
namespace avx {
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
__m256 Relu(const __m256 a, const __m256 b);
__m256 Sigmoid(const __m256 a, const __m256 b);
__m256 Tanh(const __m256 a, const __m256 b);
__m256 Identity(const __m256 a, const __m256 b);
} // namespace avx
} // namespace backward
static Active<__m256>::Act kActAvx[] = {
&forward::avx::sigmoid, &forward::avx::relu, &forward::avx::tanh,
&forward::avx::linear};
&forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh,
&forward::avx::Identity};
static Active<__m256>::ActGrad kActGradAvx[] = {
&backward::avx::sigmoid, &backward::avx::relu, &backward::avx::tanh,
&backward::avx::linear};
&backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh,
&backward::avx::Identity};
namespace forward {
inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); }
......
......@@ -22,61 +22,61 @@ namespace operators {
namespace math {
namespace detail {
__m256 exp(__m256 a) { return exp256_ps(a); }
__m256 Exp(__m256 a) { return exp256_ps(a); }
namespace forward {
namespace avx {
__m256 relu(const __m256 a) {
__m256 Relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
}
__m256 sigmoid(const __m256 a) {
__m256 Sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = Exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 tanh(const __m256 a) {
__m256 Tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp);
tmp = Exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f));
}
__m256 linear(const __m256 a) { return a; }
__m256 Identity(const __m256 a) { return a; }
} // namespace avx
} // namespace forward
namespace backward {
namespace avx {
__m256 relu(const __m256 a, const __m256 b) {
__m256 Relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f)));
}
__m256 sigmoid(const __m256 a, const __m256 b) {
__m256 Sigmoid(const __m256 a, const __m256 b) {
return _mm256_mul_ps(_mm256_mul_ps(a, b),
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
}
__m256 tanh(const __m256 a, const __m256 b) {
__m256 Tanh(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
}
__m256 linear(const __m256 a, const __m256 b) { return a; }
__m256 Identity(const __m256 a, const __m256 b) { return a; }
} // namespace avx
} // namespace backward
......
......@@ -226,9 +226,9 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
/* framePerBlock = 32 batchPerBlock = 16 */
threads = dim3(32, 16);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16);
}
auto stream =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册