Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cccc9906
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cccc9906
编写于
11月 16, 2018
作者:
P
peizhilin
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'windows/build' into windows/online
test=develop
上级
e41461c5
764f97de
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
632 addition
and
505 deletion
+632
-505
paddle/fluid/operators/math/cpu_vec.h
paddle/fluid/operators/math/cpu_vec.h
+9
-9
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+230
-10
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+60
-12
paddle/fluid/operators/math/jit_kernel.h
paddle/fluid/operators/math/jit_kernel.h
+9
-25
paddle/fluid/operators/math/jit_kernel_blas.cc
paddle/fluid/operators/math/jit_kernel_blas.cc
+31
-43
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+12
-12
paddle/fluid/operators/math/jit_kernel_exp.cc
paddle/fluid/operators/math/jit_kernel_exp.cc
+224
-350
paddle/fluid/operators/math/jit_kernel_macro.h
paddle/fluid/operators/math/jit_kernel_macro.h
+19
-11
paddle/fluid/operators/math/jit_kernel_rnn.cc
paddle/fluid/operators/math/jit_kernel_rnn.cc
+19
-19
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+13
-12
python/paddle/fluid/tests/book/test_label_semantic_roles.py
python/paddle/fluid/tests/book/test_label_semantic_roles.py
+6
-2
未找到文件。
paddle/fluid/operators/math/cpu_vec.h
浏览文件 @
cccc9906
...
...
@@ -33,11 +33,11 @@ namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define
AVX
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX_DOUBLE_BLOCK 4
#define
AVX2
_FLOAT_BLOCK 8
#define
YMM
_FLOAT_BLOCK 8
#define AVX2_DOUBLE_BLOCK 4
#define
AVX512
_FLOAT_BLOCK 16
#define
ZMM
_FLOAT_BLOCK 16
#define AVX512_DOUBLE_BLOCK 8
template
<
typename
T
>
...
...
@@ -88,7 +88,7 @@ template <>
inline
void
vec_scal
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_scal
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -142,7 +142,7 @@ template <>
inline
void
vec_bias_sub
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_bias_sub
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
const
float
*
y
,
const
float
*
z
,
float
*
out
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_cross
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
,
z
,
out
);
return
;
...
...
@@ -257,7 +257,7 @@ template <>
inline
void
vec_add_bias
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
a
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_add_bias
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
a
,
x
,
y
);
return
;
...
...
@@ -326,7 +326,7 @@ template <>
inline
void
vec_sigmoid
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
)
{
vec_sigmoid
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
@@ -415,7 +415,7 @@ template <>
inline
void
vec_relu
<
float
,
platform
::
jit
::
avx
>
(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
#ifdef __AVX__
constexpr
int
block
=
AVX
_FLOAT_BLOCK
;
constexpr
int
block
=
YMM
_FLOAT_BLOCK
;
if
(
n
<
block
*
4
)
{
vec_relu
<
float
,
platform
::
jit
::
isa_any
>
(
n
,
x
,
y
);
return
;
...
...
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
cccc9906
...
...
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
}
else
if
(
scalar_index_
==
2
)
{
vbroadcastss
(
ymm_src2
,
ptr
[
param2
]);
}
for
(
int
i
=
0
;
i
<
num_
/
AVX
_FLOAT_BLOCK
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_
/
YMM
_FLOAT_BLOCK
;
++
i
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
ymm_src1
,
ptr
[
param1
+
offset
]);
}
...
...
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_dst
);
}
vmovups
(
ptr
[
param3
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX
_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM
_FLOAT_BLOCK
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
if
(
scalar_index_
!=
1
)
{
vmovups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
...
...
@@ -118,18 +118,237 @@ void VXXJitCode::generate() {
ret
();
}
bool
ReluJitCode
::
init
(
int
d
)
{
return
MayIUse
(
avx
);
}
#define ALIGN32 __attribute__((aligned(32)))
#define EXP_HIG 88.3762626647949f
#define EXP_LOW -88.3762626647949f
#define CEPHES_LOG2EF 1.44269504088896341
#define CEPHES_EXP_C1 0.693359375
#define CEPHES_EXP_C2 -2.12194440e-4
#define CEPHES_EXP_P0 1.9875691500E-4
#define CEPHES_EXP_P1 1.3981999507E-3
#define CEPHES_EXP_P2 8.3334519073E-3
#define CEPHES_EXP_P3 4.1665795894E-2
#define CEPHES_EXP_P4 1.6666665459E-1
#define CEPHES_EXP_P5 5.0000001201E-1
void
ReluJitCode
::
generate
()
{
int
offset
=
0
;
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
static
const
float
exp_float_consts
[]
ALIGN32
=
{
REPEAT_8TIMES
(
1.
f
),
REPEAT_8TIMES
(
2.
f
),
REPEAT_8TIMES
(
0.5
f
),
REPEAT_8TIMES
(
EXP_HIG
),
REPEAT_8TIMES
(
EXP_LOW
),
REPEAT_8TIMES
(
CEPHES_LOG2EF
),
REPEAT_8TIMES
(
CEPHES_EXP_C1
),
REPEAT_8TIMES
(
CEPHES_EXP_C2
),
REPEAT_8TIMES
(
CEPHES_EXP_P0
),
REPEAT_8TIMES
(
CEPHES_EXP_P1
),
REPEAT_8TIMES
(
CEPHES_EXP_P2
),
REPEAT_8TIMES
(
CEPHES_EXP_P3
),
REPEAT_8TIMES
(
CEPHES_EXP_P4
),
REPEAT_8TIMES
(
CEPHES_EXP_P5
),
REPEAT_8TIMES
(
EXP_MAX_INPUT
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MAX
),
REPEAT_8TIMES
(
SIGMOID_THRESHOLD_MIN
)};
static
const
int
exp_int_0x7f
[]
ALIGN32
=
{
REPEAT_8TIMES
(
0x7f
)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
bool
VActJitCode
::
init
(
int
d
,
operand_type
type
)
{
bool
ok
=
MayIUse
(
avx
);
if
(
type
==
operand_type
::
relu
)
{
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
%
8
==
0
;
}
}
void
VActJitCode
::
relu_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
ymm_t
&
ymm_zero
)
{
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
ymm_t
ymm_fx
=
ymm_t
(
fx_idx
);
ymm_t
ymm_fy
=
ymm_t
(
fy_idx
);
ymm_t
ymm_mask
=
ymm_t
(
mask_idx
);
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
ymm_t
ymm_z
=
ymm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
if
(
MayIUse
(
avx2
))
{
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
.
getIdx
());
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
YMM_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
],
xtmp1
);
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_tmp
+
(
YMM_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_tmp
]);
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MAX
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_SIGMOID_MIN
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vxorps
(
ymm_tmp
,
ymm_tmp
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_tmp
,
ymm_src
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
tanh_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 2 / (1 + e^(-2x)) - 1
ymm_t
ymm_tmp
=
ymm_t
(
tmp_idx
);
ymm_t
ymm_zero
=
ymm_t
(
mask_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
for
(
int
i
=
0
;
i
<
num_
/
AVX_FLOAT_BLOCK
;
++
i
)
{
vsubps
(
ymm_tmp
,
ymm_zero
,
ymm_tmp
);
vmulps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
exp_ymm
(
ymm_dst
,
ymm_src
,
fx_idx
,
fy_idx
,
mask_idx
,
tmp_idx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_TWO
]);
vdivps
(
ymm_dst
,
ymm_tmp
,
ymm_dst
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_ONE
]);
vsubps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
generate
()
{
xmm_t
xmm_zero
=
xmm_t
(
2
);
ymm_t
ymm_zero
=
ymm_t
(
2
);
if
(
type_
==
operand_type
::
relu
)
{
vxorps
(
ymm_zero
,
ymm_zero
,
ymm_zero
);
}
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
num_
/
YMM_FLOAT_BLOCK
;
++
i
)
{
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_ymm
(
ymm_dst
,
ymm_src
,
ymm_zero
);
break
;
case
operand_type
::
exp
:
exp_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
sigmoid
:
sigmoid_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
tanh
:
tanh_ymm
(
ymm_dst
,
ymm_src
,
2
,
3
,
4
,
5
);
break
;
case
operand_type
::
identity
:
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
AVX_FLOAT_BLOCK
;
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
// TODO(TJ): remove me
ret
();
return
;
}
int
rest
=
num_
%
AVX
_FLOAT_BLOCK
;
int
rest
=
num_
%
YMM
_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
...
...
@@ -151,6 +370,7 @@ void ReluJitCode::generate() {
}
ret
();
}
}
// namespace gen
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
cccc9906
...
...
@@ -29,7 +29,16 @@ using ymm_t = const Xbyak::Ymm;
using
zmm_t
=
const
Xbyak
::
Zmm
;
using
Label
=
Xbyak
::
Label
;
typedef
enum
{
mul
=
0
,
add
}
operand_type
;
typedef
enum
{
mul
=
0
,
add
,
sub
,
relu
,
exp
,
sigmoid
,
tanh
,
identity
}
operand_type
;
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class
VXXJitCode
:
public
JitCode
{
...
...
@@ -85,26 +94,65 @@ class VXXJitCode : public JitCode {
ymm_t
ymm_zero
=
ymm_t
(
3
);
};
class
Relu
JitCode
:
public
JitCode
{
class
VAct
JitCode
:
public
JitCode
{
public:
DECLARE_JIT_CODE
(
ReluJitCode
);
explicit
ReluJitCode
(
int
d
,
size_t
code_size
=
256
*
1024
,
const
char
*
name
()
const
override
{
std
::
string
base
=
"VActJitCode"
;
switch
(
type_
)
{
case
operand_type
::
relu
:
base
+=
"_Relu"
;
break
;
case
operand_type
::
exp
:
base
+=
"_Exp"
;
break
;
case
operand_type
::
sigmoid
:
base
+=
"_Sigmoid"
;
break
;
case
operand_type
::
tanh
:
base
+=
"_Tanh"
;
break
;
case
operand_type
::
identity
:
base
+=
"_Identity"
;
break
;
default:
break
;
}
return
base
.
c_str
();
}
explicit
VActJitCode
(
int
d
,
operand_type
type
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
{}
static
bool
init
(
int
d
);
:
JitCode
(
code_size
,
code_ptr
),
num_
(
d
)
,
type_
(
type
)
{}
static
bool
init
(
int
d
,
operand_type
type
);
void
generate
()
override
;
private:
protected:
// compute relu with ymm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
// compute exp with ymm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute tanh with ymm
void
tanh_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
protected:
int
num_
;
operand_type
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
xmm_t
xmm_zero
=
xmm_t
(
0
);
xmm_t
xmm_src
=
xmm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
xmm_t
xmm_src
=
xmm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
0
);
ymm_t
ymm_zero
=
ymm_t
(
0
);
ymm_t
ymm_src
=
ymm_t
(
1
);
xmm_t
xmm_dst
=
xmm_t
(
1
);
ymm_t
ymm_dst
=
ymm_t
(
1
);
};
...
...
paddle/fluid/operators/math/jit_kernel.h
浏览文件 @
cccc9906
...
...
@@ -29,9 +29,9 @@ namespace jitkernel {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define
AVX_FLOAT_BLOCK 8
#define
AVX2
_FLOAT_BLOCK 8
#define
AVX512
_FLOAT_BLOCK 16
#define
XMM_FLOAT_BLOCK 4
#define
YMM
_FLOAT_BLOCK 8
#define
ZMM
_FLOAT_BLOCK 16
typedef
enum
{
kLT8
,
kEQ8
,
kGT8LT16
,
kEQ16
,
kGT16
}
jit_block
;
...
...
@@ -97,39 +97,23 @@ class VAddBiasKernel : public Kernel {
template
<
typename
T
>
class
VActKernel
:
public
Kernel
{
public:
v
irtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
v
oid
(
*
Compute
)(
const
T
*
,
T
*
,
int
)
;
};
template
<
typename
T
>
class
VReluKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
void
(
*
Compute
)(
const
T
*
,
T
*
,
int
);
};
class
VReluKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VIdentityKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VIdentityKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VExpKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VExpKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VSigmoidKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
VTanhKernel
:
public
VActKernel
<
T
>
{
public:
virtual
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
=
0
;
};
class
VTanhKernel
:
public
VActKernel
<
T
>
{};
template
<
typename
T
>
class
LSTMKernel
:
public
Kernel
{
...
...
paddle/fluid/operators/math/jit_kernel_blas.cc
浏览文件 @
cccc9906
...
...
@@ -25,10 +25,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace
paddle
{
namespace
operators
{
namespace
math
{
...
...
@@ -128,23 +124,16 @@ void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
#endif
#define DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
/* VMUL JitKernel */
template
<
typename
T
>
class
VMulKernelImpl
:
public
VMulKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VMulKernelImpl
(
int
d
)
:
VMulKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
// roughly estimate the size of code
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -191,11 +180,11 @@ bool VMulKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddKernelImpl
:
public
VAddKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddKernelImpl
(
int
d
)
:
VAddKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -241,11 +230,11 @@ bool VAddKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddReluKernelImpl
:
public
VAddReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddReluKernelImpl
(
int
d
)
:
VAddReluKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
0
,
true
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -273,11 +262,11 @@ bool VAddReluKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VScalKernelImpl
:
public
VScalKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VScalKernelImpl
(
int
d
)
:
VScalKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
mul
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -322,11 +311,11 @@ bool VScalKernelImpl<double>::useMKL(int d) {
template
<
typename
T
>
class
VAddBiasKernelImpl
:
public
VAddBiasKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VAddBiasKernelImpl
(
int
d
)
:
VAddBiasKernel
<
T
>
()
{
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
+
d
/
AVX
_FLOAT_BLOCK
*
4
*
8
;
size_t
sz
=
96
+
d
/
YMM
_FLOAT_BLOCK
*
4
*
8
;
jitcode_
.
reset
(
new
gen
::
VXXJitCode
(
d
,
gen
::
operand_type
::
add
,
1
,
false
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
...
...
@@ -355,15 +344,15 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
template
<
typename
T
>
class
VReluKernelImpl
:
public
VReluKernel
<
T
>
{
public:
DECLARE_STATIC_FUNC
;
JITKERNEL_
DECLARE_STATIC_FUNC
;
explicit
VReluKernelImpl
(
int
d
)
:
VReluKernel
<
T
>
()
{
this
->
num_
=
d
;
// TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if
(
useJIT
(
d
))
{
size_t
sz
=
96
/*init*/
+
d
/
AVX_FLOAT_BLOCK
*
4
/* instructions*/
*
8
/*everage byte for each instruction*/
;
jitcode_
.
reset
(
new
gen
::
ReluJitCode
(
d
,
sz
>
4096
?
sz
:
4096
));
size_t
sz
=
96
/* init size */
+
d
/
YMM_FLOAT_BLOCK
*
4
/* instructions */
*
8
/* average bytes for each instruction */
;
jitcode_
.
reset
(
new
gen
::
VActJitCode
(
d
,
gen
::
operand_type
::
relu
,
sz
>
4096
?
sz
:
4096
));
this
->
Compute
=
jitcode_
->
getCode
<
void
(
*
)(
const
T
*
,
T
*
,
int
)
>
();
return
;
}
...
...
@@ -371,24 +360,32 @@ class VReluKernelImpl : public VReluKernel<T> {
this
->
Compute
=
VReluRefer
<
T
>
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{
VReluRefer
(
x
,
y
,
this
->
num_
);
}
#ifdef PADDLE_WITH_XBYAK
private:
std
::
unique_ptr
<
gen
::
Relu
JitCode
>
jitcode_
{
nullptr
};
std
::
unique_ptr
<
gen
::
VAct
JitCode
>
jitcode_
{
nullptr
};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template
<
>
bool
VReluKernelImpl
<
float
>::
useJIT
(
int
d
)
{
return
gen
::
ReluJitCode
::
init
(
d
);
return
gen
::
VActJitCode
::
init
(
d
,
gen
::
operand_type
::
relu
);
}
#endif
#undef DECLARE_STATIC_FUNC
template
<
typename
T
>
inline
void
VIdentityRefer
(
const
T
*
x
,
T
*
y
,
int
n
)
{}
/* An empty JitKernel */
template
<
typename
T
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
JITKERNEL_DECLARE_STATIC_FUNC
;
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
Compute
=
VIdentityRefer
<
T
>
;
}
};
REGISTER_JITKERNEL
(
vmul
,
VMulKernel
);
REGISTER_JITKERNEL
(
vadd
,
VAddKernel
);
...
...
@@ -396,16 +393,7 @@ REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL
(
vscal
,
VScalKernel
);
REGISTER_JITKERNEL
(
vaddbias
,
VAddBiasKernel
);
REGISTER_JITKERNEL
(
vrelu
,
VReluKernel
);
/* An empty JitKernel */
template
<
typename
T
,
platform
::
jit
::
cpu_isa_t
isa
,
jit_block
>
class
VIdentityKernelImpl
:
public
VIdentityKernel
<
T
>
{
public:
explicit
VIdentityKernelImpl
(
int
d
)
:
VIdentityKernel
<
T
>
()
{
this
->
num_
=
d
;
}
void
ComputeDeprecated
(
const
T
*
x
,
T
*
y
)
const
override
{}
};
REGISTER_JITKERNEL_DEPRECATED
(
videntity
,
VIdentityKernel
);
REGISTER_JITKERNEL
(
videntity
,
VIdentityKernel
);
}
// namespace jitkernel
}
// namespace math
...
...
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
cccc9906
...
...
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
AVX
_FLOAT_BLOCK; \
this->end_ = this->num_ /
YMM
_FLOAT_BLOCK; \
this->rest_ = this->num_ %
YMM
_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX
_FLOAT_BLOCK) \
INIT_ALPHA(
YMM
_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX
_FLOAT_BLOCK) \
UPDATE_ALPHA(
YMM
_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
...
...
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX2_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX2_FLOAT_BLOCK;
\
this->end_ = this->num_ /
YMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
YMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX2_FLOAT_BLOCK)
\
INIT_ALPHA(
YMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(
AVX2_FLOAT_BLOCK)
\
UPDATE_ALPHA(
YMM_FLOAT_BLOCK)
\
} \
seq_offset += this->num_; \
} \
...
...
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ /
AVX512_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
AVX512_FLOAT_BLOCK;
\
this->end_ = this->num_ /
ZMM_FLOAT_BLOCK;
\
this->rest_ = this->num_ %
ZMM_FLOAT_BLOCK;
\
} \
template <> \
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(
AVX512_FLOAT_BLOCK)
\
INIT_ALPHA(
ZMM_FLOAT_BLOCK)
\
/* Use the column-major strategy to get the location of maximum score.*/
\
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
...
...
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/
\
j_offset +=
AVX512_FLOAT_BLOCK;
\
j_offset +=
ZMM_FLOAT_BLOCK;
\
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
...
...
paddle/fluid/operators/math/jit_kernel_exp.cc
浏览文件 @
cccc9906
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/jit_kernel_macro.h
浏览文件 @
cccc9906
...
...
@@ -15,12 +15,20 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
namespace
jitkernel
{
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
...
...
@@ -86,17 +94,17 @@ namespace jitkernel {
namespace
jit
=
platform
::
jit
;
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa)
\
if (d <
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8);
\
} else if (d ==
AVX_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8);
\
} else if (d >
AVX_FLOAT_BLOCK && d < AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16);
\
} else if (d ==
AVX512
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16);
\
} else {
\
macro_(ker, dtype, isa, kGT16);
\
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d <
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kLT8); \
} else if (d ==
YMM_FLOAT_BLOCK) {
\
macro_(ker, dtype, isa, kEQ8); \
} else if (d >
YMM_FLOAT_BLOCK && d < ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d ==
ZMM
_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
...
...
paddle/fluid/operators/math/jit_kernel_rnn.cc
浏览文件 @
cccc9906
...
...
@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void
ComputeCtHt
(
T
*
gates
,
const
T
*
ct_1
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
,
T
*
checked
)
const
override
{
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d3_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d3
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
...
...
@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_
->
Compute
(
wp_data
,
ct_1
,
checked
,
d_
);
vmul_d_
->
Compute
(
wp_data
+
d_
,
ct_1
,
checked
+
d_
,
d_
);
vadd_d2_
->
Compute
(
checked
,
gates
+
d_
,
gates
+
d_
,
d2_
);
act_gate_d2_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d
_
);
act_gate_d2_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d2
_
);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
gates
+
d_
,
d_
);
vmul_d_
->
Compute
(
ct_1
,
gates
+
d2_
,
gates
+
d2_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d2_
,
ct
,
d_
);
/* get ogated*/
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
void
ComputeC1H1
(
T
*
gates
,
T
*
ct
,
T
*
ht
,
const
T
*
wp_data
)
const
override
{
/* C_t = igated * cgated*/
act_gate_d_
->
Compute
Deprecated
(
gates
+
d_
,
gates
+
d_
);
act_cand_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d_
->
Compute
(
gates
+
d_
,
gates
+
d_
,
d_
);
act_cand_d_
->
Compute
(
gates
,
gates
,
d_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d_
,
ct
,
d_
);
/* get outgated, put W_oc * C_t on igated */
vmul_d_
->
Compute
(
wp_data
+
d2_
,
ct
,
gates
+
d_
,
d_
);
vadd_d_
->
Compute
(
gates
+
d_
,
gates
+
d3_
,
gates
+
d3_
,
d_
);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_
->
Compute
Deprecated
(
gates
+
d3_
,
gates
+
d3
_
);
act_cell_d_
->
Compute
Deprecated
(
ct
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
+
d3_
,
gates
+
d3_
,
d
_
);
act_cell_d_
->
Compute
(
ct
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
+
d2_
,
gates
+
d3_
,
ht
,
d_
);
}
...
...
@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
void
ComputeH1
(
T
*
gates
,
T
*
ht
)
const
override
{
act_gate_d_
->
Compute
Deprecated
(
gates
,
gates
);
act_state_d_
->
Compute
Deprecated
(
gates
+
d2_
,
gates
+
d2
_
);
act_gate_d_
->
Compute
(
gates
,
gates
,
d_
);
act_state_d_
->
Compute
(
gates
+
d2_
,
gates
+
d2_
,
d
_
);
vmul_d_
->
Compute
(
gates
,
gates
+
d2_
,
ht
,
d_
);
}
void
ComputeHtPart1
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
// W: {W_update, W_reset; W_state}
act_gate_d2_
->
Compute
Deprecated
(
gates
,
gates
);
act_gate_d2_
->
Compute
(
gates
,
gates
,
d2_
);
vmul_d_
->
Compute
(
ht_1
,
gates
+
d_
,
ht
,
d_
);
}
void
ComputeHtPart2
(
T
*
gates
,
const
T
*
ht_1
,
T
*
ht
)
const
override
{
T
*
y
=
gates
+
d2_
;
act_state_d_
->
Compute
Deprecated
(
y
,
y
);
act_state_d_
->
Compute
(
y
,
y
,
d_
);
// out = zt*ht~ + (1-zt)*ht_1
for
(
int
i
=
0
;
i
<
d_
;
++
i
)
{
ht
[
i
]
=
gates
[
i
]
*
y
[
i
]
+
(
static_cast
<
T
>
(
1
)
-
gates
[
i
])
*
ht_1
[
i
];
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
cccc9906
...
...
@@ -181,7 +181,8 @@ TEST(JitKernel, vexp) {
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
ComputeDeprecated
(
x_data
,
ztgt_data
);
// ker->Compute(x_data, ztgt_data);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -222,7 +223,7 @@ void vsigmoid_better(
y
[
i
]
=
(
x
[
i
]
<
min
)
?
min
:
((
x
[
i
]
>
max
)
?
max
:
x
[
i
]);
y
[
i
]
=
0.
f
-
y
[
i
];
}
vexp
->
Compute
Deprecated
(
y
,
y
);
vexp
->
Compute
(
y
,
y
,
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
y
[
i
]
=
1.
f
/
(
1.
f
+
y
[
i
]);
}
...
...
@@ -253,7 +254,7 @@ TEST(JitKernel, vsigmoid) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -287,7 +288,7 @@ void vtanh_better(
const
int
n
,
const
float
*
x
,
float
*
y
)
{
const
float
a
=
2.
f
,
b
=
-
1.
f
;
vscal
->
Compute
(
&
a
,
x
,
y
,
n
);
vsigmoid
->
Compute
Deprecated
(
y
,
y
);
vsigmoid
->
Compute
(
y
,
y
,
n
);
vscal
->
Compute
(
&
a
,
y
,
y
,
n
);
vaddbias
->
Compute
(
&
b
,
y
,
y
,
n
);
}
...
...
@@ -321,7 +322,7 @@ TEST(JitKernel, vtanh) {
auto
trefe
=
GetCurrentUS
();
auto
ttgts
=
GetCurrentUS
();
for
(
int
i
=
0
;
i
<
repeat
;
++
i
)
{
ker
->
Compute
Deprecated
(
x_data
,
ztgt_data
);
ker
->
Compute
(
x_data
,
ztgt_data
,
d
);
}
auto
ttgte
=
GetCurrentUS
();
...
...
@@ -344,8 +345,8 @@ void lstm_ctht_ref(
const
std
::
shared_ptr
<
const
paddle
::
operators
::
math
::
jitkernel
::
VExpKernel
<
float
>>&
vexp_1
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
const
float
*
i
=
gates
+
d
,
*
f
=
gates
+
d
*
2
,
*
o
=
gates
+
d
*
3
;
const
float
min
=
SIGMOID_THRESHOLD_MIN
;
const
float
max
=
SIGMOID_THRESHOLD_MAX
;
...
...
@@ -355,7 +356,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
float
tmp
=
ct
[
k
]
*
2
;
tmp
=
0.
f
-
((
tmp
<
min
)
?
min
:
((
tmp
>
max
)
?
max
:
tmp
));
vexp_1
->
Compute
Deprecated
(
&
tmp
,
&
tmp
);
vexp_1
->
Compute
(
&
tmp
,
&
tmp
,
1
);
tmp
=
2.
f
/
(
1.
f
+
tmp
)
-
1.
f
;
ht
[
k
]
=
tmp
*
o
[
k
];
}
...
...
@@ -373,13 +374,13 @@ void lstm_ctht_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VAddKernel
<
float
>>&
vadd_d
,
const
int
d
,
float
*
gates
,
const
float
*
ct_1
,
float
*
ct
,
float
*
ht
)
{
int
d2
=
d
*
2
;
vsigmoid_3d
->
Compute
Deprecated
(
gates
+
d
,
gates
+
d
);
vtanh_d
->
Compute
Deprecated
(
gates
,
gates
);
vsigmoid_3d
->
Compute
(
gates
+
d
,
gates
+
d
,
3
*
d
);
vtanh_d
->
Compute
(
gates
,
gates
,
d
);
vmul_d
->
Compute
(
gates
,
gates
+
d
,
gates
+
d
,
d
);
vmul_d
->
Compute
(
ct_1
,
gates
+
d2
,
gates
+
d2
,
d
);
vadd_d
->
Compute
(
gates
+
d
,
gates
+
d2
,
ct
,
d
);
/* H_t = act_cell(C_t) * ogated */
vtanh_d
->
Compute
Deprecated
(
ct
,
gates
+
d2
);
vtanh_d
->
Compute
(
ct
,
gates
+
d2
,
d
);
vmul_d
->
Compute
(
gates
+
d2
,
gates
+
d
*
3
,
ht
,
d
);
}
...
...
@@ -736,7 +737,7 @@ void vaddrelu_better(
const
paddle
::
operators
::
math
::
jitkernel
::
VReluKernel
<
float
>>&
vrelu
,
const
float
*
x
,
const
float
*
y
,
float
*
z
,
int
d
)
{
vadd
->
Compute
(
x
,
y
,
z
,
d
);
vrelu
->
Compute
Deprecated
(
z
,
z
);
vrelu
->
Compute
(
z
,
z
,
d
);
}
TEST
(
JitKernel
,
vaddrelu
)
{
...
...
python/paddle/fluid/tests/book/test_label_semantic_roles.py
浏览文件 @
cccc9906
...
...
@@ -38,7 +38,7 @@ depth = 8
mix_hidden_lr
=
1e-3
IS_SPARSE
=
True
PASS_NUM
=
1
PASS_NUM
=
2
BATCH_SIZE
=
10
embedding_name
=
'emb'
...
...
@@ -196,7 +196,7 @@ def train(use_cuda, save_dirname=None, is_local=True):
print
(
"second per batch: "
+
str
((
time
.
time
(
)
-
start_time
)
/
batch_id
))
# Set the threshold low to speed up the CI test
if
float
(
cost
)
<
6
0.0
:
if
float
(
cost
)
<
8
0.0
:
if
save_dirname
is
not
None
:
# TODO(liuyiqun): Change the target to crf_decode
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
...
...
@@ -208,6 +208,10 @@ def train(use_cuda, save_dirname=None, is_local=True):
batch_id
=
batch_id
+
1
raise
RuntimeError
(
"This model should save_inference_model and return, but not reach here, please check!"
)
if
is_local
:
train_loop
(
fluid
.
default_main_program
())
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录