Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
7f17e561
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7f17e561
编写于
11月 16, 2018
作者:
T
tensor-tang
提交者:
GitHub
11月 16, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14423 from tensor-tang/fea/jit/act
jitcode act relu, exp, sigmoid, tanh
上级
28bd5b7b
1f00723f
变更
10
展开全部
显示空白变更内容
内联
并排
Showing
10 changed file
with
626 addition
and
503 deletion
+626
-503
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+9
-9
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+230
-10
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+60
-12
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+9
-25
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+31
-43
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+12
-12
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+224
-350
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+19
-11
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+19
-19
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+13
-12
未找到文件。
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
7f17e561
...
@@ -33,11 +33,11 @@ namespace math {
...
@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define
AVX
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4
#define AVX_DOUBLE_BLOCK 4
#define
AVX2
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4
#define AVX2_DOUBLE_BLOCK 4
#define
AVX512
_FLOAT_BLOCK 16
#define
ZMM
_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
#define AVX512_DOUBLE_BLOCK 8
template
<
typename
T
>
template
<
typename
T
>
...
@@ -88,7 +88,7 @@ template <>
...
@@ -88,7 +88,7 @@ template <>
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
if
(
n
<
block
)
{
vec_scal
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
vec_scal
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
return
;
...
@@ -142,7 +142,7 @@ template <>
...
@@ -142,7 +142,7 @@ template <>
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
if
(
n
<
block
)
{
vec_bias_sub
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
vec_bias_sub
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
return
;
...
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
...
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const
float
*
y
,
const
float
*
z
,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
float
*
out
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
if
(
n
<
block
)
{
vec_cross
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
,
z
,
out
);
vec_cross
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
,
z
,
out
);
return
;
return
;
...
@@ -257,7 +257,7 @@ template <>
...
@@ -257,7 +257,7 @@ template <>
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
if
(
n
<
block
)
{
vec_add_bias
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
vec_add_bias
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
return
;
...
@@ -326,7 +326,7 @@ template <>
...
@@ -326,7 +326,7 @@ template <>
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
if
(
n
<
block
)
{
vec_sigmoid
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
vec_sigmoid
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
return
;
...
@@ -415,7 +415,7 @@ template <>
...
@@ -415,7 +415,7 @@ template <>
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
float
*
y
)
{
#ifdef __AVX__
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
*
4
)
{
if
(
n
<
block
*
4
)
{
vec_relu
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
vec_relu
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
return
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
7f17e561
...
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
...
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
}
else
if
(
scalar_index_
==
2
)
{
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
}
for
(
int
i
=
0
;
i
<
num_
/
AVX
_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM
_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
}
...
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
...
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
rest
>=
4
)
{
if
(
scalar_index_
!=
1
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
...
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
...
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
ret
();
ret
();
}
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
void
ReluJitCode
::
generate
()
{
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
int
offset
=
0
;
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
%
8
==
0
;
}
}
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
ymm_t
ymm_z
=
ymm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 2 / (1 + e^(-2x)) - 1
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
}
int
rest
=
num_
%
AVX_FLOAT_BLOCK
;
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
...
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
...
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
}
}
ret
();
ret
();
}
}
}
// namespace gen
}
// namespace gen
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
7f17e561
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
class
VXXJitCode
:
public
JitCode
{
...
@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
...
@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
};
class
Relu
JitCode
:
public
JitCode
{
class
VAct
JitCode
:
public
JitCode
{
public:
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
const
char
*
name
()
const
override
{
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
,
type_
(
type
)
{}
static
bool
init
(
int
d
);
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
void
generate
()
override
;
private:
protected:
// compute relu with ymm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute tanh with ymm
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
protected:
int
num_
;
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_zero
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
ymm_t
ymm_src
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_zero
=
ymm_t
(
0
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_src
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
};
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
7f17e561
...
@@ -29,9 +29,9 @@ namespace jitkernel {
...
@@ -29,9 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define EXP_MAX_INPUT 40.0
#define
AVX_FLOAT_BLOCK 8
#define
XMM_FLOAT_BLOCK 4
#define
AVX2
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define
AVX512
_FLOAT_BLOCK 16
#define
ZMM
_FLOAT_BLOCK 16
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
...
@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
...
@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
template
<
typename
T
>
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
class
VActKernel
:
public
Kernel
{
public:
public:
v
irtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
T
*
,
int
)
;
};
};
template
<
typename
T
>
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{
class
VReluKernel
:
public
VActKernel
<
T
>
{};
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
template
<
typename
T
>
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{
class
VIdentityKernel
:
public
VActKernel
<
T
>
{};
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{
class
VExpKernel
:
public
VActKernel
<
T
>
{};
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{};
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{
class
VTanhKernel
:
public
VActKernel
<
T
>
{};
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
template
<
typename
T
>
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
class
LSTMKernel
:
public
Kernel
{
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
7f17e561
...
@@ -25,10 +25,6 @@ limitations under the License. */
...
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
...
@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
...
@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
/* VMUL JitKernel */
template
<
typename
T
>
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
this
->
Compute
=
...
@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
...
@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
this
->
Compute
=
...
@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
...
@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
this
->
Compute
=
...
@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
...
@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
this
->
Compute
=
...
@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
...
@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
this
->
Compute
=
...
@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
...
@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*init*/
+
size_t
sz
=
96
/* init size */
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
d
/
YMM_FLOAT_BLOCK
*
4
/* instructions */
*
8
/*everage byte for each instruction*/
;
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
return
;
}
}
...
@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
...
@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
this
->
Compute
=
VReluRefer
<
T
>
;
this
->
Compute
=
VReluRefer
<
T
>
;
}
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VReluRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
private:
private:
std
::
unique_ptr
<
gen
::
Relu
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
VAct
JitCode
>
jitcode_
{
nullptr
};
#endif
#endif
};
};
#ifdef PADDLE_WITH_XBYAK
#ifdef PADDLE_WITH_XBYAK
template
<
>
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
ReluJitCode
::
init
(
d
);
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
}
#endif
#endif
#undef DECLARE_STATIC_FUNC
template
<
typename
T
>
inline
void
VIdentityRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
VIdentityRefer
<
T
>
;
}
};
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
...
@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
...
@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
/* An empty JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace jitkernel
}
// namespace math
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
7f17e561
...
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
int tag_num) \
: CRFDecodeKernel<float>() { \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX
_FLOAT_BLOCK; \
this->end_ = this->num_ /
YMM
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
AVX
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
YMM
_FLOAT_BLOCK; \
} \
} \
template <> \
template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
int* track) const { \
INIT_ALPHA(
AVX
_FLOAT_BLOCK) \
INIT_ALPHA(
YMM
_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
constexpr int state_trans_base_idx = 2; \
...
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
trans_offset += this->num_; \
} \
} \
UPDATE_ALPHA(
AVX
_FLOAT_BLOCK) \
UPDATE_ALPHA(
YMM
_FLOAT_BLOCK) \
} \
} \
seq_offset += this->num_; \
seq_offset += this->num_; \
} \
} \
...
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX2_FLOAT_BLOCK;
\
this->end_ = this->num_ /
YMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX2_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
YMM_FLOAT_BLOCK;
\
} \
} \
template <> \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
int* track) const { \
INIT_ALPHA(
AVX2_FLOAT_BLOCK)
\
INIT_ALPHA(
YMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
constexpr int state_trans_base_idx = 2; \
...
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
trans_offset += this->num_; \
} \
} \
UPDATE_ALPHA(
AVX2_FLOAT_BLOCK)
\
UPDATE_ALPHA(
YMM_FLOAT_BLOCK)
\
} \
} \
seq_offset += this->num_; \
seq_offset += this->num_; \
} \
} \
...
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
int tag_num) \
: CRFDecodeKernel<float>() { \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX512_FLOAT_BLOCK;
\
this->end_ = this->num_ /
ZMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX512_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
ZMM_FLOAT_BLOCK;
\
} \
} \
template <> \
template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
int* track) const { \
INIT_ALPHA(
AVX512_FLOAT_BLOCK)
\
INIT_ALPHA(
ZMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
constexpr int state_trans_base_idx = 2; \
...
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \
this->num_ + j_offset), \
max_j); \
max_j); \
/* Calculate the offset of next step*/
\
/* Calculate the offset of next step*/
\
j_offset +=
AVX512_FLOAT_BLOCK;
\
j_offset +=
ZMM_FLOAT_BLOCK;
\
if (j == this->end_ - 1) { \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
j_offset += last_offset; \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
7f17e561
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
7f17e561
...
@@ -15,12 +15,20 @@ limitations under the License. */
...
@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
namespace
jitkernel
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string ker_class##Impl<float>::name(int d) { \
...
@@ -87,13 +95,13 @@ namespace jitkernel {
...
@@ -87,13 +95,13 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d <
AVX_FLOAT_BLOCK) {
\
if (d <
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8); \
macro_(ker, dtype, isa, kLT8); \
} else if (d ==
AVX_FLOAT_BLOCK) {
\
} else if (d ==
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8); \
macro_(ker, dtype, isa, kEQ8); \
} else if (d >
AVX_FLOAT_BLOCK && d < AVX512
_FLOAT_BLOCK) { \
} else if (d >
YMM_FLOAT_BLOCK && d < ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d ==
AVX512
_FLOAT_BLOCK) { \
} else if (d ==
ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
macro_(ker, dtype, isa, kEQ16); \
} else { \
} else { \
macro_(ker, dtype, isa, kGT16); \
macro_(ker, dtype, isa, kGT16); \
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
7f17e561
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d3_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d3
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d2
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
}
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
}
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d_
->
Compute
(
gates
,
gates
,
d_
);
act_state_d_
->
Compute
Deprecated
(
gates
+
d2_
,
gates
+
d2
_
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d2_
->
Compute
(
gates
,
gates
,
d2_
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
T
*
y
=
gates
+
d2_
;
T
*
y
=
gates
+
d2_
;
act_state_d_
->
Compute
Deprecated
(
y
,
y
);
act_state_d_
->
Compute
(
y
,
y
,
d_
);
// out = zt*ht~ + (1-zt)*ht_1
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
7f17e561
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeDeprecated
(
x_data
,
ztgt_data
);
// ker->Compute(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -222,7 +223,7 @@ void vsigmoid_better(
...
@@ -222,7 +223,7 @@ void vsigmoid_better(
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
y
[
i
]
=
0.
f
-
y
[
i
];
}
}
vexp
->
Compute
Deprecated
(
y
,
y
);
vexp
->
Compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
}
...
@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
...
@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto
trefe
=
GetCurrentUS
();
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -287,7 +288,7 @@ void vtanh_better(
...
@@ -287,7 +288,7 @@ void vtanh_better(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
Deprecated
(
y
,
y
);
vsigmoid
->
Compute
(
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
}
...
@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
...
@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto
trefe
=
GetCurrentUS
();
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
}
auto
ttgte
=
GetCurrentUS
();
auto
ttgte
=
GetCurrentUS
();
...
@@ -344,8 +345,8 @@ void lstm_ctht_ref(
...
@@ -344,8 +345,8 @@ void lstm_ctht_ref(
const
std
::
shared_ptr
<
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
...
@@ -355,7 +356,7 @@ void lstm_ctht_ref(
...
@@ -355,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
// H_t = act_cell(C_t) * ogated
float
tmp
=
ct
[
k
]
*
2
;
float
tmp
=
ct
[
k
]
*
2
;
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
vexp_1
->
Compute
Deprecated
(
&
tmp
,
&
tmp
);
vexp_1
->
Compute
(
&
tmp
,
&
tmp
,
1
);
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
ht
[
k
]
=
tmp
*
o
[
k
];
ht
[
k
]
=
tmp
*
o
[
k
];
}
}
...
@@ -373,13 +374,13 @@ void lstm_ctht_better(
...
@@ -373,13 +374,13 @@ void lstm_ctht_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
int
d2
=
d
*
2
;
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
Deprecated
(
ct
,
gates
+
d2
);
vtanh_d
->
Compute
(
ct
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
}
...
@@ -736,7 +737,7 @@ void vaddrelu_better(
...
@@ -736,7 +737,7 @@ void vaddrelu_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
Deprecated
(
z
,
z
);
vrelu
->
Compute
(
z
,
z
,
d
);
}
}
TEST
(
JitKernel
,
vaddrelu
)
{
TEST
(
JitKernel
,
vaddrelu
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录