Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ee2a7f1b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee2a7f1b
编写于
11月 15, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine exp and fix error on avx
test=develop
上级
1e06a32a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
19 deletion
+15
-19
paddle/fluid/operators/math/jit_code.cc
paddle/fluid/operators/math/jit_code.cc
+15
-18
paddle/fluid/operators/math/jit_code.h
paddle/fluid/operators/math/jit_code.h
+0
-1
未找到文件。
paddle/fluid/operators/math/jit_code.cc
浏览文件 @
ee2a7f1b
...
@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
...
@@ -197,10 +197,8 @@ static const int exp_int_0x7f[] ALIGN32 = {REPEAT_8TIMES(0x7f)};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
static
int
g_tmp_mem
[
16
]
ALIGN32
=
{
0
};
void
VExpJitCode
::
generate
()
{
void
VExpJitCode
::
generate
()
{
preCode
();
// push some?
// in: ymm0, out: ymm1
// in: ymm0, out: ymm1
// use ymm 0~5
(and ymm 14~15 if avx only)
// use ymm 0~5
, rax
int
offset
=
0
;
int
offset
=
0
;
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
vmovups
(
ymm_src
,
ptr
[
param1
+
offset
]);
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_ptr_global
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
...
@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
...
@@ -222,7 +220,8 @@ void VExpJitCode::generate() {
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C1
]);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmulps
(
ymm_fy
,
ymm_fx
,
ymm_tmp
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
+
OFFSET_EXP_C2
]);
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
// ymm_z use same with mask
ymm_t
ymm_z
=
ymm_t
(
ymm_mask
.
getIdx
());
vmulps
(
ymm_z
,
ymm_fx
,
ymm_tmp
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_fy
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vsubps
(
ymm_src
,
ymm_src
,
ymm_z
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
vmulps
(
ymm_z
,
ymm_src
,
ymm_src
);
...
@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
...
@@ -240,7 +239,6 @@ void VExpJitCode::generate() {
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_src
);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vmovaps
(
ymm_tmp
,
ptr
[
reg_ptr_global
]);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
vaddps
(
ymm_dst
,
ymm_dst
,
ymm_tmp
);
// build 2^n
// build 2^n
ymm_t
ymm_int
=
ymm_fx
;
ymm_t
ymm_int
=
ymm_fx
;
vcvttps2dq
(
ymm_int
,
ymm_fx
);
vcvttps2dq
(
ymm_int
,
ymm_fx
);
...
@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
...
@@ -250,31 +248,30 @@ void VExpJitCode::generate() {
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpaddd
(
ymm_int
,
ymm_int
,
ymm_tmp
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
vpslld
(
ymm_int
,
ymm_int
,
23
);
}
else
if
(
MayIUse
(
avx
))
{
}
else
if
(
MayIUse
(
avx
))
{
// use ymm_int, ymm_tmp and reg_ptr_global
xmm_t
xtmp1
=
xmm_t
(
ymm_int
.
getIdx
());
xmm_t
xtmp
1
=
xmm_t
(
ymm_int
);
// or magic number should equal the ymm_int
xmm_t
xtmp
2
=
xmm_t
(
ymm_tmp
.
getIdx
());
xmm_t
xtmp2
=
xmm_t
(
ymm_tmp
);
// or magic number should equal the ymm_tmp
reg64_t
reg_ptr_tmp
=
reg_ptr_global
;
mov
(
reg_ptr_
global
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
mov
(
reg_ptr_
tmp
,
reinterpret_cast
<
size_t
>
(
g_tmp_mem
));
vmovdqa
(
ptr
[
reg_ptr_
global
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_
tmp
],
ymm_int
);
vmovdqa
(
ptr
[
reg_ptr_
global
+
AVX_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vmovdqa
(
ptr
[
reg_ptr_
tmp
+
AVX_FLOAT_BLOCK
*
sizeof
(
float
)],
ymm_tmp
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_
global
],
xtmp1
);
vmovdqa
(
ptr
[
reg_ptr_
tmp
],
xtmp1
);
// next 128bits
// next 128bits
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_
global
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp1
,
ptr
[
reg_ptr_
tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)]);
vmovdqa
(
xtmp2
,
vmovdqa
(
xtmp2
,
ptr
[
reg_ptr_
global
+
ptr
[
reg_ptr_
tmp
+
(
AVX_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
(
AVX_FLOAT_BLOCK
+
4
/*xmm float block*/
)
*
sizeof
(
float
)]);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpaddd
(
xtmp1
,
xtmp1
,
xtmp2
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vpslld
(
xtmp1
,
xtmp1
,
23
);
vmovdqa
(
ptr
[
reg_ptr_
global
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
vmovdqa
(
ptr
[
reg_ptr_
tmp
+
4
/*xmm float block*/
*
sizeof
(
float
)],
xtmp1
);
// load out
// load out
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_
global
]);
vmovdqa
(
ymm_int
,
ptr
[
reg_ptr_
tmp
]);
}
}
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
vmulps
(
ymm_dst
,
ymm_dst
,
ymm_int
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
vmovups
(
ptr
[
param2
+
offset
],
ymm_dst
);
// ret();
ret
();
postCode
();
}
}
}
// namespace gen
}
// namespace gen
...
...
paddle/fluid/operators/math/jit_code.h
浏览文件 @
ee2a7f1b
...
@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode {
...
@@ -128,7 +128,6 @@ class VExpJitCode : public JitCode {
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fx
=
ymm_t
(
2
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_fy
=
ymm_t
(
3
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_mask
=
ymm_t
(
4
);
ymm_t
ymm_z
=
ymm_t
(
4
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
ymm_t
ymm_tmp
=
ymm_t
(
5
);
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录