Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e0591dee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e0591dee
编写于
1月 04, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance seqpool jitcode
上级
92201d39
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
126 addition
and
67 deletion
+126
-67
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+2
-2
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+2
-53
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+122
-12
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
e0591dee
...
@@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
...
@@ -194,8 +194,8 @@ template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void
BenchSeqPoolKernel
()
{
void
BenchSeqPoolKernel
()
{
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
h
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
for
(
int
h
:
TestSizes
())
{
const
jit
::
seq_pool_attr_t
attr
(
h
,
w
,
type
);
const
jit
::
seq_pool_attr_t
attr
(
h
,
w
,
type
);
std
::
vector
<
T
>
x
(
h
*
w
),
y
(
w
);
std
::
vector
<
T
>
x
(
h
*
w
),
y
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
e0591dee
...
@@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() {
...
@@ -35,7 +35,6 @@ void SeqPoolJitCode::genCode() {
mov
(
reg32_scalar
,
scalar
);
mov
(
reg32_scalar
,
scalar
);
}
}
// TODO(TJ): make height load from params
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
...
@@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() {
...
@@ -44,59 +43,9 @@ void SeqPoolJitCode::genCode() {
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
}
}
//
rest par
t
//
part of rest_w * heigh
t
const
int
rest
=
w_
%
block
;
const
int
rest
=
w_
%
block
;
const
bool
has_block4
=
rest
/
4
>
0
;
pool_height_of_rest_width
(
rest
,
(
w_
-
rest
)
*
sizeof
(
float
),
max_num_regs
);
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
const
int
w_offset
=
num_block
*
YMM_FLOAT_BLOCK
*
sizeof
(
float
);
for
(
int
h
=
0
;
h
<
h_
;
++
h
)
{
int
offset
=
h
*
w_
*
sizeof
(
float
)
+
w_offset
;
const
int
shift_regs
=
(
h
==
0
)
?
0
:
max_num_regs
;
int
reg_idx
=
0
;
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
reg_idx
++
;
}
rest_num_regs
=
reg_idx
;
if
(
h
>
0
)
{
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
}
}
// save right now
int
offset
=
w_offset
;
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
xmm_t
(
max_num_regs
-
1
),
reg32_scalar
);
for
(
int
i
=
0
;
i
<
rest_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
-
1
));
}
}
int
reg_idx
=
0
;
if
(
has_block4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
}
ret
();
ret
();
}
}
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
e0591dee
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <string>
#include <string>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode {
...
@@ -45,8 +46,6 @@ class SeqPoolJitCode : public JitCode {
base
+=
"_Sqrt"
;
base
+=
"_Sqrt"
;
}
}
base
+=
(
"_W"
+
std
::
to_string
(
w_
));
base
+=
(
"_W"
+
std
::
to_string
(
w_
));
// TODO(TJ): make h load from params
base
+=
(
"_H"
+
std
::
to_string
(
h_
));
return
base
.
c_str
();
return
base
.
c_str
();
}
}
void
genCode
()
override
;
void
genCode
()
override
;
...
@@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode {
...
@@ -54,25 +53,36 @@ class SeqPoolJitCode : public JitCode {
protected:
protected:
template
<
typename
JMM
>
template
<
typename
JMM
>
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
for
(
int
h
=
0
;
h
<
h_
;
++
h
)
{
int
offset
=
w_offset
;
int
offset
=
h
*
w_
*
sizeof
(
float
)
+
w_offset
;
const
int
shift_regs
=
(
h
==
0
)
?
0
:
max_num_regs
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
+
shift_regs
),
ptr
[
param1
+
offset
]);
vmovups
(
JMM
(
i
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
block
;
offset
+=
sizeof
(
float
)
*
block
;
}
}
if
(
h
>
0
)
{
if
(
h_
>
1
)
{
// sum anyway
Label
l_next_h
;
mov
(
reg_h
,
1
);
mov
(
reg_tmp
,
param1
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
{
mov
(
reg_ptr_src_i
,
reg_tmp
);
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
// sum anyway
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
block
);
}
}
inc
(
reg_h
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h
,
h_
);
jl
(
l_next_h
,
T_NEAR
);
}
}
}
}
// save right now
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
JMM
(
max_num_regs
),
reg32_scalar
);
vbroadcastss
(
JMM
(
max_num_regs
),
reg32_scalar
);
}
}
int
offset
=
w_offset
;
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
...
@@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode {
...
@@ -82,6 +92,102 @@ class SeqPoolJitCode : public JitCode {
}
}
}
}
void
pool_height_of_rest_width
(
int
rest
,
int
w_offset
,
int
max_num_regs
)
{
const
int
rest_used_num_regs
=
load_rest
(
rest
,
w_offset
,
0
);
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
if
(
h_
>
1
)
{
Label
l_next_h
;
mov
(
reg_h
,
1
);
mov
(
reg_tmp
,
param1
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
{
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
// max_num_regs);
int
reg_idx
=
0
;
mov
(
reg_ptr_src_i
,
reg_tmp
);
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
4
);
reg_idx
++
;
}
if
(
has_block2
)
{
vmovups
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
2
);
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
reg_idx
++
;
}
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
"All heights should use same regs"
);
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
inc
(
reg_h
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h
,
h_
);
jl
(
l_next_h
,
T_NEAR
);
}
}
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
xmm_t
(
max_num_regs
-
1
),
reg32_scalar
);
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
-
1
));
}
}
save_rest
(
rest
,
w_offset
);
}
// return the number of used regs, use start from reg 0
int
load_rest
(
int
rest
,
int
w_offset
,
const
int
num_shift_regs
,
const
int
reg_start
=
0
)
{
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param1
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param1
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param1
+
w_offset
]);
reg_idx
++
;
}
return
reg_idx
;
}
// use reg start from 0
void
save_rest
(
int
rest
,
int
w_offset
,
int
reg_start
=
0
)
{
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
vmovups
(
ptr
[
param2
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
ptr
[
param2
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
ptr
[
param2
+
w_offset
],
xmm_t
(
reg_idx
));
}
}
private:
private:
int
h_
;
int
h_
;
int
w_
;
int
w_
;
...
@@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode {
...
@@ -90,6 +196,10 @@ class SeqPoolJitCode : public JitCode {
reg64_t
param2
{
abi_param2
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
reg64_t
param3
{
abi_param3
};
reg32_t
reg32_scalar
{
r8d
};
reg32_t
reg32_scalar
{
r8d
};
reg64_t
reg_h
{
r9
};
reg64_t
reg_ptr_src_i
{
r10
};
reg64_t
reg_tmp
{
r11
};
};
};
}
// namespace gen
}
// namespace gen
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录