Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
dbc8f893
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dbc8f893
编写于
9月 18, 2019
作者:
石
石晓伟
提交者:
GitHub
9月 18, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the device binding logic of the pass, test=develop (#2060)
上级
3682a9df
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
104 addition
and
18 deletion
+104
-18
lite/core/mir/fusion/conv_activation_fuse_pass.cc
lite/core/mir/fusion/conv_activation_fuse_pass.cc
+2
-1
lite/core/mir/fusion/conv_bn_fuse_pass.cc
lite/core/mir/fusion/conv_bn_fuse_pass.cc
+2
-1
lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
+2
-1
lite/core/mir/fusion/fc_fuse_pass.cc
lite/core/mir/fusion/fc_fuse_pass.cc
+2
-1
lite/core/mir/fusion/shuffle_channel_fuse_pass.cc
lite/core/mir/fusion/shuffle_channel_fuse_pass.cc
+2
-1
lite/core/mir/io_copy_kernel_pick_pass.cc
lite/core/mir/io_copy_kernel_pick_pass.cc
+2
-1
lite/core/mir/pass_registry.h
lite/core/mir/pass_registry.h
+5
-0
lite/core/mir/pass_utils.cc
lite/core/mir/pass_utils.cc
+64
-3
lite/core/mir/pass_utils.h
lite/core/mir/pass_utils.h
+4
-0
lite/core/mir/subgraph/generate_npu_program_pass.cc
lite/core/mir/subgraph/generate_npu_program_pass.cc
+1
-1
lite/core/mir/type_layout_cast_pass.cc
lite/core/mir/type_layout_cast_pass.cc
+3
-1
lite/core/mir/type_precision_cast_pass.cc
lite/core/mir/type_precision_cast_pass.cc
+3
-1
lite/core/mir/type_target_cast_pass.cc
lite/core/mir/type_target_cast_pass.cc
+3
-1
lite/core/op_registry.h
lite/core/op_registry.h
+7
-3
lite/core/optimizer.h
lite/core/optimizer.h
+2
-2
未找到文件。
lite/core/mir/fusion/conv_activation_fuse_pass.cc
浏览文件 @
dbc8f893
...
@@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -39,4 +39,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS
(
lite_conv_activation_fuse_pass
,
REGISTER_MIR_PASS
(
lite_conv_activation_fuse_pass
,
paddle
::
lite
::
mir
::
ConvActivationFusePass
)
paddle
::
lite
::
mir
::
ConvActivationFusePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"conv2d"
);
lite/core/mir/fusion/conv_bn_fuse_pass.cc
浏览文件 @
dbc8f893
...
@@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -35,4 +35,5 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// namespace paddle
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_conv_bn_fuse_pass
,
paddle
::
lite
::
mir
::
ConvBNFusePass
)
REGISTER_MIR_PASS
(
lite_conv_bn_fuse_pass
,
paddle
::
lite
::
mir
::
ConvBNFusePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"elementwise_add"
);
lite/core/mir/fusion/elementwise_add_activation_fuse_pass.cc
浏览文件 @
dbc8f893
...
@@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply(
...
@@ -34,4 +34,5 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
,
REGISTER_MIR_PASS
(
lite_elementwise_add_activation_fuse_pass
,
paddle
::
lite
::
mir
::
ElementwiseAddActivationFusePass
)
paddle
::
lite
::
mir
::
ElementwiseAddActivationFusePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"fusion_elementwise_add_activation"
);
lite/core/mir/fusion/fc_fuse_pass.cc
浏览文件 @
dbc8f893
...
@@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -32,4 +32,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
// namespace paddle
}
// namespace paddle
REGISTER_MIR_PASS
(
lite_fc_fuse_pass
,
paddle
::
lite
::
mir
::
FcFusePass
)
REGISTER_MIR_PASS
(
lite_fc_fuse_pass
,
paddle
::
lite
::
mir
::
FcFusePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"fc"
);
lite/core/mir/fusion/shuffle_channel_fuse_pass.cc
浏览文件 @
dbc8f893
...
@@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
...
@@ -36,4 +36,5 @@ void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS
(
lite_shuffle_channel_fuse_pass
,
REGISTER_MIR_PASS
(
lite_shuffle_channel_fuse_pass
,
paddle
::
lite
::
mir
::
ShuffleChannelFusePass
)
paddle
::
lite
::
mir
::
ShuffleChannelFusePass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"shuffle_channel"
);
lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
dbc8f893
...
@@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass {
...
@@ -72,4 +72,5 @@ class IoCopyKernelPickPass : public StmtPass {
REGISTER_MIR_PASS
(
io_copy_kernel_pick_pass
,
REGISTER_MIR_PASS
(
io_copy_kernel_pick_pass
,
paddle
::
lite
::
mir
::
IoCopyKernelPickPass
)
paddle
::
lite
::
mir
::
IoCopyKernelPickPass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"io_copy"
);
lite/core/mir/pass_registry.h
浏览文件 @
dbc8f893
...
@@ -39,6 +39,11 @@ class PassRegistry {
...
@@ -39,6 +39,11 @@ class PassRegistry {
pass_
->
BindKernel
(
name
,
place
);
pass_
->
BindKernel
(
name
,
place
);
return
*
this
;
return
*
this
;
}
}
PassRegistry
&
BindKernel
(
const
std
::
string
&
name
)
{
pass_
->
BindKernel
(
name
,
Place
(
TARGET
(
kAny
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
)));
return
*
this
;
}
bool
Touch
()
const
{
return
true
;
}
bool
Touch
()
const
{
return
true
;
}
private:
private:
...
...
lite/core/mir/pass_utils.cc
浏览文件 @
dbc8f893
...
@@ -16,10 +16,72 @@
...
@@ -16,10 +16,72 @@
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
using
lite_api
::
Place
;
namespace
{
template
<
typename
T
>
class
Types
final
{
public:
explicit
Types
(
const
std
::
set
<
T
>&
types
)
:
types_
(
types
)
{}
~
Types
()
=
default
;
std
::
set
<
T
>
ValidSet
(
const
T
&
element
)
const
;
private:
const
std
::
set
<
T
>
types_
;
};
template
<
typename
T
>
std
::
set
<
T
>
Types
<
T
>::
ValidSet
(
const
T
&
element
)
const
{
if
(
element
==
T
::
kAny
)
{
return
types_
;
}
else
if
(
element
==
T
::
kUnk
)
{
LOG
(
FATAL
)
<<
"The type of the kernel's place is unknown."
;
}
return
std
::
set
<
T
>
({
element
});
}
bool
ExpandPlaces
(
std
::
set
<
Place
>*
places
,
const
Place
&
place
)
{
static
const
Types
<
TargetType
>
target_set
({
TARGET
(
kHost
),
TARGET
(
kX86
),
TARGET
(
kCUDA
),
TARGET
(
kARM
),
TARGET
(
kOpenCL
),
TARGET
(
kNPU
),
TARGET
(
kFPGA
)});
static
const
Types
<
PrecisionType
>
precision_set
(
{
PRECISION
(
kFloat
),
PRECISION
(
kInt8
),
PRECISION
(
kFP16
),
PRECISION
(
kAny
)});
static
const
Types
<
DataLayoutType
>
layout_set
(
{
DATALAYOUT
(
kNCHW
),
DATALAYOUT
(
kAny
),
DATALAYOUT
(
kNHWC
)});
for
(
const
auto
&
target
:
target_set
.
ValidSet
(
place
.
target
))
{
for
(
const
auto
&
precision
:
precision_set
.
ValidSet
(
place
.
precision
))
{
for
(
const
auto
&
layout
:
layout_set
.
ValidSet
(
place
.
layout
))
{
places
->
insert
(
Place
(
target
,
precision
,
layout
));
}
}
}
}
}
// anonymous namespace
bool
KernelRegistered
(
const
std
::
string
name
,
const
Place
&
place
)
{
std
::
set
<
Place
>
places
;
ExpandPlaces
(
&
places
,
place
);
for
(
const
auto
&
p
:
places
)
{
if
(
!
KernelRegistry
::
Global
()
.
Create
(
name
,
p
.
target
,
p
.
precision
,
p
.
layout
)
.
empty
())
{
return
true
;
}
}
return
false
;
}
bool
PassMatchesTarget
(
const
mir
::
Pass
&
pass
,
TargetType
target
)
{
bool
PassMatchesTarget
(
const
mir
::
Pass
&
pass
,
TargetType
target
)
{
const
auto
&
targets
=
pass
.
Targets
();
const
auto
&
targets
=
pass
.
Targets
();
if
(
targets
.
find
(
TARGET
(
kAny
))
!=
targets
.
end
())
return
true
;
if
(
targets
.
find
(
TARGET
(
kAny
))
!=
targets
.
end
())
return
true
;
...
@@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) {
...
@@ -30,10 +92,9 @@ bool PassMatchesKernels(const mir::Pass& pass) {
const
auto
&
kernels
=
pass
.
GetBoundKernels
();
const
auto
&
kernels
=
pass
.
GetBoundKernels
();
for
(
const
auto
&
kernel
:
kernels
)
{
for
(
const
auto
&
kernel
:
kernels
)
{
for
(
const
auto
&
place
:
kernel
.
second
)
{
for
(
const
auto
&
place
:
kernel
.
second
)
{
if
(
KernelRegistry
::
Global
()
if
(
!
KernelRegistered
(
kernel
.
first
,
place
))
{
.
Create
(
kernel
.
first
,
place
.
target
,
place
.
precision
,
place
.
layout
)
.
empty
())
return
false
;
return
false
;
}
}
}
}
}
return
true
;
return
true
;
...
...
lite/core/mir/pass_utils.h
浏览文件 @
dbc8f893
...
@@ -14,11 +14,15 @@
...
@@ -14,11 +14,15 @@
#pragma once
#pragma once
#include <string>
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
// Query if the specified kernel has been registered.
bool
KernelRegistered
(
const
std
::
string
name
,
const
Place
&
place
);
// Check if the pass hits the hardware target.
// Check if the pass hits the hardware target.
bool
PassMatchesTarget
(
const
mir
::
Pass
&
pass
,
TargetType
target
);
bool
PassMatchesTarget
(
const
mir
::
Pass
&
pass
,
TargetType
target
);
...
...
lite/core/mir/subgraph/generate_npu_program_pass.cc
浏览文件 @
dbc8f893
...
@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
...
@@ -215,4 +215,4 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS
(
generate_npu_program_pass
,
REGISTER_MIR_PASS
(
generate_npu_program_pass
,
paddle
::
lite
::
mir
::
subgraph
::
GenerateNPUProgramPass
)
paddle
::
lite
::
mir
::
subgraph
::
GenerateNPUProgramPass
)
.
BindTargets
({
TARGET
(
k
Any
)});
.
BindTargets
({
TARGET
(
k
NPU
)});
lite/core/mir/type_layout_cast_pass.cc
浏览文件 @
dbc8f893
...
@@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces(
...
@@ -174,4 +174,6 @@ void TypeLayoutTransformPass::SetValidPlaces(
REGISTER_MIR_PASS
(
type_layout_cast_pass
,
REGISTER_MIR_PASS
(
type_layout_cast_pass
,
paddle
::
lite
::
mir
::
TypeLayoutTransformPass
)
paddle
::
lite
::
mir
::
TypeLayoutTransformPass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"layout_once"
)
.
BindKernel
(
"layout"
);
lite/core/mir/type_precision_cast_pass.cc
浏览文件 @
dbc8f893
...
@@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
...
@@ -180,4 +180,6 @@ void PrecisionCastPass::SetValidPlaces(const std::vector<Place>& valid_places) {
REGISTER_MIR_PASS
(
type_precision_cast_pass
,
REGISTER_MIR_PASS
(
type_precision_cast_pass
,
paddle
::
lite
::
mir
::
PrecisionCastPass
)
paddle
::
lite
::
mir
::
PrecisionCastPass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"calib_once"
)
.
BindKernel
(
"calib"
);
lite/core/mir/type_target_cast_pass.cc
浏览文件 @
dbc8f893
...
@@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces(
...
@@ -180,4 +180,6 @@ void TypeTargetTransformPass::SetValidPlaces(
REGISTER_MIR_PASS
(
type_target_cast_pass
,
REGISTER_MIR_PASS
(
type_target_cast_pass
,
paddle
::
lite
::
mir
::
TypeTargetTransformPass
)
paddle
::
lite
::
mir
::
TypeTargetTransformPass
)
.
BindTargets
({
TARGET
(
kAny
)});
.
BindTargets
({
TARGET
(
kAny
)})
.
BindKernel
(
"io_copy_once"
)
.
BindKernel
(
"io_copy"
);
lite/core/op_registry.h
浏览文件 @
dbc8f893
...
@@ -174,9 +174,13 @@ class KernelRegistry final {
...
@@ -174,9 +174,13 @@ class KernelRegistry final {
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
)
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()]
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
kernel_list
;
.
template
get
<
kernel_registor_t
*
>()
if
(
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()].
valid
())
{
->
Creates
(
op_type
);
kernel_list
=
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Creates
(
op_type
);
}
return
kernel_list
;
}
}
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
,
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
Create
(
const
std
::
string
&
op_type
,
...
...
lite/core/optimizer.h
浏览文件 @
dbc8f893
...
@@ -193,9 +193,9 @@ class Optimizer {
...
@@ -193,9 +193,9 @@ class Optimizer {
matched
=
true
;
matched
=
true
;
}
}
}
}
matched
=
matched
||
PassMatchesKernels
(
*
pass
);
matched
=
matched
&&
PassMatchesKernels
(
*
pass
);
if
(
!
matched
)
{
if
(
!
matched
)
{
LOG
(
INFO
)
<<
"
Skip "
<<
x
<<
" pass
because the target does not match."
;
LOG
(
INFO
)
<<
"
- Skip "
<<
x
<<
"
because the target does not match."
;
}
else
{
}
else
{
pass
->
Apply
(
graph_
);
pass
->
Apply
(
graph_
);
LOG
(
INFO
)
<<
"== Finished running: "
<<
x
;
LOG
(
INFO
)
<<
"== Finished running: "
<<
x
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录