Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
123b98f4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
123b98f4
编写于
1月 07, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine heigth and codesize and support all pool
test=develop
上级
0145f40f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
26 addition
and
36 deletion
+26
-36
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+2
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+14
-13
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+8
-20
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+2
-2
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
123b98f4
...
@@ -192,7 +192,8 @@ void BenchGRUKernel() {
...
@@ -192,7 +192,8 @@ void BenchGRUKernel() {
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
BenchSeqPoolKernel
()
{
void
BenchSeqPoolKernel
()
{
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
123b98f4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include
<stddef.h> // offsetof
#include
"paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
...
@@ -22,9 +22,6 @@ namespace operators {
...
@@ -22,9 +22,6 @@ namespace operators {
namespace
jit
{
namespace
jit
{
namespace
gen
{
namespace
gen
{
thread_local
float
ALIGN32_BEG
float_h
[
1
]
ALIGN32_END
=
{
1.
f
};
// TODO(TJ): try move to private
void
SeqPoolJitCode
::
genCode
()
{
void
SeqPoolJitCode
::
genCode
()
{
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
max_num_regs
=
8
;
constexpr
int
max_num_regs
=
8
;
...
@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() {
...
@@ -33,10 +30,17 @@ void SeqPoolJitCode::genCode() {
int
rest_num_regs
=
num_block
%
max_num_regs
;
int
rest_num_regs
=
num_block
%
max_num_regs
;
mov
(
reg32_int_h
,
dword
[
param_attr
]);
mov
(
reg32_int_h
,
dword
[
param_attr
]);
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
float_h
));
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovups
(
xmm_t
(
1
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
fild
(
dword
[
param_attr
]);
fild
(
dword
[
param_attr
]);
fstp
(
dword
[
reg_tmp
]);
fstp
(
dword
[
reg_tmp
]);
mov
(
reg32_fp_h
,
dword
[
reg_tmp
]);
vmovss
(
xmm_t
(
0
),
ptr
[
reg_tmp
]);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
xmm_t
(
0
),
xmm_t
(
0
));
}
vdivps
(
xmm_t
(
1
),
xmm_t
(
1
),
xmm_t
(
0
));
vmovss
(
ptr
[
reg_tmp
],
xmm_t
(
1
));
}
}
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
...
@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() {
...
@@ -45,7 +49,6 @@ void SeqPoolJitCode::genCode() {
if
(
rest_num_regs
>
0
)
{
if
(
rest_num_regs
>
0
)
{
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
}
}
// part of rest_w * height
// part of rest_w * height
const
int
rest
=
w_
%
block
;
const
int
rest
=
w_
%
block
;
pool_height_of_rest_width
(
rest
,
(
w_
-
rest
)
*
sizeof
(
float
),
max_num_regs
);
pool_height_of_rest_width
(
rest
,
(
w_
-
rest
)
*
sizeof
(
float
),
max_num_regs
);
...
@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
...
@@ -58,12 +61,10 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
return
platform
::
MayIUse
(
platform
::
avx
);
return
platform
::
MayIUse
(
platform
::
avx
);
}
}
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
// TODO(TJ): remove attr.h when enabled height
return
96
+
bool
yes
=
((
attr
.
w
/
YMM_FLOAT_BLOCK
+
4
/* for rest */
)
*
attr
.
type
==
SeqPoolType
::
kAvg
||
attr
.
type
==
SeqPoolType
::
kSqrt
;
4
/* load, mul and save */
+
return
96
/* basic */
+
256
)
*
((
attr
.
w
/
YMM_FLOAT_BLOCK
+
4
/* rest */
)
*
2
/* for sum */
*
(
attr
.
h
+
(
yes
?
3
:
1
/*for avg or sqrt*/
)))
*
8
;
8
;
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
123b98f4
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
#include <string>
#include <string>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
...
@@ -31,9 +30,11 @@ class SeqPoolJitCode : public JitCode {
size_t
code_size
=
256
*
1024
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
:
JitCode
(
code_size
,
code_ptr
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
if
(
type_
!=
SeqPoolType
::
kSum
)
{
if
(
!
(
type_
==
SeqPoolType
::
kSum
||
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
))
{
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
}
}
fp_h_
[
0
]
=
1.
f
;
this
->
genCode
();
this
->
genCode
();
}
}
...
@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
...
@@ -82,15 +83,8 @@ class SeqPoolJitCode : public JitCode {
L
(
l_h_done
);
L
(
l_h_done
);
// save right now
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
vmovups
(
JMM
(
max_num_regs
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
vbroadcastss
(
JMM
(
max_num_regs
),
ptr
[
reg_tmp
]);
movd
(
JMM
(
max_num_regs
+
1
),
reg32_fp_h
);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
JMM
(
max_num_regs
+
1
),
JMM
(
max_num_regs
+
1
));
}
vdivps
(
JMM
(
max_num_regs
+
2
),
JMM
(
max_num_regs
),
JMM
(
max_num_regs
+
1
));
vbroadcastss
(
JMM
(
max_num_regs
),
JMM
(
max_num_regs
+
2
));
// TODO(TJ): fix me
}
}
offset
=
w_offset
;
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
...
@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
...
@@ -144,15 +138,8 @@ class SeqPoolJitCode : public JitCode {
L
(
l_h_done
);
L
(
l_h_done
);
// save right now
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
fp_h_
));
vmovups
(
xmm_t
(
max_num_regs
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
vbroadcastss
(
xmm_t
(
max_num_regs
),
ptr
[
reg_tmp
]);
movd
(
xmm_t
(
max_num_regs
+
1
),
reg32_fp_h
);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
xmm_t
(
max_num_regs
+
1
),
xmm_t
(
max_num_regs
+
1
));
}
vdivps
(
xmm_t
(
max_num_regs
+
2
),
xmm_t
(
max_num_regs
),
xmm_t
(
max_num_regs
+
1
));
vbroadcastss
(
xmm_t
(
max_num_regs
),
xmm_t
(
max_num_regs
+
2
));
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
));
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
));
}
}
...
@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
...
@@ -206,6 +193,7 @@ class SeqPoolJitCode : public JitCode {
}
}
private:
private:
float
ALIGN32_BEG
fp_h_
[
1
]
ALIGN32_END
;
int
w_
;
int
w_
;
SeqPoolType
type_
;
SeqPoolType
type_
;
reg64_t
param_src
{
abi_param1
};
reg64_t
param_src
{
abi_param1
};
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
123b98f4
...
@@ -436,8 +436,8 @@ void TestGRUKernel() {
...
@@ -436,8 +436,8 @@ void TestGRUKernel() {
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
template
<
paddle
::
operators
::
jit
::
KernelType
KT
,
typename
T
,
typename
PlaceType
>
void
TestSeqPoolKernel
()
{
void
TestSeqPoolKernel
()
{
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
VLOG
(
10
)
<<
"===== Test JITKernel "
<<
jit
::
to_string
(
KT
);
// TODO(TJ): support more
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
jit
::
SeqPoolType
::
kSum
,
jit
::
SeqPoolType
::
kAvg
,
jit
::
SeqPoolType
::
kSqrt
};
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录