Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0a9f5f17
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a9f5f17
编写于
10月 19, 2018
作者:
T
tensor-tang
提交者:
GitHub
10月 19, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13968 from tensor-tang/fix/jit/exp
Fix jit exp
上级
fcb2e810
60ff05e3
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
327 addition
and
134 deletion
+327
-134
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
+3
-3
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+201
-60
paddle/fluid/operators/math/jit_kernel_lstm.cc
paddle/fluid/operators/math/jit_kernel_lstm.cc
+122
-70
未找到文件。
paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc
浏览文件 @
0a9f5f17
...
@@ -18,12 +18,12 @@ namespace paddle {
...
@@ -18,12 +18,12 @@ namespace paddle {
namespace
inference
{
namespace
inference
{
using
namespace
framework
;
// NOLINT
using
namespace
framework
;
// NOLINT
static
std
::
vector
<
float
>
result_data
;
struct
DataRecord
{
struct
DataRecord
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
link_step_data_all
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>>>
link_step_data_all
;
std
::
vector
<
size_t
>
lod
;
std
::
vector
<
size_t
>
lod
;
std
::
vector
<
std
::
vector
<
float
>>
rnn_link_data
;
std
::
vector
<
std
::
vector
<
float
>>
rnn_link_data
;
std
::
vector
<
float
>
result_data
;
size_t
num_samples
;
// total number of samples
size_t
num_samples
;
// total number of samples
size_t
batch_iter
{
0
};
size_t
batch_iter
{
0
};
size_t
batch_size
{
1
};
size_t
batch_size
{
1
};
...
@@ -57,6 +57,7 @@ struct DataRecord {
...
@@ -57,6 +57,7 @@ struct DataRecord {
std
::
ifstream
file
(
path
);
std
::
ifstream
file
(
path
);
std
::
string
line
;
std
::
string
line
;
int
num_lines
=
0
;
int
num_lines
=
0
;
result_data
.
clear
();
while
(
std
::
getline
(
file
,
line
))
{
while
(
std
::
getline
(
file
,
line
))
{
num_lines
++
;
num_lines
++
;
std
::
vector
<
std
::
string
>
data
;
std
::
vector
<
std
::
string
>
data
;
...
@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
...
@@ -135,13 +136,12 @@ TEST(Analyzer_rnn2, profile) {
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
if
(
FLAGS_num_threads
==
1
&&
!
FLAGS_test_all_data
)
{
// the first inference result
// the first inference result
DataRecord
data
(
FLAGS_infer_data
,
FLAGS_batch_size
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
PADDLE_ENFORCE_GT
(
outputs
.
size
(),
0
);
size_t
size
=
GetSize
(
outputs
[
0
]);
size_t
size
=
GetSize
(
outputs
[
0
]);
PADDLE_ENFORCE_GT
(
size
,
0
);
PADDLE_ENFORCE_GT
(
size
,
0
);
float
*
result
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
());
float
*
result
=
static_cast
<
float
*>
(
outputs
[
0
].
data
.
data
());
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
EXPECT_NEAR
(
result
[
i
],
data
.
result_data
[
i
],
1e-3
);
EXPECT_NEAR
(
result
[
i
],
result_data
[
i
],
1e-3
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
0a9f5f17
...
@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
...
@@ -76,5 +76,5 @@ cc_test(concat_test SRCS concat_test.cc DEPS concat)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_test
(
cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info
)
cc_library
(
jit_kernel
cc_library
(
jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc
DEPS cpu_info cblas
activation_functions
)
DEPS cpu_info cblas
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
cc_test
(
jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel
)
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
0a9f5f17
...
@@ -27,13 +27,6 @@ limitations under the License. */
...
@@ -27,13 +27,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
#ifdef __AVX__
namespace
detail
{
__m256
Exp
(
__m256
a
);
}
// namespace detail
#endif
namespace
jitkernel
{
namespace
jitkernel
{
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
...
@@ -69,37 +62,186 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
...
@@ -69,37 +62,186 @@ FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
FOR_EACH_ISA_BLOCK
(
MKL_DOUBLE
);
#endif
#endif
#define INTRI8_FLOAT(isa) \
namespace
detail
{
#ifdef __AVX__
#define ALIGN32 __attribute__((aligned(32)))
#define _PS256_CONST(Name, Val) \
static const float _ps256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
#define _PI256_CONST(Name, Val) \
static const int _pi256_##Name[8] ALIGN32 = {Val, Val, Val, Val, \
Val, Val, Val, Val}
_PI256_CONST
(
0x7f
,
0x7f
);
_PS256_CONST
(
one
,
1.
f
);
_PS256_CONST
(
0
p5
,
0.5
f
);
_PS256_CONST
(
exp_hi
,
88.3762626647949
f
);
_PS256_CONST
(
exp_lo
,
-
88.3762626647949
f
);
_PS256_CONST
(
cephes_LOG2EF
,
1.44269504088896341
);
_PS256_CONST
(
cephes_exp_C1
,
0.693359375
);
_PS256_CONST
(
cephes_exp_C2
,
-
2.12194440e-4
);
_PS256_CONST
(
cephes_exp_p0
,
1.9875691500E-4
);
_PS256_CONST
(
cephes_exp_p1
,
1.3981999507E-3
);
_PS256_CONST
(
cephes_exp_p2
,
8.3334519073E-3
);
_PS256_CONST
(
cephes_exp_p3
,
4.1665795894E-2
);
_PS256_CONST
(
cephes_exp_p4
,
1.6666665459E-1
);
_PS256_CONST
(
cephes_exp_p5
,
5.0000001201E-1
);
typedef
union
imm_xmm_union
{
__m256i
imm
;
__m128i
xmm
[
2
];
}
imm_xmm_union
;
#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) \
{ \
imm_xmm_union u ALIGN32; \
u.imm = imm_; \
xmm0_ = u.xmm[0]; \
xmm1_ = u.xmm[1]; \
}
#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) \
{ \
imm_xmm_union u ALIGN32; \
u.xmm[0] = xmm0_; \
u.xmm[1] = xmm1_; \
imm_ = u.imm; \
}
#define AVX2_BITOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_##fn(__m256i x, int y) { \
/* use SSE2 to perform the bitop AVX2 */
\
__m128i x1, x2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
x1 = _mm_##fn(x1, y); \
x2 = _mm_##fn(x2, y); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
#define AVX2_INTOP_USING_SSE2(fn) \
static inline __m256i avx2_mm256_add_epi32(__m256i x, __m256i y) { \
/* use SSE2 to perform the AVX2 integer operation */
\
__m128i x1, x2; \
__m128i y1, y2; \
__m256i ret; \
COPY_IMM_TO_XMM(x, x1, x2); \
COPY_IMM_TO_XMM(y, y1, y2); \
x1 = _mm_##fn(x1, y1); \
x2 = _mm_##fn(x2, y2); \
COPY_XMM_TO_IMM(x1, x2, ret); \
return ret; \
}
AVX2_BITOP_USING_SSE2
(
slli_epi32
);
AVX2_INTOP_USING_SSE2
(
add_epi32
);
#define AVXEXP_BASE \
__m256 tmp = _mm256_setzero_ps(), fx; \
__m256 one = *reinterpret_cast<const __m256*>(_ps256_one); \
__m256i imm0; \
x = _mm256_min_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_hi)); \
x = _mm256_max_ps(x, *reinterpret_cast<const __m256*>(_ps256_exp_lo)); \
/* express exp(x) as exp(g + n*log(2)) */
\
fx = _mm256_mul_ps(x, \
*reinterpret_cast<const __m256*>(_ps256_cephes_LOG2EF)); \
fx = _mm256_add_ps(fx, *reinterpret_cast<const __m256*>(_ps256_0p5)); \
tmp = _mm256_floor_ps(fx); \
/* if greater, substract 1 */
\
__m256 mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS); \
mask = _mm256_and_ps(mask, one); \
fx = _mm256_sub_ps(tmp, mask); \
tmp = _mm256_mul_ps(fx, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_C1)); \
__m256 z = _mm256_mul_ps( \
fx, *reinterpret_cast<const __m256*>(_ps256_cephes_exp_C2)); \
x = _mm256_sub_ps(x, tmp); \
x = _mm256_sub_ps(x, z); \
z = _mm256_mul_ps(x, x); \
__m256 y = *reinterpret_cast<const __m256*>(_ps256_cephes_exp_p0); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p1)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p2)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p3)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p4)); \
y = _mm256_mul_ps(y, x); \
y = _mm256_add_ps(y, \
*reinterpret_cast<const __m256*>(_ps256_cephes_exp_p5)); \
y = _mm256_mul_ps(y, z); \
y = _mm256_add_ps(y, x); \
y = _mm256_add_ps(y, one); \
/* build 2^n */
\
imm0 = _mm256_cvttps_epi32(fx)
__m256
ExpAVX
(
__m256
x
)
{
AVXEXP_BASE
;
// two AVX2 instructions using SSE2
imm0
=
avx2_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
avx2_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
)
{
AVXEXP_BASE
;
// two AVX2 instructions
imm0
=
_mm256_add_epi32
(
imm0
,
*
reinterpret_cast
<
const
__m256i
*>
(
_pi256_0x7f
));
imm0
=
_mm256_slli_epi32
(
imm0
,
23
);
__m256
pow2n
=
_mm256_castsi256_ps
(
imm0
);
y
=
_mm256_mul_ps
(
y
,
pow2n
);
return
y
;
}
#endif
}
// namespace detail
#define INTRI8_FLOAT(isa, expisa) \
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VExpKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
_mm256_storeu_ps(y,
detail::Exp(tmp));
\
_mm256_storeu_ps(y,
expisa(tmp));
\
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
void VExpKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 =
detail::Exp(tmp0);
\
tmp0 =
expisa(tmp0);
\
tmp1 =
detail::Exp(tmp1);
\
tmp1 =
expisa(tmp1);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
#endif
#endif
// TODO(TJ): eq16 test and complete avx512
// TODO(TJ): eq16 test and complete avx512
...
@@ -135,26 +277,27 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -135,26 +277,27 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
};
#define INTRI_SIGMOID(tmp, min, max
)
\
#define INTRI_SIGMOID(tmp, min, max
, expisa)
\
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp =
detail::Exp(tmp);
\
tmp =
expisa(tmp);
\
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa
)
\
#define INTRI8_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VSigmoidKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
/* TODO(TJ): try to use static const*/
\
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
void VSigmoidKernelImpl<float, isa, kEQ16>::Compute(const float* x, \
float* y) const { \
float* y) const { \
...
@@ -162,13 +305,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -162,13 +305,13 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max
);
\
INTRI_SIGMOID(tmp0, min, max
, expisa);
\
INTRI_SIGMOID(tmp1, min, max
);
\
INTRI_SIGMOID(tmp1, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#define INTRI_GT8LT16_FLOAT(isa
)
\
#define INTRI_GT8LT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
: VSigmoidKernel<float>() { \
...
@@ -184,7 +327,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -184,7 +327,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
...
@@ -198,7 +341,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -198,7 +341,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
} \
} \
}
}
#define INTRI_GT16_FLOAT(isa
)
\
#define INTRI_GT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
: VSigmoidKernel<float>() { \
...
@@ -215,7 +358,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -215,7 +358,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max
);
\
INTRI_SIGMOID(tmp, min, max
, expisa);
\
_mm256_storeu_ps(y + i, tmp); \
_mm256_storeu_ps(y + i, tmp); \
} \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float min_ = SIGMOID_THRESHOLD_MIN; \
...
@@ -231,22 +374,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -231,22 +374,20 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// INTRI_GT8LT16_FLOAT(jit::avx2);
// maybe use avx at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx2);
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// INTRI_GT8LT16_FLOAT(jit::avx512f);
// maybe use avx2 at gt8lt16 and gt16
// INTRI_GT16_FLOAT(jit::avx512f);
#endif
#endif
#undef INTRI8_FLOAT
#undef INTRI8_FLOAT
...
@@ -280,36 +421,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -280,36 +421,36 @@ class VTanhKernelImpl : public VTanhKernel<T> {
std
::
shared_ptr
<
const
VAddBiasKernel
<
T
>>
vaddbias_
;
std
::
shared_ptr
<
const
VAddBiasKernel
<
T
>>
vaddbias_
;
};
};
#define INTRI_VTANH(tmp
)
\
#define INTRI_VTANH(tmp
, expisa)
\
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp =
detail::Exp(tmp);
\
tmp =
expisa(tmp);
\
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa
)
\
#define INTRI8_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
void VTanhKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
}
}
#define INTRI16_FLOAT(isa
)
\
#define INTRI16_FLOAT(isa
, expisa)
\
template <> \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
void VTanhKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0
);
\
INTRI_VTANH(tmp0
, expisa);
\
INTRI_VTANH(tmp1
);
\
INTRI_VTANH(tmp1
, expisa);
\
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
_mm256_storeu_ps(y + 8, tmp1); \
}
}
#define INTRI_GT8LT16_FLOAT(isa
)
\
#define INTRI_GT8LT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
: VTanhKernel<float>() { \
...
@@ -327,7 +468,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -327,7 +468,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
void VTanhKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y, tmp); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
...
@@ -337,7 +478,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -337,7 +478,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_->Compute(-1.f, y, y); \
vaddbias_->Compute(-1.f, y, y); \
}
}
#define INTRI_GT16_FLOAT(isa
)
\
#define INTRI_GT16_FLOAT(isa
, expisa)
\
template <> \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
: VTanhKernel<float>() { \
...
@@ -356,7 +497,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -356,7 +497,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
const { \
const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp
);
\
INTRI_VTANH(tmp
, expisa);
\
_mm256_storeu_ps(y + i, tmp); \
_mm256_storeu_ps(y + i, tmp); \
} \
} \
x += this->end_; \
x += this->end_; \
...
@@ -368,19 +509,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
...
@@ -368,19 +509,19 @@ class VTanhKernelImpl : public VTanhKernel<T> {
}
}
#ifdef __AVX__
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
);
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
);
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
// maybe use avx at gt8lt16 and gt16
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
);
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
// maybe use avx at gt8lt16 and gt16
#endif
#endif
...
...
paddle/fluid/operators/math/jit_kernel_lstm.cc
浏览文件 @
0a9f5f17
...
@@ -25,13 +25,18 @@ limitations under the License. */
...
@@ -25,13 +25,18 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
#ifdef __AVX__
namespace
jitkernel
{
namespace
detail
{
namespace
detail
{
__m256
Exp
(
__m256
a
);
#ifdef __AVX__
}
// namespace detail
__m256
ExpAVX
(
__m256
x
);
#endif
#endif
namespace
jitkernel
{
#ifdef __AVX2__
__m256
ExpAVX2
(
__m256
x
);
#endif
}
// namespace detail
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
#ifdef __AVX__
#ifdef __AVX__
...
@@ -43,43 +48,72 @@ class AVXAct {
...
@@ -43,43 +48,72 @@ class AVXAct {
virtual
__m256
Compute
(
__m256
x
)
const
=
0
;
virtual
__m256
Compute
(
__m256
x
)
const
=
0
;
};
};
template
<
act_type
type
>
template
<
act_type
type
,
jit
::
cpu_isa_t
isa
>
class
AVXActImpl
:
public
AVXAct
{
class
AVXActImpl
:
public
AVXAct
{
public:
public:
__m256
Compute
(
__m256
x
)
const
override
{
PADDLE_THROW
(
"Unkown type!"
);
}
__m256
Compute
(
__m256
x
)
const
override
{
PADDLE_THROW
(
"Unkown type!"
);
}
};
};
template
<
>
#define AVX_SIGMOID(isa, expisa) \
__m256
AVXActImpl
<
kSigmoid
>::
Compute
(
__m256
x
)
const
{
template <> \
__m256
ones
=
_mm256_set1_ps
(
1.0
f
);
__m256 AVXActImpl<kSigmoid, isa>::Compute(__m256 x) const { \
x
=
_mm256_max_ps
(
x
,
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MIN
));
__m256 ones = _mm256_set1_ps(1.0f); \
x
=
_mm256_min_ps
(
x
,
_mm256_set1_ps
(
SIGMOID_THRESHOLD_MAX
));
x = _mm256_max_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MIN)); \
x
=
_mm256_sub_ps
(
_mm256_set1_ps
(
0.0
f
),
x
);
x = _mm256_min_ps(x, _mm256_set1_ps(SIGMOID_THRESHOLD_MAX)); \
x
=
detail
::
Exp
(
x
);
x = _mm256_sub_ps(_mm256_set1_ps(0.0f), x); \
x
=
_mm256_add_ps
(
ones
,
x
);
x = expisa(x); \
return
_mm256_div_ps
(
ones
,
x
);
x = _mm256_add_ps(ones, x); \
}
return _mm256_div_ps(ones, x); \
}
template
<
>
#define AVX_TANH(isa, expisa) \
__m256
AVXActImpl
<
kTanh
>::
Compute
(
__m256
x
)
const
{
template <> \
__m256
ones
=
_mm256_set1_ps
(
1.0
f
);
__m256 AVXActImpl<kTanh, isa>::Compute(__m256 x) const { \
x
=
_mm256_mul_ps
(
_mm256_set1_ps
(
-
2.0
f
),
x
);
__m256 ones = _mm256_set1_ps(1.0f); \
x
=
_mm256_min_ps
(
x
,
_mm256_set1_ps
(
EXP_MAX_INPUT
));
x = _mm256_mul_ps(_mm256_set1_ps(-2.0f), x); \
x
=
detail
::
Exp
(
x
);
x = _mm256_min_ps(x, _mm256_set1_ps(EXP_MAX_INPUT)); \
x
=
_mm256_add_ps
(
ones
,
x
);
x = expisa(x); \
x
=
_mm256_div_ps
(
_mm256_set1_ps
(
2.0
f
),
x
);
x = _mm256_add_ps(ones, x); \
return
_mm256_sub_ps
(
x
,
ones
);
x = _mm256_div_ps(_mm256_set1_ps(2.0f), x); \
}
return _mm256_sub_ps(x, ones); \
}
template
<
>
#define AVX_RELU(isa) \
__m256
AVXActImpl
<
kRelu
>::
Compute
(
__m256
x
)
const
{
template <> \
return
_mm256_max_ps
(
x
,
_mm256_setzero_ps
());
__m256 AVXActImpl<kRelu, isa>::Compute(__m256 x) const { \
}
return _mm256_max_ps(x, _mm256_setzero_ps()); \
}
#define AVX_IDENTITY(isa) \
template <> \
__m256 AVXActImpl<kIdentity, isa>::Compute(__m256 x) const { \
return x; \
}
#define FOR_EACH_AVX_ISA(macro_) \
macro_(jit::avx); \
macro_(jit::avx2); \
macro_(jit::avx512f)
FOR_EACH_AVX_ISA
(
AVX_RELU
);
FOR_EACH_AVX_ISA
(
AVX_IDENTITY
);
AVX_SIGMOID
(
jit
::
avx
,
detail
::
ExpAVX
);
AVX_TANH
(
jit
::
avx
,
detail
::
ExpAVX
);
#ifdef __AVX2__
AVX_SIGMOID
(
jit
::
avx2
,
detail
::
ExpAVX2
);
AVX_SIGMOID
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
AVX_TANH
(
jit
::
avx2
,
detail
::
ExpAVX2
);
AVX_TANH
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
#endif
#undef FOR_EACH_AVX_ISA
#undef AVX_IDENTITY
#undef AVX_RELU
#undef AVX_TANH
#undef AVX_SIGMOID
template
<
>
__m256
AVXActImpl
<
kIdentity
>::
Compute
(
__m256
x
)
const
{
return
x
;
}
#endif
#endif
template
<
typename
T
>
template
<
typename
T
>
...
@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -119,23 +153,6 @@ class LSTMKernelImpl : public LSTMKernel<T> {
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
act_cell_d_
=
GetActKernel
<
T
>
(
act_cell
,
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vmul_d_
=
KernelPool
::
Instance
().
template
Get
<
VMulKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
vadd_d_
=
KernelPool
::
Instance
().
template
Get
<
VAddKernel
<
T
>
>
(
d
);
#ifdef __AVX__
auto
GetAVXAct
=
[
&
](
const
std
::
string
&
type
)
->
std
::
unique_ptr
<
AVXAct
>
{
if
(
type
==
"sigmoid"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kSigmoid
>
());
}
else
if
(
type
==
"relu"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kRelu
>
());
}
else
if
(
type
==
"tanh"
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kTanh
>
());
}
else
if
(
type
==
"identity"
||
type
==
""
)
{
return
std
::
unique_ptr
<
AVXAct
>
(
new
AVXActImpl
<
kIdentity
>
());
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
};
avx_act_gate_
=
GetAVXAct
(
act_gate
);
avx_act_cand_
=
GetAVXAct
(
act_cand
);
avx_act_cell_
=
GetAVXAct
(
act_cell
);
#endif
}
}
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
...
@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -175,26 +192,61 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif
#endif
};
};
#define INTRI8_FLOAT(isa) \
#define INTRI8_FLOAT(isa) \
template <> \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
float* gates, const float* ct_1, float* ct, float* ht, \
const std::string& act_gate, const std::string& act_cand, \
const float* wp_data, float* checked) const { \
const std::string& act_cell, int d) \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
: LSTMKernel<float>() { \
__m256 c, i, f, o; \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \
c = _mm256_loadu_ps(gates); \
if (type == "sigmoid") { \
i = _mm256_loadu_ps(gates + 8); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \
f = _mm256_loadu_ps(gates + 16); \
} else if (type == "relu") { \
o = _mm256_loadu_ps(gates + 24); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
} else if (type == "tanh") { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \
i = _mm256_loadu_ps(ct_1); \
} else if (type == "identity" || type == "") { \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \
f = _mm256_add_ps(c, f); \
} \
_mm256_storeu_ps(ct, f); \
PADDLE_THROW("Not support type: %s", type); \
/* H_t = act_cell(C_t) * ogated */
\
}; \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
avx_act_gate_ = GetAVXAct(act_gate); \
_mm256_storeu_ps(ht, o); \
avx_act_cand_ = GetAVXAct(act_cand); \
avx_act_cell_ = GetAVXAct(act_cell); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
float* gates, const float* ct_1, float* ct, float* ht, \
const float* wp_data, float* checked) const { \
/* gates: W_ch, W_ih, W_fh, W_oh */
\
__m256 c, i, f, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
f = _mm256_loadu_ps(gates + 16); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
i = _mm256_loadu_ps(ct_1); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
f = _mm256_add_ps(c, f); \
_mm256_storeu_ps(ct, f); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} \
template <> \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/
\
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */
\
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
}
}
// TODO(TJ): optimize keq16
// TODO(TJ): optimize keq16
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录