Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
cdb12e59
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cdb12e59
编写于
4月 16, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ssa_graph test
上级
8b950a4f
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
455 addition
and
11 deletion
+455
-11
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+1
-1
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+14
-2
paddle/fluid/lite/core/mir/demo_pass.cc
paddle/fluid/lite/core/mir/demo_pass.cc
+5
-1
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+13
-0
paddle/fluid/lite/core/mir/generate_program_pass.h
paddle/fluid/lite/core/mir/generate_program_pass.h
+33
-0
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+71
-0
paddle/fluid/lite/core/mir/graph_visualize_pass.h
paddle/fluid/lite/core/mir/graph_visualize_pass.h
+37
-0
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+13
-0
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+33
-0
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+12
-0
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+39
-0
paddle/fluid/lite/core/mir/pass_manager.cc
paddle/fluid/lite/core/mir/pass_manager.cc
+2
-4
paddle/fluid/lite/core/mir/pass_manager_test.cc
paddle/fluid/lite/core/mir/pass_manager_test.cc
+3
-0
paddle/fluid/lite/core/mir/pass_registry.h
paddle/fluid/lite/core/mir/pass_registry.h
+12
-0
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+23
-2
paddle/fluid/lite/core/mir/ssa_graph_test.cc
paddle/fluid/lite/core/mir/ssa_graph_test.cc
+102
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+13
-0
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+27
-0
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+1
-1
paddle/fluid/lite/kernels/host/CMakeLists.txt
paddle/fluid/lite/kernels/host/CMakeLists.txt
+1
-0
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
cdb12e59
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
kernel_lite SRCS kernel.cc
)
cc_library
(
kernel_lite SRCS kernel.cc
DEPS type_system
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
cdb12e59
...
@@ -3,6 +3,18 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
...
@@ -3,6 +3,18 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass SRCS pass.cc DEPS mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager
)
cc_library
(
mir_demo_pass SRCS demo_pass.cc DEPS mir_pass
)
cc_library
(
mir_passes
SRCS static_kernel_pick_pass.cc
io_complement_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
demo_pass.cc
DEPS mir_pass
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_demo_pass
)
cc_test
(
test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes
)
cc_test
(
test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_ssa_graph scope_lite op_lite
proto_desc ops_lite
host_kernels
mir_passes
)
paddle/fluid/lite/core/mir/demo_pass.cc
浏览文件 @
cdb12e59
...
@@ -19,15 +19,19 @@ namespace paddle {
...
@@ -19,15 +19,19 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
class
DemoPass
:
public
mir
::
Pass
{
class
DemoPass
:
public
mir
::
Debug
Pass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{}
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{}
};
};
/*
bool RegisterDemoPass() {
bool RegisterDemoPass() {
return PassManager::Global().AddNewPass("demo", new DemoPass);
return PassManager::Global().AddNewPass("demo", new DemoPass);
}
}
*/
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_MIR_PASS
(
demo
,
paddle
::
lite
::
mir
::
DemoPass
);
paddle/fluid/lite/core/mir/generate_program_pass.cc
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/core/mir/generate_program_pass.h
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
/*
* GenerateProgramPass will build the execution program for executor from a mir
* graph.
*/
class
GenerateProgramPass
:
public
Pass
{
public:
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include <set>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
GraphVisualizePass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
Visualize
(
graph
.
get
());
}
std
::
string
Visualize
(
mir
::
SSAGraph
*
graph
)
{
inference
::
analysis
::
Dot
dot
;
int
id
=
0
;
std
::
set
<
std
::
string
>
exists_args
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
std
::
string
key
;
if
(
node
.
IsArgument
())
{
key
=
node
.
AsArgument
().
name
;
}
else
{
key
=
node
.
AsInstruct
().
op_type
+
std
::
to_string
(
id
++
);
}
if
(
node
.
IsInstruct
())
{
dot
.
AddNode
(
key
,
{});
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
name
=
x
->
AsArgument
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
dot
.
AddEdge
(
name
,
key
,
{});
exists_args
.
insert
(
name
);
}
for
(
auto
&
x
:
node
.
outlinks
)
{
auto
name
=
x
->
AsArgument
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
dot
.
AddEdge
(
key
,
name
,
{});
exists_args
.
insert
(
name
);
}
}
}
auto
res
=
dot
.
Build
();
LOG
(
INFO
)
<<
"dot:
\n
"
<<
res
;
return
res
;
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
graph_visualze
,
paddle
::
lite
::
mir
::
GraphVisualizePass
);
paddle/fluid/lite/core/mir/graph_visualize_pass.h
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
/*
* GraphVisualizePass helps to visualize an mir graph by exporting a DOT
* language file.
*/
class
GraphVisualizePass
:
public
DebugPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
};
std
::
string
Visualize
(
mir
::
SSAGraph
*
graph
);
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/io_complement_pass.cc
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/core/mir/io_complement_pass.h
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
/*
* IoComplementPass complement the necessary instruction to make data
* transferring or transformation between different places.
*/
class
IoComplementPass
:
public
Pass
{
public:
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/mir/node.h
浏览文件 @
cdb12e59
...
@@ -51,6 +51,18 @@ class Node {
...
@@ -51,6 +51,18 @@ class Node {
Place
place
;
Place
place
;
};
};
Argument
&
AsArgument
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArgument
();
x
.
name
=
name
;
return
x
;
}
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
)
{
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
return
x
;
}
// Set roles.
// Set roles.
Argument
&
AsArgument
()
{
Argument
&
AsArgument
()
{
if
(
role_
!=
Role
::
kUnk
)
{
if
(
role_
!=
Role
::
kUnk
)
{
...
...
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
cdb12e59
...
@@ -22,14 +22,53 @@ namespace mir {
...
@@ -22,14 +22,53 @@ namespace mir {
class
Pass
{
class
Pass
{
public:
public:
// Some appoint here, one pass should be only one of the following kinds.
enum
class
Kind
{
// Will modify the program/graph topology.
kProgramWise
=
0
,
// Will modify the instruction, with the graph topology fixed.
kInstructionWise
,
// Will not modify the IR, just collect information or visualization.
kDebug
,
};
Pass
(
Kind
kind
)
:
kind_
(
kind
)
{}
virtual
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
=
0
;
virtual
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
=
0
;
void
set_name
(
const
std
::
string
&
name
)
{
name_
=
name
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
const
std
::
string
&
name
()
const
{
return
name_
;
}
void
set_doc
(
const
std
::
string
&
doc
)
{
doc_
=
doc
;
}
const
std
::
string
&
doc
()
const
{
return
doc_
;
}
Kind
kind
()
const
{
return
kind_
;
}
bool
is_debug_pass
()
const
{
return
kind_
==
Kind
::
kDebug
;
}
bool
is_program_pass
()
const
{
return
kind_
==
Kind
::
kProgramWise
;
}
bool
is_instruction_pass
()
const
{
return
kind_
==
Kind
::
kInstructionWise
;
}
virtual
~
Pass
()
=
default
;
virtual
~
Pass
()
=
default
;
private:
private:
const
Kind
kind_
;
std
::
string
name_
;
std
::
string
name_
;
std
::
string
doc_
;
};
// Different kinds.
class
ProgramPass
:
public
Pass
{
public:
ProgramPass
()
:
Pass
(
Kind
::
kProgramWise
)
{}
};
class
InstructionPass
:
public
Pass
{
public:
InstructionPass
()
:
Pass
(
Kind
::
kInstructionWise
)
{}
};
class
DebugPass
:
public
Pass
{
public:
DebugPass
()
:
Pass
(
Kind
::
kDebug
)
{}
};
};
}
// namespace mir
}
// namespace mir
...
...
paddle/fluid/lite/core/mir/pass_manager.cc
浏览文件 @
cdb12e59
...
@@ -21,10 +21,8 @@ namespace mir {
...
@@ -21,10 +21,8 @@ namespace mir {
PassManager
::
PassManager
()
{}
PassManager
::
PassManager
()
{}
// Manually register here.
extern
bool
RegisterDemoPass
();
static
bool
xx
__attribute__
((
unused
))
=
RegisterDemoPass
();
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
USE_MIR_PASS
(
demo
);
paddle/fluid/lite/core/mir/pass_manager_test.cc
浏览文件 @
cdb12e59
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -28,3 +29,5 @@ TEST(PassManager, test) {
...
@@ -28,3 +29,5 @@ TEST(PassManager, test) {
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
USE_MIR_PASS
(
demo
);
paddle/fluid/lite/core/mir/pass_registry.h
浏览文件 @
cdb12e59
...
@@ -32,6 +32,18 @@ class PassRegistry {
...
@@ -32,6 +32,18 @@ class PassRegistry {
bool
Touch
()
const
{
return
true
;
}
bool
Touch
()
const
{
return
true
;
}
};
};
#define REGISTER_MIR_PASS(name__, class__) \
paddle::lite::mir::PassRegistry mir_pass_registry##name__(#name__, \
new class__); \
bool mir_pass_registry##name__##_fake() { \
return mir_pass_registry##name__.Touch(); \
}
#define USE_MIR_PASS(name__) \
extern bool mir_pass_registry##name__##_fake(); \
static bool mir_pass_usage##name__ __attribute__((unused)) = \
mir_pass_registry##name__##_fake();
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
cdb12e59
...
@@ -30,8 +30,9 @@ namespace mir {
...
@@ -30,8 +30,9 @@ namespace mir {
// - main block, which is a list of OpLite
// - main block, which is a list of OpLite
// - scope: which contains all the weights
// - scope: which contains all the weights
struct
Program
{
struct
Program
{
std
::
list
<
std
::
string
>
inputs
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
;
std
::
unique_ptr
<
lite
::
Scope
>
scope
;
};
};
// An Graph for MIR. It is built from a list of Op and a scope.
// An Graph for MIR. It is built from a list of Op and a scope.
...
@@ -42,21 +43,38 @@ class SSAGraph : GraphBase {
...
@@ -42,21 +43,38 @@ class SSAGraph : GraphBase {
// @param program: the op program
// @param program: the op program
// @param valid_places: the valid places user set for the system.
// @param valid_places: the valid places user set for the system.
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
void
Build
(
const
Program
&
program
,
const
std
::
vector
<
Place
>
&
valid_places
)
{
// create inputs
for
(
const
auto
&
name
:
program
.
inputs
)
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
();
arg
.
name
=
name
;
arguments_
[
name
]
=
&
new_node
;
}
for
(
auto
&
op
:
program
.
ops
)
{
for
(
auto
&
op
:
program
.
ops
)
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
// TODO(Superjomn) remove one valid_places here.
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
);
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
}
}
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
arg
=
new_node
.
AsArgument
(
name
);
arg
.
name
=
name
;
arguments_
.
emplace
(
name
,
&
new_node
);
}
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
}
}
}
}
...
@@ -64,6 +82,9 @@ class SSAGraph : GraphBase {
...
@@ -64,6 +82,9 @@ class SSAGraph : GraphBase {
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
const
;
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
const
;
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
private:
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
...
...
paddle/fluid/lite/core/mir/ssa_graph_test.cc
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
BuildFc
(
framework
::
ProgramDesc
*
desc
,
const
std
::
string
&
x
,
const
std
::
string
&
w
,
const
std
::
string
&
b
,
const
std
::
string
&
out
)
{
auto
*
fc
=
desc
->
MutableBlock
(
0
)
->
AppendOp
();
fc
->
SetInput
(
"Input"
,
{
x
});
fc
->
SetInput
(
"W"
,
{
w
});
fc
->
SetInput
(
"Bias"
,
{
b
});
fc
->
SetOutput
(
"Out"
,
{
out
});
}
Program
FakeProgram
()
{
Program
program
;
program
.
scope
.
reset
(
new
lite
::
Scope
);
auto
add_fc
=
[
&
](
int
id
,
std
::
string
x
)
{
// create variables
std
::
string
w1
=
"w"
+
std
::
to_string
(
id
);
std
::
string
b1
=
"b"
+
std
::
to_string
(
id
);
std
::
string
out1
=
"out"
+
std
::
to_string
(
id
);
auto
w1v
=
program
.
scope
->
Var
(
w1
)
->
GetMutable
<
Tensor
>
();
auto
b1v
=
program
.
scope
->
Var
(
b1
)
->
GetMutable
<
Tensor
>
();
auto
out1v
=
program
.
scope
->
Var
(
out1
)
->
GetMutable
<
Tensor
>
();
framework
::
OpDesc
desc
;
desc
.
SetInput
(
"Input"
,
{
x
});
desc
.
SetInput
(
"W"
,
{
w1
});
desc
.
SetInput
(
"Bias"
,
{
b1
});
desc
.
SetOutput
(
"Out"
,
{
out1
});
desc
.
SetType
(
"fc"
);
desc
.
SetAttr
(
"in_num_col_dims"
,
1
);
desc
.
Flush
();
// add to input
program
.
inputs
.
push_back
(
w1
);
program
.
inputs
.
push_back
(
b1
);
auto
fc_op
=
LiteOpRegistry
::
Global
().
Create
(
"fc"
);
fc_op
->
PickKernel
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
fc_op
->
Attach
(
desc
,
program
.
scope
.
get
());
program
.
ops
.
emplace_back
(
std
::
move
(
fc_op
));
w1v
->
Resize
({
100
,
100
});
b1v
->
Resize
({
100
,
1
});
out1v
->
Resize
({
100
,
100
});
return
out1
;
};
// x1, w1, b1 -fc-> out1
// out1, w2, b2 -fc-> out2
std
::
string
x
=
"x"
;
program
.
inputs
.
push_back
(
x
);
auto
*
xv
=
program
.
scope
->
Var
(
x
)
->
GetMutable
<
Tensor
>
();
xv
->
Resize
({
100
,
100
});
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
x
=
add_fc
(
i
,
x
);
}
return
program
;
}
TEST
(
SSAGraph
,
test
)
{
auto
program
=
FakeProgram
();
SSAGraph
graph
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
graph
.
Build
(
program
,
places
);
Visualize
(
&
graph
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
fc
);
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
);
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
0 → 100644
浏览文件 @
cdb12e59
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
class
StaticKernelPickPass
:
public
mir
::
Pass
{};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/type_system.cc
浏览文件 @
cdb12e59
...
@@ -35,7 +35,7 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
...
@@ -35,7 +35,7 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
}
}
template
<
>
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
,
int
device
)
{
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
DataLayoutType
::
kNCHW
>
();
}
}
...
...
paddle/fluid/lite/kernels/host/CMakeLists.txt
浏览文件 @
cdb12e59
...
@@ -8,6 +8,7 @@ cc_library(host_kernels DEPS
...
@@ -8,6 +8,7 @@ cc_library(host_kernels DEPS
relu_compute_host
relu_compute_host
mul_compute_host
mul_compute_host
scale_compute_host
scale_compute_host
DEPS kernel_lite
)
)
cc_test
(
test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite
)
cc_test
(
test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录