Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
36588b33
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
36588b33
编写于
10月 18, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix illegal instruction of rnn1 and text
上级
6447155d
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
241 addition
and
55 deletion
+241
-55
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+240
-54
未找到文件。
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
36588b33
...
@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
...
@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_library
(
jit_kernel
cc_library
(
jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas
activation_functions
)
DEPS cpu_info cblas
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
36588b33
...
@@ -69,37 +69,225 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
...
@@ -69,37 +69,225 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#endif
#define INTRI8_FLOAT(isa) \
namespace
detail
{
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST
(
0x7f
,
0x7f
);
_PS256_CONST
(
one
,
1.
f
);
_PS256_CONST
(
0
p5
,
0.5
f
);
_PS256_CONST
(
exp_hi
,
88.3762626647949
f
);
_PS256_CONST
(
exp_lo
,
-
88.3762626647949
f
);
_PS256_CONST
(
cephes_LOG2EF
,
1.44269504088896341
);
_PS256_CONST
(
cephes_exp_C1
,
0.693359375
);
_PS256_CONST
(
cephes_exp_C2
,
-
2.12194440e-4
);
_PS256_CONST
(
cephes_exp_p0
,
1.9875691500E-4
);
_PS256_CONST
(
cephes_exp_p1
,
1.3981999507E-3
);
_PS256_CONST
(
cephes_exp_p2
,
8.3334519073E-3
);
_PS256_CONST
(
cephes_exp_p3
,
4.1665795894E-2
);
_PS256_CONST
(
cephes_exp_p4
,
1.6666665459E-1
);
_PS256_CONST
(
cephes_exp_p5
,
5.0000001201E-1
);
typedef
union
imm_xmm_union
{
__m256i
imm
;
__m128i
xmm
[
2
];
}
imm_xmm_union
;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */
\
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */
\
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2
(
slli_epi32
);
AVX2_INTOP_USING_SSE2
(
add_epi32
);
__m256
ExpAVX
(
__m256
x
)
{
__m256
tmp
=
_mm256_setzero_ps
(),
fx
;
__m256
one
=
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_one
);
__m256i
imm0
;
x
=
_mm256_min_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_exp_hi
));
x
=
_mm256_max_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_exp_lo
));
/* express exp(x) as exp(g + n*log(2)) */
fx
=
_mm256_mul_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_LOG2EF
));
fx
=
_mm256_add_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_0p5
));
tmp
=
_mm256_floor_ps
(
fx
);
/* if greater, substract 1 */
__m256
mask
=
_mm256_cmp_ps
(
tmp
,
fx
,
_CMP_GT_OS
);
mask
=
_mm256_and_ps
(
mask
,
one
);
fx
=
_mm256_sub_ps
(
tmp
,
mask
);
tmp
=
_mm256_mul_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_C1
));
__m256
z
=
_mm256_mul_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_C2
));
x
=
_mm256_sub_ps
(
x
,
tmp
);
x
=
_mm256_sub_ps
(
x
,
z
);
z
=
_mm256_mul_ps
(
x
,
x
);
__m256
y
=
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p0
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p1
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p2
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p3
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p4
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p5
));
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
one
);
/* build 2^n */
imm0
=
_mm256_cvttps_epi32
(
fx
);
// two AVX2 instructions using SSE2
imm0
=
avx2_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
)
{
__m256
tmp
=
_mm256_setzero_ps
(),
fx
;
__m256
one
=
*
reinterpret_cast
<
const
__m256
*>
_ps256_one
;
__m256i
imm0
;
x
=
_mm256_min_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_exp_hi
));
x
=
_mm256_max_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_exp_lo
));
/* express exp(x) as exp(g + n*log(2)) */
fx
=
_mm256_mul_ps
(
x
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_LOG2EF
));
fx
=
_mm256_add_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_0p5
));
tmp
=
_mm256_floor_ps
(
fx
);
/* if greater, substract 1 */
__m256
mask
=
_mm256_cmp_ps
(
tmp
,
fx
,
_CMP_GT_OS
);
mask
=
_mm256_and_ps
(
mask
,
one
);
fx
=
_mm256_sub_ps
(
tmp
,
mask
);
tmp
=
_mm256_mul_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_C1
));
__m256
z
=
_mm256_mul_ps
(
fx
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_C2
));
x
=
_mm256_sub_ps
(
x
,
tmp
);
x
=
_mm256_sub_ps
(
x
,
z
);
z
=
_mm256_mul_ps
(
x
,
x
);
__m256
y
=
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p0
);
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p1
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p2
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p3
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p4
));
y
=
_mm256_mul_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
*
reinterpret_cast
<
const
__m256
*>
(
_ps256_cephes_exp_p5
));
y
=
_mm256_mul_ps
(
y
,
z
);
y
=
_mm256_add_ps
(
y
,
x
);
y
=
_mm256_add_ps
(
y
,
one
);
/* build 2^n */
imm0
=
_mm256_cvttps_epi32
(
fx
);
// two AVX2 instructions
imm0
=
_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
}
// namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y,
detail::Exp(tmp));
\
_mm256_storeu_ps(y,
expisa(tmp));
\
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 =
detail::Exp(tmp0);
\
tmp0 =
expisa(tmp0);
\
tmp1 =
detail::Exp(tmp1);
\
tmp1 =
expisa(tmp1);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
#endif
#endif
// TODO(TJ): eq16 test and complete avx512
// TODO(TJ): eq16 test and complete avx512
...
@@ -135,26 +323,26 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -135,26 +323,26 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
};
#define INTRI_SIGMOID(tmp, min, max
)
\
#define INTRI_SIGMOID(tmp, min, max
, expisa)
\
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp =
detail::Exp(tmp);
\
tmp =
expisa(tmp);
\
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa
)
\
#define INTRI8_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
\
/*use static const??*/
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
\
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
float* y) const { \
float* y) const { \
...
@@ -162,13 +350,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -162,13 +350,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max
);
\
INTRI_SIGMOID(tmp0, min, max
, expisa);
\
INTRI_SIGMOID(tmp1, min, max
);
\
INTRI_SIGMOID(tmp1, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#define INTRI_GT8LT16_FLOAT(isa
)
\
#define INTRI_GT8LT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
: VSigmoidKernel<float>() { \
...
@@ -184,7 +372,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -184,7 +372,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
...
@@ -198,7 +386,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -198,7 +386,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
} \
} \
}
}
#define INTRI_GT16_FLOAT(isa
)
\
#define INTRI_GT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
: VSigmoidKernel<float>() { \
...
@@ -215,7 +403,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -215,7 +403,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y + i, tmp); \
_mm256_storeu_ps(y + i, tmp); \
} \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float min_ = SIGMOID_THRESHOLD_MIN; \
...
@@ -231,22 +419,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -231,22 +419,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// INTRI_GT8LT16_FLOAT(jit::avx2);
// maybe use avx at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx2);
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// INTRI_GT8LT16_FLOAT(jit::avx512f);
// maybe use avx2 at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx512f);
#endif
#endif
#undef INTRI8_FLOAT
#undef INTRI8_FLOAT
...
@@ -280,36 +466,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -280,36 +466,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
std
::
shared_ptr
<
const
VAddBiasKernel
<
T
>>
vaddbias_
;
std
::
shared_ptr
<
const
VAddBiasKernel
<
T
>>
vaddbias_
;
};
};
#define INTRI_VTANH(tmp
)
\
#define INTRI_VTANH(tmp
, expisa)
\
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp =
detail::Exp(tmp);
\
tmp =
expisa(tmp);
\
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa
)
\
#define INTRI8_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0
);
\
INTRI_VTANH(tmp0
, expisa);
\
INTRI_VTANH(tmp1
);
\
INTRI_VTANH(tmp1
, expisa);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#define INTRI_GT8LT16_FLOAT(isa
)
\
#define INTRI_GT8LT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
: VTanhKernel<float>() { \
...
@@ -327,7 +513,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -327,7 +513,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
...
@@ -337,7 +523,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -337,7 +523,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(-1.f, y, y); \
}
}
#define INTRI_GT16_FLOAT(isa
)
\
#define INTRI_GT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
: VTanhKernel<float>() { \
...
@@ -356,7 +542,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -356,7 +542,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
const { \
const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y + i, tmp); \
_mm256_storeu_ps(y + i, tmp); \
} \
} \
x += this->end_; \
x += this->end_; \
...
@@ -368,19 +554,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -368,19 +554,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
// maybe use avx at gt8lt16 and gt16
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
// maybe use avx at gt8lt16 and gt16
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录