Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
046374bc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
046374bc
编写于
11月 15, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vsigmoid jitcode of size 8
上级
ee2a7f1b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
177 addition
and
161 deletion
+177
-161
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+69
-16
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+23
-5
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+2
-0
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+80
-137
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+3
-3
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
046374bc
...
@@ -152,10 +152,6 @@ void ReluJitCode::generate() {
...
@@ -152,10 +152,6 @@ void ReluJitCode::generate() {
ret
();
ret
();
}
}
bool
VExpJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
#define ALIGN32 __attribute__((aligned(32)))
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define EXP_LOW -88.3762626647949f
...
@@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) {
...
@@ -171,6 +167,7 @@ bool VExpJitCode::init(int d) {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
...
@@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) {
...
@@ -183,24 +180,43 @@ bool VExpJitCode::init(int d) {
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
)};
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
void
VExpJitCode
::
generate
()
{
bool
VExpJitCode
::
init
(
int
d
)
{
// in: ymm0, out: ymm1
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
// use ymm 0~5, rax
}
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
void
VExpJitCode
::
exp_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// use reg rax and ymm 2~5
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
...
@@ -269,8 +285,45 @@ void VExpJitCode::generate() {
...
@@ -269,8 +285,45 @@ void VExpJitCode::generate() {
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VExpJitCode
::
generate
()
{
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
exp_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
bool
VSigmoidJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VSigmoidJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// use ymm2
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_src
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VSigmoidJitCode
::
generate
()
{
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
sigmoid_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
ret
();
}
}
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
046374bc
...
@@ -117,18 +117,36 @@ class VExpJitCode : public JitCode {
...
@@ -117,18 +117,36 @@ class VExpJitCode : public JitCode {
static
bool
init
(
int
d
);
static
bool
init
(
int
d
);
void
generate
()
override
;
void
generate
()
override
;
protected:
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
private:
private:
int
num_
;
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
reg64_t
reg_ptr_global
=
rax
;
class
VSigmoidJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VSigmoidJitCode
);
explicit
VSigmoidJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
};
};
}
// namespace gen
}
// namespace gen
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
046374bc
...
@@ -29,6 +29,7 @@ namespace jitkernel {
...
@@ -29,6 +29,7 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define EXP_MAX_INPUT 40.0
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK
#define AVX_FLOAT_BLOCK 8
#define AVX_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16
#define AVX512_FLOAT_BLOCK 16
...
@@ -124,6 +125,7 @@ template <typename T>
...
@@ -124,6 +125,7 @@ template <typename T>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
public:
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
046374bc
...
@@ -43,6 +43,16 @@ void VExpRefer(const T* x, T* y, int n) {
...
@@ -43,6 +43,16 @@ void VExpRefer(const T* x, T* y, int n) {
}
}
}
}
template
<
typename
T
>
void
VSigmoidRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
T
tmp
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
std
::
exp
(
-
tmp
));
}
}
#ifdef PADDLE_WITH_MKLML
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
template
<
typename
T
>
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
...
@@ -56,6 +66,20 @@ template <>
...
@@ -56,6 +66,20 @@ template <>
void
VExpMKL
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
void
VExpMKL
<
double
>
(
const
double
*
x
,
double
*
y
,
int
n
)
{
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
platform
::
dynload
::
vdExp
(
n
,
x
,
y
);
}
}
template
<
typename
T
>
void
VSigmoidMKL
(
const
T
*
x
,
T
*
y
,
int
n
)
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
VExpMKL
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
#endif
#endif
/* VExp JitKernel */
/* VExp JitKernel */
...
@@ -108,9 +132,65 @@ template <>
...
@@ -108,9 +132,65 @@ template <>
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
bool
VExpKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
return
true
;
}
}
#endif
/* VSigmoid JitKernel */
template
<
typename
T
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VSigmoidJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VSigmoidMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VSigmoidRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VSigmoidRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VSigmoidJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VSigmoidJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VSigmoidKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VSigmoidKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
#endif
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
namespace
detail
{
namespace
detail
{
...
@@ -258,31 +338,6 @@ __m256 ExpAVX2(__m256 x) {
...
@@ -258,31 +338,6 @@ __m256 ExpAVX2(__m256 x) {
}
// namespace detail
}
// namespace detail
/* VSigmoid JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VSigmoidKernelImpl
:
public
VSigmoidKernel
<
T
>
{
public:
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
this
->
num_
=
d
;
vexp_
=
KernelPool
::
Instance
().
template
Get
<
VExpKernel
<
T
>
>
(
d
);
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
static_cast
<
T
>
(
0
)
-
y
[
i
];
}
vexp_
->
ComputeDeprecated
(
y
,
y
);
for
(
int
i
=
0
;
i
<
this
->
num_
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
private:
std
::
shared_ptr
<
const
VExpKernel
<
T
>>
vexp_
;
};
#define INTRI_SIGMOID(tmp, min, max, expisa) \
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_min_ps(tmp, max); \
...
@@ -290,120 +345,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
...
@@ -290,120 +345,8 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
tmp = expisa(tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ8>::ComputeDeprecated( \
const float* x, float* y) const { \
/* TODO(TJ): try to use static const*/
\
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VSigmoidKernelImpl<float, isa, kEQ16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_SIGMOID(tmp0, min, max, expisa); \
INTRI_SIGMOID(tmp1, min, max, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT8LT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y, tmp); \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VSigmoidKernelImpl<float, isa, kGT16>::VSigmoidKernelImpl(int d) \
: VSigmoidKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vexp_ = \
KernelPool::Instance().template Get<VExpKernel<float>>(this->rest_); \
} \
template <> \
void VSigmoidKernelImpl<float, isa, kGT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); \
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_SIGMOID(tmp, min, max, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
const float min_ = SIGMOID_THRESHOLD_MIN; \
const float max_ = SIGMOID_THRESHOLD_MAX; \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = (x[i] < min_) ? min_ : ((x[i] > max_) ? max_ : x[i]); \
y[i] = 0.f - y[i]; \
} \
vexp_->ComputeDeprecated(y + this->end_, y + this->end_); \
for (int i = this->end_; i < this->num_; ++i) { \
y[i] = 1.f / (1.f + y[i]); \
} \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// maybe use avx2 at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VSIGMOID
#undef INTRI_VSIGMOID
REGISTER_JITKERNEL_DEPRECATED
(
vsigmoid
,
VSigmoidKernel
);
/* VTanh JitKernel */
/* VTanh JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
046374bc
...
@@ -223,7 +223,7 @@ void vsigmoid_better(
...
@@ -223,7 +223,7 @@ void vsigmoid_better(
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
y
[
i
]
=
0.
f
-
y
[
i
];
}
}
vexp
->
Compute
Deprecated
(
y
,
y
);
vexp
->
Compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
}
...
@@ -254,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
...
@@ -254,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto
trefe
=
GetCurrentUS
();
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -288,7 +288,7 @@ void vtanh_better(
...
@@ -288,7 +288,7 @@ void vtanh_better(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
Deprecated
(
y
,
y
);
vsigmoid
->
Compute
(
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录