Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c50060bb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c50060bb
编写于
12月 29, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add jitcode impl and use it
上级
142bb417
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
239 addition
and
5 deletion
+239
-5
paddle/fluid/operators/jit/gen/CMakeLists.txt
paddle/fluid/operators/jit/gen/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+132
-0
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+98
-0
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+5
-2
paddle/fluid/operators/math/sequence_pooling.cc
paddle/fluid/operators/math/sequence_pooling.cc
+3
-3
未找到文件。
paddle/fluid/operators/jit/gen/CMakeLists.txt
浏览文件 @
c50060bb
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN
(
kGRUHtPart1
)
USE_JITKERNEL_GEN
(
kGRUHtPart1
)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kGRUHtPart2
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
USE_JITKERNEL_GEN
(
kNCHW16CMulNC
)
USE_JITKERNEL_GEN
(
kSeqPool
)
paddle/fluid/operators/jit/gen/seqpool.cc
0 → 100644
浏览文件 @
c50060bb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
void
SeqPoolJitCode
::
genCode
()
{
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
max_num_regs
=
8
;
const
int
num_block
=
w_
/
block
;
const
int
num_groups
=
num_block
/
max_num_regs
;
int
rest_num_regs
=
num_block
%
max_num_regs
;
if
(
type_
==
SeqPoolType
::
kAvg
)
{
float
scalar
=
1.
f
/
h_
;
mov
(
reg32_scalar
,
scalar
);
}
else
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
float
scalar
=
1.
f
/
std
::
sqrt
(
static_cast
<
float
>
(
h_
));
mov
(
reg32_scalar
,
scalar
);
}
// TODO(TJ): make height load from params
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
}
if
(
rest_num_regs
>
0
)
{
pool_height
<
ymm_t
>
(
num_groups
*
group_len
,
block
,
rest_num_regs
);
}
// rest part
const
int
rest
=
w_
%
block
;
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
const
int
w_offset
=
num_block
*
YMM_FLOAT_BLOCK
*
sizeof
(
float
);
for
(
int
h
=
0
;
h
<
h_
;
++
h
)
{
int
offset
=
h
*
w_
*
sizeof
(
float
)
+
w_offset
;
const
int
shift_regs
=
(
h
==
0
)
?
0
:
max_num_regs
;
int
reg_idx
=
0
;
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
shift_regs
),
ptr
[
param1
+
offset
]);
reg_idx
++
;
}
rest_num_regs
=
reg_idx
;
if
(
h
>
0
)
{
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
}
}
// save right now
int
offset
=
w_offset
;
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
xmm_t
(
max_num_regs
-
1
),
reg32_scalar
);
for
(
int
i
=
0
;
i
<
rest_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
-
1
));
}
}
int
reg_idx
=
0
;
if
(
has_block4
)
{
vmovups
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
}
if
(
has_block2
)
{
vmovq
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
}
if
(
has_block1
)
{
vmovss
(
ptr
[
param2
+
offset
],
xmm_t
(
reg_idx
));
}
ret
();
}
class
SeqPoolCreator
:
public
JitCodeCreator
<
seq_pool_attr_t
>
{
public:
bool
UseMe
(
const
seq_pool_attr_t
&
attr
)
const
override
{
return
platform
::
MayIUse
(
platform
::
avx
);
}
size_t
CodeSize
(
const
seq_pool_attr_t
&
attr
)
const
override
{
// TODO(TJ): remove attr.h when enabled height
bool
yes
=
attr
.
type
==
SeqPoolType
::
kAvg
||
attr
.
type
==
SeqPoolType
::
kSqrt
;
return
96
/* basic */
+
((
attr
.
w
/
YMM_FLOAT_BLOCK
+
4
/* rest */
)
*
2
/* for sum */
*
(
attr
.
h
+
(
yes
?
3
:
1
/*for avg or sqrt*/
)))
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
);
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
namespace
gen
=
paddle
::
operators
::
jit
::
gen
;
REGISTER_JITKERNEL_GEN
(
kSeqPool
,
gen
::
SeqPoolCreator
);
paddle/fluid/operators/jit/gen/seqpool.h
0 → 100644
浏览文件 @
c50060bb
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace
paddle
{
namespace
operators
{
namespace
jit
{
namespace
gen
{
class
SeqPoolJitCode
:
public
JitCode
{
public:
explicit
SeqPoolJitCode
(
const
seq_pool_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
h_
(
attr
.
h
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
if
(
type_
!=
SeqPoolType
::
kSum
)
{
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
}
this
->
genCode
();
}
virtual
const
char
*
name
()
const
{
std
::
string
base
=
"SeqPoolJitCode"
;
if
(
type_
==
SeqPoolType
::
kSum
)
{
base
+=
"_Sum"
;
}
else
if
(
type_
==
SeqPoolType
::
kAvg
)
{
base
+=
"_Avg"
;
}
else
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
base
+=
"_Sqrt"
;
}
base
+=
(
"_W"
+
std
::
to_string
(
w_
));
// TODO(TJ): make h load from params
base
+=
(
"_H"
+
std
::
to_string
(
h_
));
return
base
.
c_str
();
}
void
genCode
()
override
;
protected:
template
<
typename
JMM
>
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
for
(
int
h
=
0
;
h
<
h_
;
++
h
)
{
int
offset
=
h
*
w_
*
sizeof
(
float
)
+
w_offset
;
const
int
shift_regs
=
(
h
==
0
)
?
0
:
max_num_regs
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
+
shift_regs
),
ptr
[
param1
+
offset
]);
offset
+=
sizeof
(
float
)
*
block
;
}
if
(
h
>
0
)
{
// sum anyway
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
}
}
}
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
JMM
(
max_num_regs
),
reg32_scalar
);
}
int
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
}
vmovups
(
ptr
[
param2
+
offset
],
JMM
(
i
));
offset
+=
sizeof
(
float
)
*
block
;
}
}
private:
int
h_
;
int
w_
;
SeqPoolType
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
reg32_t
reg32_scalar
{
r8d
};
};
}
// namespace gen
}
// namespace jit
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
c50060bb
...
@@ -44,8 +44,11 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
...
@@ -44,8 +44,11 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
template
<
>
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
static_cast
<
size_t
>
(
attr
.
type
);
size_t
key
=
attr
.
w
;
return
key
+
(
attr
.
w
<<
act_type_shift
);
// TODO(TJ): support height, then removed it from key
constexpr
int
w_shift
=
30
;
return
(
key
<<
act_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
)
+
(
static_cast
<
size_t
>
(
attr
.
h
)
<<
(
act_type_shift
+
w_shift
));
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/math/sequence_pooling.cc
浏览文件 @
c50060bb
...
@@ -255,11 +255,11 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
...
@@ -255,11 +255,11 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
jit
::
seq_pool_attr_t
attr
;
jit
::
seq_pool_attr_t
attr
;
attr
.
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
attr
.
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
attr
.
type
=
jit
::
SeqPoolType
::
kSum
;
attr
.
type
=
jit
::
SeqPoolType
::
kSum
;
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
seqpool
(
src
,
dst
,
&
attr
);
seqpool
(
src
,
dst
,
&
attr
);
dst
+=
attr
.
w
;
dst
+=
attr
.
w
;
src
+=
attr
.
h
*
attr
.
w
;
src
+=
attr
.
h
*
attr
.
w
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录