Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
8b950a4f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8b950a4f
编写于
4月 15, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mir implementation
上级
f41d73b3
变更
27
隐藏空白更改
内联
并排
Showing
27 changed file
with
634 addition
and
65 deletion
+634
-65
paddle/fluid/lite/core/executor.h
paddle/fluid/lite/core/executor.h
+1
-1
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+0
-5
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+23
-6
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+7
-2
paddle/fluid/lite/core/mir/demo_pass.cc
paddle/fluid/lite/core/mir/demo_pass.cc
+33
-0
paddle/fluid/lite/core/mir/node.cc
paddle/fluid/lite/core/mir/node.cc
+14
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+78
-6
paddle/fluid/lite/core/mir/pass.cc
paddle/fluid/lite/core/mir/pass.cc
+14
-0
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+37
-0
paddle/fluid/lite/core/mir/pass_manager.cc
paddle/fluid/lite/core/mir/pass_manager.cc
+30
-0
paddle/fluid/lite/core/mir/pass_manager.h
paddle/fluid/lite/core/mir/pass_manager.h
+80
-0
paddle/fluid/lite/core/mir/pass_manager_test.cc
paddle/fluid/lite/core/mir/pass_manager_test.cc
+30
-0
paddle/fluid/lite/core/mir/pass_registry.cc
paddle/fluid/lite/core/mir/pass_registry.cc
+21
-0
paddle/fluid/lite/core/mir/pass_registry.h
paddle/fluid/lite/core/mir/pass_registry.h
+37
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+14
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+74
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+10
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+34
-10
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+43
-0
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+38
-27
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+5
-2
paddle/fluid/lite/operators/fc_op.h
paddle/fluid/lite/operators/fc_op.h
+4
-1
paddle/fluid/lite/operators/fc_op_test.cc
paddle/fluid/lite/operators/fc_op_test.cc
+1
-1
paddle/fluid/lite/operators/mul_op.h
paddle/fluid/lite/operators/mul_op.h
+2
-1
paddle/fluid/lite/operators/relu_op.cc
paddle/fluid/lite/operators/relu_op.cc
+1
-1
paddle/fluid/lite/operators/relu_op.h
paddle/fluid/lite/operators/relu_op.h
+1
-1
paddle/fluid/lite/operators/scale_op.cc
paddle/fluid/lite/operators/scale_op.cc
+2
-1
未找到文件。
paddle/fluid/lite/core/executor.h
浏览文件 @
8b950a4f
...
...
@@ -52,7 +52,7 @@ class Executor {
ops_
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
ops_
.
back
()
->
PickKernel
(
valid_places_
);
ops_
.
back
()
->
Attach
(
*
op_desc
,
exec_scope_
);
ops_
.
back
()
->
Attach
Impl
(
*
op_desc
,
exec_scope_
);
}
}
...
...
paddle/fluid/lite/core/kernel.cc
浏览文件 @
8b950a4f
...
...
@@ -17,11 +17,6 @@
namespace
paddle
{
namespace
lite
{
bool
operator
==
(
const
Place
&
a
,
const
Place
&
b
)
{
return
a
.
target
==
b
.
target
&&
a
.
precision
==
b
.
precision
&&
a
.
layout
==
b
.
layout
;
}
bool
operator
<
(
const
Place
&
a
,
const
Place
&
b
)
{
if
(
a
.
target
!=
b
.
target
)
return
a
.
target
<
b
.
target
;
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
8b950a4f
...
...
@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
#include "paddle/fluid/lite/core/type_system.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
...
...
@@ -51,6 +52,7 @@ class KernelBase {
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
~
KernelBase
()
=
default
;
...
...
@@ -66,17 +68,21 @@ class KernelBase {
* registered in the `TypeSystem`.
*/
struct
ParamType
{
// For unsupported types.
size_t
element_type_hash
{};
Place
tensor_place
{};
const
Type
*
type_
;
ParamType
()
=
default
;
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
};
/*
* The data types of kernel parameters.
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct
ParamTypes
{
std
::
vector
<
std
::
vector
<
ParamType
>>
inputs
;
...
...
@@ -115,6 +121,8 @@ struct ParamTypes {
*/
class
ParamTypeRegistry
{
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
/*
...
...
@@ -130,7 +138,12 @@ class ParamTypeRegistry {
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
(
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
offset
,
ptype
);
return
*
this
;
}
NewInstance
&
BindOutput
(
int
offset
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kOutput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
offset
,
ptype
);
return
*
this
;
}
...
...
@@ -141,8 +154,12 @@ class ParamTypeRegistry {
std
::
string
kernel_type_
;
};
template
<
IO
io
>
void
Register
(
const
std
::
string
&
kernel_type
,
const
Place
&
place
,
int
offset
,
ParamType
data_type
)
{}
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
offset
};
types_
[
key
]
=
data_type
;
}
ParamType
Retrive
(
const
Place
&
place
,
int
offset
);
...
...
@@ -155,16 +172,15 @@ class ParamTypeRegistry {
ParamTypeRegistry
()
=
default
;
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
// Identification for a Kernel.
struct
KernelIdT
{
struct
KernelIdT
y
{
std
::
string
kernel_type
;
Place
place
;
IO
io
;
int
offset
;
};
using
key_t
=
KernelIdT
;
using
key_t
=
KernelIdT
y
;
struct
KeyCmp
{
bool
operator
()(
const
key_t
&
a
,
const
key_t
&
b
)
const
;
};
...
...
@@ -188,6 +204,7 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
void
Touch
()
{}
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
8b950a4f
cc_library
(
mir_pass SRCS pass.cc
)
cc_library
(
mir_node SRCS node.cc
)
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc
)
\ No newline at end of file
cc_library
(
mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_demo_pass SRCS demo_pass.cc DEPS mir_pass
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass
)
paddle/fluid/lite/core/mir/demo_pass.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
DemoPass
:
public
mir
::
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{}
};
bool
RegisterDemoPass
()
{
return
PassManager
::
Global
().
AddNewPass
(
"demo"
,
new
DemoPass
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
paddle/fluid/lite/core/mir/node.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
// Node in a MIR graph.
class
Node
{
public:
// Tell is instruction.
bool
IsInstruct
()
const
;
// Tell is an argument.
bool
IsArgument
()
const
;
};
std
::
list
<
Node
*>
inlinks
;
std
::
list
<
Node
*>
outlinks
;
Node
()
=
default
;
enum
class
Role
{
kUnk
=
-
1
,
kArgument
,
kInstruct
,
kNumRoles
/*should be last*/
};
struct
Instruct
{
std
::
string
op_type
;
Place
place
;
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
};
struct
Argument
{
std
::
string
name
;
Place
place
;
};
// Set roles.
Argument
&
AsArgument
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kArgument
);
return
*
argument_
;
}
role_
=
Role
::
kArgument
;
argument_
.
reset
(
new
Argument
);
return
*
argument_
;
}
Instruct
&
AsInstruct
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kInstruct
);
return
*
instruct_
;
}
role_
=
Role
::
kInstruct
;
instruct_
.
reset
(
new
Instruct
);
return
*
instruct_
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
==
Role
::
kUnk
;
}
bool
IsInstruct
()
const
{
return
role_
==
Role
::
kInstruct
;
}
bool
IsArgument
()
const
{
return
role_
==
Role
::
kArgument
;
}
private:
// Either instruct_ or argument_ is used.
std
::
unique_ptr
<
Instruct
>
instruct_
;
std
::
unique_ptr
<
Argument
>
argument_
;
Role
role_
{
Role
::
kUnk
};
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
\ No newline at end of file
}
// namespace paddle
paddle/fluid/lite/core/mir/pass.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
Pass
{
public:
virtual
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
=
0
;
const
std
::
string
&
name
()
const
{
return
name_
;
}
virtual
~
Pass
()
=
default
;
private:
std
::
string
name_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
PassManager
::
PassManager
()
{}
// Manually register here.
extern
bool
RegisterDemoPass
();
static
bool
xx
__attribute__
((
unused
))
=
RegisterDemoPass
();
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager.h
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
PassManager
{
public:
static
PassManager
&
Global
()
{
static
PassManager
x
;
return
x
;
}
PassManager
();
void
Run
()
{
for
(
auto
&
pass
:
passes_
)
{
LOG
(
INFO
)
<<
"Running MIR pass "
<<
pass
->
name
();
pass
->
Apply
(
graph_
);
}
}
bool
AddNewPass
(
const
std
::
string
&
name
,
Pass
*
pass
)
{
passes_
.
emplace_back
(
pass
);
pass_map_
.
emplace
(
name
,
passes_
.
back
().
get
());
return
true
;
}
// Clear all the passes.
void
Clear
()
{
passes_
.
clear
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
iterator
passes_begin
()
{
return
passes_
.
begin
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
iterator
passes_end
()
{
return
passes_
.
end
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
const_iterator
passes_const_begin
()
const
{
return
passes_
.
begin
();
}
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>::
const_iterator
passes_const_end
()
const
{
return
passes_
.
end
();
}
Pass
*
LookUp
(
const
std
::
string
&
key
)
{
auto
it
=
pass_map_
.
find
(
key
);
CHECK
(
it
!=
pass_map_
.
end
());
return
it
->
second
;
}
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
list
<
std
::
unique_ptr
<
mir
::
Pass
>>
passes_
;
std
::
map
<
std
::
string
,
mir
::
Pass
*>
pass_map_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_manager_test.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
PassManager
,
test
)
{
auto
*
pass
=
PassManager
::
Global
().
LookUp
(
"demo"
);
LOG
(
INFO
)
<<
"pass: "
<<
pass
;
ASSERT_TRUE
(
pass
!=
nullptr
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_registry.cc
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/pass_registry.h
0 → 100644
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
PassRegistry
{
public:
PassRegistry
(
const
std
::
string
&
name
,
mir
::
Pass
*
pass
)
{
LOG
(
INFO
)
<<
"Registry add MIR pass "
<<
name
;
PassManager
::
Global
().
AddNewPass
(
name
,
pass
);
}
bool
Touch
()
const
{
return
true
;
}
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
8b950a4f
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
// A program is used to represent a code program, in Paddle, a code program
// contains:
// - main block, which is a list of OpLite
// - scope: which contains all the weights
struct
Program
{
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
;
};
// An Graph for MIR. It is built from a list of Op and a scope.
class
GraphBase
{};
class
SSAGraph
:
GraphBase
{
public:
// @param program: the op program
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
();
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
}
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
}
}
}
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
const
;
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
8b950a4f
...
...
@@ -44,5 +44,15 @@ void OpLite::PickKernel(const std::vector<Place> &valid_places,
}
}
bool
OpLite
::
Run
()
{
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
RecordOutputEvents
();
return
true
;
}
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
8b950a4f
...
...
@@ -36,6 +36,11 @@ struct Registry {
void
Touch
()
{}
};
namespace
mir
{
class
Node
;
class
SSAGraph
;
}
/**
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
...
...
@@ -71,19 +76,13 @@ class OpLite : public Registry {
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
// Run this operator.
virtual
bool
Run
()
{
CHECK
(
kernel_
);
SyncInputEvents
();
kernel_
->
Run
();
virtual
bool
Run
();
RecordOutputEvents
();
return
true
;
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
ExtractInputsAndOutputs
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
}
// Attach it with the runtime environment.
virtual
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
...
...
@@ -92,9 +91,29 @@ class OpLite : public Registry {
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
virtual
~
OpLite
()
=
default
;
protected:
// Attach it with the runtime environment.
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
...
@@ -113,12 +132,17 @@ class OpLite : public Registry {
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
friend
class
mir
::
Node
;
friend
class
mir
::
SSAGraph
;
protected:
std
::
unique_ptr
<
OpContext
>
op_context_
;
std
::
unique_ptr
<
KernelBase
>
kernel_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
};
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
8b950a4f
...
...
@@ -13,3 +13,46 @@
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
// ------------------------- GetType specification ----------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
false
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
UnsupportedTy
x
;
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kX86
);
return
&
x
;
}
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
,
int
device
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
}
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
}
}
// ------------------------- end GetType specification ------------------------
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/type_system.h
浏览文件 @
8b950a4f
...
...
@@ -82,7 +82,7 @@ class DataTypeBase {
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class
DeviceData
Type
:
public
DataTypeBase
{
class
Type
:
public
DataTypeBase
{
public:
TargetType
target
()
const
{
return
place_
.
target
;
}
PrecisionType
precision
()
const
{
return
place_
.
precision
;
}
...
...
@@ -90,23 +90,31 @@ class DeviceDataType : public DataTypeBase {
const
Place
&
place
()
const
{
return
place_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
bool
operator
==
(
const
DeviceData
Type
&
other
)
{
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a instruction to transform a type to another.
virtual
bool
TypeCastable
(
const
DeviceDataType
&
type
)
const
{
return
id_
==
type
.
id
();
}
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
// Get a type.
static
const
Type
*
Get
();
template
<
typename
TypeTy
>
static
const
Type
*
Get
(
TargetType
target
=
TargetType
::
kHost
);
virtual
~
DeviceData
Type
()
=
default
;
virtual
~
Type
()
=
default
;
protected:
DeviceData
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
)
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
)
:
DataTypeBase
(
id
,
is_tensor
),
place_
{
target
,
precision
,
layout
},
name_
(
name
)
{}
...
...
@@ -117,30 +125,33 @@ class DeviceDataType : public DataTypeBase {
};
// -------------------------------- predefined types ---------------------------
class
Void
:
public
DeviceDataType
{
// TODO(Superjomn) make all the Types' constructs protected to make sure there
// is only one instance across the system.
class
VoidTy
:
public
Type
{
public:
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
,
false
/*is_tensor*/
)
{}
};
class
UnsupportedTy
:
public
Type
{
public:
Void
()
:
DeviceDataType
(
ID
::
Void
,
"Voi
d"
,
false
/*is_tensor*/
)
{}
UnsupportedTy
()
:
Type
(
ID
::
Unsupported
,
"Unsupporte
d"
,
false
/*is_tensor*/
)
{}
};
class
TensorFp32NCHW
:
public
DeviceData
Type
{
class
TensorFp32NCHW
Ty
:
public
Type
{
public:
TensorFp32NCHW
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
)
{}
TensorFp32NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt8NCHW
:
public
DeviceData
Type
{
class
TensorInt8NCHW
Ty
:
public
Type
{
public:
TensorInt8NCHW
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
TensorInt8NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt64NCHW
:
public
DeviceData
Type
{
class
TensorInt64NCHW
Ty
:
public
Type
{
public:
TensorInt64NCHW
(
TargetType
target
)
:
DeviceDataType
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
TensorInt64NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
// ------------------------- end predefined types ---------------------------
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
8b950a4f
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/lite/kernels/host/fc_compute.h"
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -51,6 +52,8 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
.
BindInput
(
0
,
{
typeid
(
paddle
::
lite
::
Tensor
).
hash_code
(),
paddle
::
lite
::
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}})
.
BindInput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kX86
))})
.
BindOutput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kX86
))})
.
Finalize
();
paddle/fluid/lite/operators/fc_op.h
浏览文件 @
8b950a4f
...
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
...
...
@@ -44,7 +46,8 @@ class FcOpLite : public OpLite {
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"Input"
).
front
();
auto
W
=
op_desc
.
Input
(
"W"
).
front
();
auto
bias
=
op_desc
.
Input
(
"Bias"
).
front
();
...
...
paddle/fluid/lite/operators/fc_op_test.cc
浏览文件 @
8b950a4f
...
...
@@ -61,7 +61,7 @@ TEST(fc_op_lite, test) {
fc
.
SetValidPlaces
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc
.
Attach
(
desc
,
&
scope
);
fc
.
Attach
Impl
(
desc
,
&
scope
);
fc
.
Run
();
for
(
int
i
=
0
;
i
<
10
*
20
;
i
++
)
{
...
...
paddle/fluid/lite/operators/mul_op.h
浏览文件 @
8b950a4f
...
...
@@ -37,7 +37,8 @@ class MulOpLite : public OpLite {
bool
InferShape
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
input
=
op_desc
.
Input
(
"X"
).
front
();
auto
W
=
op_desc
.
Input
(
"Y"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
paddle/fluid/lite/operators/relu_op.cc
浏览文件 @
8b950a4f
...
...
@@ -31,7 +31,7 @@ bool ReluOp::InferShape() const {
return
true
;
}
bool
ReluOp
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
ReluOp
::
Attach
Impl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
input
=
const_cast
<
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"Input"
).
front
())
->
Get
<
Tensor
>
());
param_
.
output
=
...
...
paddle/fluid/lite/operators/relu_op.h
浏览文件 @
8b950a4f
...
...
@@ -32,7 +32,7 @@ class ReluOp : public OpLite {
bool
InferShape
()
const
override
;
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
Attach
Impl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
std
::
string
DebugString
()
const
override
{
return
"tanh"
;
}
...
...
paddle/fluid/lite/operators/scale_op.cc
浏览文件 @
8b950a4f
...
...
@@ -44,7 +44,8 @@ class ScaleOp : public OpLite {
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
Attach
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
framework
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
out
=
op_desc
.
Output
(
"Out"
).
front
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录