Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
1f00723f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1f00723f
编写于
11月 16, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
exp, sigmoid, tanh jitcode support more size
test=develop
上级
8cda7b3d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
74 addition
and
72 deletion
+74
-72
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+9
-9
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+30
-27
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+3
-4
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+6
-6
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+12
-12
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+3
-3
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+11
-11
未找到文件。
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
1f00723f
...
...
@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define
AVX
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4
#define
AVX2
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4
#define
AVX512
_FLOAT_BLOCK 16
#define
ZMM
_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
template
<
typename
T
>
...
...
@@ -88,7 +88,7 @@ template <>
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_scal
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -142,7 +142,7 @@ template <>
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_bias_sub
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_cross
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
,
z
,
out
);
return
;
...
...
@@ -257,7 +257,7 @@ template <>
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_add_bias
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -326,7 +326,7 @@ template <>
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_sigmoid
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
@@ -415,7 +415,7 @@ template <>
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
*
4
)
{
vec_relu
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
1f00723f
...
...
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
AVX
_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM
_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
...
...
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
...
...
@@ -133,23 +133,23 @@ void VXXJitCode::generate() {
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 *
AVX
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_ONE 0 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 *
YMM
_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 *
YMM
_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
...
...
@@ -177,9 +177,12 @@ bool VActJitCode::init(int d, operand_type type) {
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
==
8
;
// only 8 yet
return
ok
&&
d
%
8
==
0
;
}
}
...
...
@@ -224,7 +227,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
AVX
_FLOAT_BLOCK
*
sizeof
(
float
)))
{
i
+=
(
YMM
_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
...
...
@@ -249,7 +252,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
AVX
_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM
_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
...
...
@@ -257,7 +260,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
AVX
_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
(
YMM
_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
...
...
@@ -317,7 +320,7 @@ void VActJitCode::generate() {
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
AVX
_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM
_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
case
operand_type
::
relu
:
...
...
@@ -338,14 +341,14 @@ void VActJitCode::generate() {
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
1f00723f
...
...
@@ -29,10 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK
#define AVX_FLOAT_BLOCK 8
#define AVX2_FLOAT_BLOCK 8
#define AVX512_FLOAT_BLOCK 16
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
1f00723f
...
...
@@ -133,7 +133,7 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -184,7 +184,7 @@ class VAddKernelImpl : public VAddKernel<T> {
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -234,7 +234,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -266,7 +266,7 @@ class VScalKernelImpl : public VScalKernel<T> {
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -315,7 +315,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -349,7 +349,7 @@ class VReluKernelImpl : public VReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/* init size */
+
d
/
AVX
_FLOAT_BLOCK
*
4
/* instructions */
*
d
/
YMM
_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
1f00723f
...
...
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
AVX
_FLOAT_BLOCK; \
this->end_ = this->num_ /
YMM
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
YMM
_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX
_FLOAT_BLOCK) \
INIT_ALPHA(
YMM
_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX
_FLOAT_BLOCK) \
UPDATE_ALPHA(
YMM
_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
...
...
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX2_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX2_FLOAT_BLOCK;
\
this->end_ = this->num_ /
YMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
YMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX2_FLOAT_BLOCK)
\
INIT_ALPHA(
YMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX2_FLOAT_BLOCK)
\
UPDATE_ALPHA(
YMM_FLOAT_BLOCK)
\
} \
seq_offset += this->num_; \
} \
...
...
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX512_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX512_FLOAT_BLOCK;
\
this->end_ = this->num_ /
ZMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
ZMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX512_FLOAT_BLOCK)
\
INIT_ALPHA(
ZMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset +=
AVX512_FLOAT_BLOCK;
\
j_offset +=
ZMM_FLOAT_BLOCK;
\
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
1f00723f
...
...
@@ -116,7 +116,7 @@ class VExpKernelImpl : public VExpKernel<T> {
explicit
VExpKernelImpl
(
int
d
)
:
VExpKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
70
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
exp
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
...
...
@@ -167,7 +167,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
explicit
VSigmoidKernelImpl
(
int
d
)
:
VSigmoidKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
82
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
sigmoid
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
...
...
@@ -219,7 +219,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
explicit
VTanhKernelImpl
(
int
d
)
:
VTanhKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX_FLOAT_BLOCK
*
4
*
8
;
// should change
size_t
sz
=
96
+
d
/
YMM_FLOAT_BLOCK
*
84
*
8
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
tanh
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
...
...
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
1f00723f
...
...
@@ -94,17 +94,17 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa)
\
if (d <
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d ==
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d >
AVX_FLOAT_BLOCK && d < AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d ==
AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16);
\
} else {
\
macro_(ker, dtype, isa, kGT16);
\
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d <
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8); \
} else if (d ==
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8); \
} else if (d >
YMM_FLOAT_BLOCK && d < ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d ==
ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录