Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
6a159071
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6a159071
编写于
11月 15, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add vtanh jitcode of size 8
上级
046374bc
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
153 addition
and
166 deletion
+153
-166
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+52
-15
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+20
-0
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+1
-0
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+79
-150
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+1
-1
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
6a159071
...
...
@@ -168,24 +168,26 @@ void ReluJitCode::generate() {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 15 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
...
...
@@ -216,6 +218,7 @@ void VExpJitCode::exp_ymm(ymm_t& ymm_src, ymm_t& ymm_dst) {
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
...
...
@@ -327,6 +330,40 @@ void VSigmoidJitCode::generate() {
ret
();
}
bool
VTanhJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
)
&&
d
==
8
;
// only 8 yet
}
void
VTanhJitCode
::
vtanh_ymm
(
ymm_t
&
ymm_src
,
ymm_t
&
ymm_dst
)
{
// y = 2 / (1 + e^(-2x)) - 1
// use ymm2, ymm3
reg64_t
reg_ptr_global
=
rax
;
ymm_t
ymm_tmp
=
ymm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_src
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
pop
(
reg_ptr_global
);
}
void
VTanhJitCode
::
generate
()
{
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vtanh_ymm
(
ymm_src
,
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
6a159071
...
...
@@ -149,6 +149,26 @@ class VSigmoidJitCode : public VExpJitCode {
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
class
VTanhJitCode
:
public
VExpJitCode
{
public:
DECLARE_JIT_CODE
(
VTanhJitCode
);
explicit
VTanhJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
VExpJitCode
(
d
,
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
void
generate
()
override
;
// compute sigmoid with ymm
void
vtanh_ymm
(
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
dst
);
private:
int
num_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
6a159071
...
...
@@ -132,6 +132,7 @@ template <typename T>
class
VTanhKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
6a159071
...
...
@@ -45,6 +45,7 @@ void VExpRefer(const T* x, T* y, int n) {
template
<
typename
T
>
void
VSigmoidRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 1 / (1 + e^-x)
const
T
min
=
SIGMOID_THRESHOLD_MIN
;
const
T
max
=
SIGMOID_THRESHOLD_MAX
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
...
...
@@ -53,6 +54,18 @@ void VSigmoidRefer(const T* x, T* y, int n) {
}
}
template
<
typename
T
>
void
VTanhRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{
// y = 2 * sigmoid(2x) - 1
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoidRefer
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
#ifdef PADDLE_WITH_MKLML
template
<
typename
T
>
void
VExpMKL
(
const
T
*
x
,
T
*
y
,
int
n
);
...
...
@@ -80,6 +93,17 @@ void VSigmoidMKL(const T* x, T* y, int n) {
y
[
i
]
=
static_cast
<
T
>
(
1
)
/
(
static_cast
<
T
>
(
1
)
+
y
[
i
]);
}
}
template
<
typename
T
>
void
VTanhMKL
(
const
T
*
x
,
T
*
y
,
int
n
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
x
[
i
];
}
VSigmoidMKL
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
static_cast
<
T
>
(
2
)
*
y
[
i
]
-
static_cast
<
T
>
(
1
);
}
}
#endif
/* VExp JitKernel */
...
...
@@ -189,8 +213,63 @@ bool VSigmoidKernelImpl<double>::useMKL(int d) {
}
#endif
/* VTanh JitKernel */
template
<
typename
T
>
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VTanhKernelImpl
(
int
d
)
:
VTanhKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
jitcode_
.
reset
(
new
gen
::
VTanhJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if
(
useMKL
(
d
))
{
this
->
Compute
=
VTanhMKL
<
T
>
;
return
;
}
#endif
this
->
Compute
=
VTanhRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VTanhRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
VTanhJitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VTanhKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
VTanhJitCode
::
init
(
d
);
}
#endif
#ifdef PADDLE_WITH_MKLML
template
<
>
bool
VTanhKernelImpl
<
float
>::
useMKL
(
int
d
)
{
return
d
>
512
;
}
template
<
>
bool
VTanhKernelImpl
<
double
>::
useMKL
(
int
d
)
{
return
true
;
}
#endif
REGISTER_JITKERNEL
(
vexp
,
VExpKernel
);
REGISTER_JITKERNEL
(
vsigmoid
,
VSigmoidKernel
);
REGISTER_JITKERNEL
(
vtanh
,
VTanhKernel
);
namespace
detail
{
...
...
@@ -337,156 +416,6 @@ __m256 ExpAVX2(__m256 x) {
#endif
}
// namespace detail
#define INTRI_SIGMOID(tmp, min, max, expisa) \
tmp = _mm256_max_ps(tmp, min); \
tmp = _mm256_min_ps(tmp, max); \
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp)
#undef INTRI_VSIGMOID
/* VTanh JitKernel */
template
<
typename
T
,
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VTanhKernelImpl
:
public
VTanhKernel
<
T
>
{
public:
explicit
VTanhKernelImpl
(
int
d
)
:
VTanhKernel
<
T
>
()
{
this
->
num_
=
d
;
vscal_
=
KernelPool
::
Instance
().
template
Get
<
VScalKernel
<
T
>
>
(
d
);
vsigmoid_
=
KernelPool
::
Instance
().
template
Get
<
VSigmoidKernel
<
T
>
>
(
d
);
vaddbias_
=
KernelPool
::
Instance
().
template
Get
<
VAddBiasKernel
<
T
>
>
(
d
);
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
const
T
a
=
static_cast
<
T
>
(
2
),
b
=
static_cast
<
T
>
(
-
1
);
vscal_
->
Compute
(
&
a
,
x
,
y
,
this
->
num_
);
vsigmoid_
->
ComputeDeprecated
(
y
,
y
);
vscal_
->
Compute
(
&
a
,
y
,
y
,
this
->
num_
);
vaddbias_
->
Compute
(
&
b
,
y
,
y
,
this
->
num_
);
}
private:
std
::
shared_ptr
<
const
VScalKernel
<
T
>>
vscal_
;
std
::
shared_ptr
<
const
VSigmoidKernel
<
T
>>
vsigmoid_
;
std
::
shared_ptr
<
const
VAddBiasKernel
<
T
>>
vaddbias_
;
};
#define INTRI_VTANH(tmp, expisa) \
tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), tmp); \
tmp = _mm256_min_ps(tmp, _mm256_set1_ps(EXP_MAX_INPUT)); \
tmp = expisa(tmp); \
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); \
tmp = _mm256_div_ps(_mm256_set1_ps(2.0f), tmp); \
tmp = _mm256_sub_ps(tmp, _mm256_set1_ps(1.0f))
#define INTRI8_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ8>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa, expisa) \
template <> \
void VTanhKernelImpl<float, isa, kEQ16>::ComputeDeprecated(const float* x, \
float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
INTRI_VTANH(tmp0, expisa); \
INTRI_VTANH(tmp1, expisa); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#define INTRI_GT8LT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT8LT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - this->end_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT8LT16>::ComputeDeprecated( \
const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
template <> \
VTanhKernelImpl<float, isa, kGT16>::VTanhKernelImpl(int d) \
: VTanhKernel<float>() { \
this->num_ = d; \
this->rest_ = d % AVX_FLOAT_BLOCK; \
this->end_ = d - this->rest_; \
vscal_ = \
KernelPool::Instance().template Get<VScalKernel<float>>(this->rest_); \
vsigmoid_ = KernelPool::Instance().template Get<VSigmoidKernel<float>>( \
this->rest_); \
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<float>>( \
this->rest_); \
} \
template <> \
void VTanhKernelImpl<float, isa, kGT16>::ComputeDeprecated(const float* x, \
float* y) const { \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
INTRI_VTANH(tmp, expisa); \
_mm256_storeu_ps(y + i, tmp); \
} \
x += this->end_; \
y += this->end_; \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->ComputeDeprecated(y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifdef __AVX__
INTRI8_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT8LT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
INTRI_GT16_FLOAT
(
jit
::
avx
,
detail
::
ExpAVX
);
#endif
#ifdef __AVX2__
INTRI8_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx2
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
#endif
#ifdef __AVX512F__
INTRI8_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
INTRI16_FLOAT
(
jit
::
avx512f
,
detail
::
ExpAVX2
);
// maybe use avx at gt8lt16 and gt16
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef INTRI_VTANH
REGISTER_JITKERNEL_DEPRECATED
(
vtanh
,
VTanhKernel
);
#undef JITKERNEL_NEW_ACT_IMPL
}
// namespace jitkernel
}
// namespace math
}
// namespace operators
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
6a159071
...
...
@@ -322,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录