Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ba3eaed7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ba3eaed7
编写于
11月 16, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
exp support all size
上级
d239801b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
113 addition
and
14 deletion
+113
-14
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+103
-11
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+6
-2
paddle/fluid/operators/math/jit_kernel_test.cc
paddle/fluid/operators/math/jit_kernel_test.cc
+4
-1
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
ba3eaed7
...
...
@@ -81,10 +81,10 @@ void VXXJitCode::generate() {
}
if
(
rest
>=
2
)
{
if
(
scalar_index_
!=
1
)
{
vmov
ups
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmov
q
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmov
ups
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmov
q
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulps
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
...
...
@@ -100,10 +100,10 @@ void VXXJitCode::generate() {
}
if
(
rest
>
0
)
{
if
(
scalar_index_
!=
1
)
{
vmov
up
s
(
xmm_src1
,
ptr
[
param1
+
offset
]);
vmov
s
s
(
xmm_src1
,
ptr
[
param1
+
offset
]);
}
if
(
scalar_index_
!=
2
)
{
vmov
up
s
(
xmm_src2
,
ptr
[
param2
+
offset
]);
vmov
s
s
(
xmm_src2
,
ptr
[
param2
+
offset
]);
}
if
(
type_
==
operand_type
::
mul
)
{
vmulss
(
xmm_dst
,
xmm_src1
,
xmm_src2
);
...
...
@@ -179,7 +179,7 @@ bool VActJitCode::init(int d, operand_type type) {
return
ok
;
}
else
if
(
type
==
operand_type
::
exp
)
{
// exp is slower than mkl when d >= 256
return
ok
&&
d
%
8
==
0
&&
d
<
256
;
return
ok
;
//&& d % 4
== 0 && d < 256;
}
else
{
// TODO(TJ): support more
return
ok
&&
d
%
8
==
0
;
...
...
@@ -190,6 +190,10 @@ void VActJitCode::relu_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, ymm_t& ymm_zero) {
vmaxps
(
ymm_dst
,
ymm_zero
,
ymm_src
);
}
void
VActJitCode
::
relu_xmm
(
xmm_t
&
xmm_dst
,
xmm_t
&
xmm_src
,
xmm_t
&
xmm_zero
)
{
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
}
void
VActJitCode
::
exp_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
...
...
@@ -271,6 +275,65 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
exp_xmm
(
xmm_t
&
ymm_dst
,
xmm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
assert
(
ymm_src
.
getIdx
()
!=
ymm_dst
.
getIdx
());
// TODO(TJ): use enfore
// check all idx can not equal
xmm_t
ymm_fx
=
xmm_t
(
fx_idx
);
xmm_t
ymm_fy
=
xmm_t
(
fy_idx
);
xmm_t
ymm_mask
=
xmm_t
(
mask_idx
);
xmm_t
ymm_tmp
=
xmm_t
(
tmp_idx
);
reg64_t
reg_ptr_global
=
rax
;
push
(
reg_ptr_global
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_HIG
]);
vminps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOW
]);
vmaxps
(
ymm_src
,
ymm_src
,
ymm_tmp
);
// express exp(x) as exp(g + n*log(2))
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_LOG2EF
]);
vmulps
(
ymm_fx
,
ymm_src
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_0P5
]);
vaddps
(
ymm_fx
,
ymm_fx
,
ymm_tmp
);
vroundps
(
ymm_fy
,
ymm_fx
,
0x01
);
// if greater, substract 1
vcmpgtps
(
ymm_mask
,
ymm_fy
,
ymm_fx
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vandps
(
ymm_mask
,
ymm_mask
,
ymm_tmp
);
vsubps
(
ymm_fx
,
ymm_fy
,
ymm_mask
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
xmm_t
ymm_z
=
xmm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P0
]);
vmulps
(
ymm_dst
,
ymm_src
,
ymm_tmp
);
for
(
size_t
i
=
OFFSET_EXP_P1
;
i
<
OFFSET_EXP_P5
;
i
+=
(
YMM_FLOAT_BLOCK
*
sizeof
(
float
)))
{
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
i
]);
// P1~P4
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
}
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_P5
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_z
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
xmm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_int_0x7f
));
vmovdqa
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
pop
(
reg_ptr_global
);
}
void
VActJitCode
::
sigmoid_ymm
(
ymm_t
&
ymm_dst
,
ymm_t
&
ymm_src
,
int
fx_idx
,
int
fy_idx
,
int
mask_idx
,
int
tmp_idx
)
{
// y = 1 / (1 + e^-x)
...
...
@@ -343,7 +406,7 @@ void VActJitCode::generate() {
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
offset
+=
sizeof
(
float
)
*
YMM_FLOAT_BLOCK
;
}
if
(
type_
!=
operand_type
::
relu
)
{
if
(
type_
!=
operand_type
::
relu
&&
type_
!=
operand_type
::
exp
)
{
// TODO(TJ): remove me
ret
();
return
;
...
...
@@ -351,21 +414,50 @@ void VActJitCode::generate() {
int
rest
=
num_
%
YMM_FLOAT_BLOCK
;
if
(
rest
>=
4
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_xmm
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_xmm
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
vmovups
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
4
;
rest
-=
4
;
}
if
(
rest
>=
2
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
vmovq
(
xmm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_xmm
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_xmm
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
vmovq
(
ptr
[
param2
+
offset
],
xmm_dst
);
offset
+=
sizeof
(
float
)
*
2
;
rest
-=
2
;
}
if
(
rest
>
0
)
{
vmovups
(
xmm_src
,
ptr
[
param1
+
offset
]);
vmaxps
(
xmm_dst
,
xmm_zero
,
xmm_src
);
// vmovups();
vmovss
(
xmm_src
,
ptr
[
param1
+
offset
]);
switch
(
type_
)
{
case
operand_type
::
relu
:
relu_xmm
(
xmm_dst
,
xmm_src
,
xmm_zero
);
break
;
case
operand_type
::
exp
:
exp_xmm
(
xmm_dst
,
xmm_src
,
2
,
3
,
4
,
5
);
break
;
default:
break
;
}
vmovss
(
ptr
[
param2
+
offset
],
xmm_dst
);
}
ret
();
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
ba3eaed7
...
...
@@ -127,13 +127,17 @@ class VActJitCode : public JitCode {
void
generate
()
override
;
protected:
// compute relu with ymm
// compute relu with ymm
, xmm
void
relu_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
const
Xbyak
::
Ymm
&
zero
);
void
relu_xmm
(
const
Xbyak
::
Xmm
&
dst
,
const
Xbyak
::
Xmm
&
src
,
const
Xbyak
::
Xmm
&
zero
);
// compute exp with ymm
// compute exp with ymm
, xmm
void
exp_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
void
exp_xmm
(
const
Xbyak
::
Xmm
&
dst
,
const
Xbyak
::
Xmm
&
src
,
int
fx_idx
=
2
,
int
fy_idx
=
3
,
int
mask_idx
=
4
,
int
tmp_idx
=
5
);
// compute sigmoid with ymm
void
sigmoid_ymm
(
const
Xbyak
::
Ymm
&
dst
,
const
Xbyak
::
Ymm
&
src
,
int
fx_idx
=
2
,
...
...
paddle/fluid/operators/math/jit_kernel_test.cc
浏览文件 @
ba3eaed7
...
...
@@ -33,6 +33,9 @@ limitations under the License. */
constexpr
int
repeat
=
20000
;
// TODO(TJ): benchmark and test should be seperated,
// benchmark should verify more sizes
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
...
...
@@ -156,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
TEST
(
JitKernel
,
vexp
)
{
namespace
jit
=
paddle
::
operators
::
math
::
jitkernel
;
for
(
int
d
:
{
7
,
8
,
1
5
,
16
,
30
,
128
,
256
})
{
for
(
int
d
:
{
7
,
8
,
1
2
,
15
,
16
,
20
,
30
,
128
,
256
})
{
std
::
vector
<
float
>
x
(
d
);
std
::
vector
<
float
>
zref
(
d
),
ztgt
(
d
);
RandomVec
<
float
>
(
d
,
x
.
data
(),
-
2.
f
,
2.
f
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录