Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
f1ca00a4
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1ca00a4
编写于
5月 02, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename some concepts
Instruction to Stmt
上级
ec27aa46
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
118 addition
and
116 deletion
+118
-116
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+4
-0
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
+4
-4
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+4
-5
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+6
-6
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+35
-35
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+5
-5
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
+3
-3
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+9
-9
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+7
-7
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+1
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+21
-22
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+9
-9
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+1
-1
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+1
-1
未找到文件。
paddle/fluid/lite/core/kernel.h
浏览文件 @
f1ca00a4
...
...
@@ -104,12 +104,16 @@ class KernelBase {
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
{};
// The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel.
std
::
string
alias_
{};
};
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN,
// MKLDNN, plain CUDA C implementations.
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
DataLayout
=
DataLayoutType
::
kNCHW
>
class
OpKernel
:
public
KernelBase
{
...
...
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
LOG
(
INFO
)
<<
"== Argument types =="
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsArg
ument
())
continue
;
if
(
!
node
.
IsArg
())
continue
;
auto
*
type
=
node
.
AsArg
ument
().
type
;
auto
*
type
=
node
.
AsArg
().
type
;
if
(
type
)
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: "
<<
*
type
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: "
<<
*
type
;
}
else
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: UNK"
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: UNK"
;
}
}
LOG
(
INFO
)
<<
"---------------------"
;
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -23,11 +23,10 @@ namespace mir {
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
LOG
(
INFO
)
<<
instruct
;
insts_
.
emplace_back
(
instruct
.
op
,
std
::
move
(
instruct
.
valid_kernels
.
front
()));
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
std
::
string
key
;
if
(
node
.
IsArg
ument
())
{
key
=
node
.
AsArg
ument
().
name
;
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
}
else
{
key
=
node
.
As
Instruc
t
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
As
Stm
t
().
op_type
+
std
::
to_string
(
id
++
);
}
if
(
node
.
Is
Instruc
t
())
{
if
(
node
.
Is
Stm
t
())
{
dot
.
AddNode
(
key
,
{
Dot
::
Attr
(
"shape"
,
"box"
)});
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
...
...
@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
exists_args
.
insert
(
name
);
}
for
(
auto
&
x
:
node
.
outlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -19,20 +19,20 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
IoCopyKernelPickPass
:
public
Instruction
Pass
{
class
IoCopyKernelPickPass
:
public
Stmt
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Stm
t
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
As
Instruc
t
().
valid_kernels
;
auto
&
kernels
=
node
.
As
Stm
t
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
f1ca00a4
...
...
@@ -34,15 +34,15 @@ class Node {
Node
()
=
default
;
enum
class
Role
{
kArg
ument
=
0
,
k
Instruc
t
,
kArg
=
0
,
k
Stm
t
,
kNumRoles
,
/*should be last*/
kUnk
,
};
struct
Instruc
t
{
struct
Stm
t
{
std
::
string
op_type
;
// The kernel instances this
Instruc
t contains.
// The kernel instances this
Statemen
t contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
...
...
@@ -62,13 +62,13 @@ class Node {
return
*
valid_kernels
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruc
t
&
other
)
{
os
<<
"
Instruc
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stm
t
&
other
)
{
os
<<
"
Statemen
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
};
struct
Arg
ument
{
struct
Arg
{
std
::
string
name
;
const
Type
*
type
{};
// Weight is a special kind of argument, it is marked as weight explicitly
...
...
@@ -76,16 +76,16 @@ class Node {
bool
is_weight
{
false
};
};
Arg
ument
&
AsArgument
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
ument
();
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
Instruct
&
AsInstruc
t
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
As
Instruc
t
();
Stmt
&
AsStm
t
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
As
Stm
t
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
...
...
@@ -93,23 +93,23 @@ class Node {
}
// Set roles.
Arg
ument
&
AsArgument
()
{
Arg
&
AsArg
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kArg
ument
);
return
*
arg
ument
_
;
CHECK
(
role_
==
Role
::
kArg
);
return
*
arg_
;
}
role_
=
Role
::
kArg
ument
;
arg
ument_
.
reset
(
new
Argument
);
return
*
arg
ument
_
;
role_
=
Role
::
kArg
;
arg
_
.
reset
(
new
Arg
);
return
*
arg_
;
}
Instruct
&
AsInstruc
t
()
{
Stmt
&
AsStm
t
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
k
Instruc
t
);
return
*
instruc
t_
;
CHECK
(
role_
==
Role
::
k
Stm
t
);
return
*
stm
t_
;
}
role_
=
Role
::
k
Instruc
t
;
instruct_
.
reset
(
new
Instruc
t
);
return
*
instruc
t_
;
role_
=
Role
::
k
Stm
t
;
stmt_
.
reset
(
new
Stm
t
);
return
*
stm
t_
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
...
...
@@ -117,26 +117,26 @@ class Node {
if
(
!
other
.
IsRoleSet
())
{
os
<<
"Unk role node"
;
}
if
(
other
.
IsArg
ument
())
{
auto
&
arg
=
other
.
AsArg
ument
();
if
(
other
.
IsArg
())
{
auto
&
arg
=
other
.
AsArg
();
os
<<
"Argument "
<<
arg
.
name
;
}
if
(
other
.
Is
Instruc
t
())
{
auto
&
arg
=
other
.
As
Instruc
t
();
os
<<
"
Instruc
t "
<<
arg
.
op_type
;
if
(
other
.
Is
Stm
t
())
{
auto
&
arg
=
other
.
As
Stm
t
();
os
<<
"
Statemen
t "
<<
arg
.
op_type
;
}
return
os
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
!=
Role
::
kUnk
;
}
bool
Is
Instruct
()
const
{
return
role_
==
Role
::
kInstruc
t
;
}
bool
IsArg
ument
()
const
{
return
role_
==
Role
::
kArgument
;
}
bool
Is
Stmt
()
const
{
return
role_
==
Role
::
kStm
t
;
}
bool
IsArg
()
const
{
return
role_
==
Role
::
kArg
;
}
private:
// Either
instruc
t_ or argument_ is used.
std
::
unique_ptr
<
Instruct
>
instruc
t_
;
std
::
unique_ptr
<
Arg
ument
>
argument
_
;
// Either
stm
t_ or argument_ is used.
std
::
unique_ptr
<
Stmt
>
stm
t_
;
std
::
unique_ptr
<
Arg
>
arg
_
;
Role
role_
{
Role
::
kUnk
};
};
...
...
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
f1ca00a4
...
...
@@ -26,8 +26,8 @@ class Pass {
enum
class
Kind
{
// Will modify the program/graph topology.
kProgramWise
=
0
,
// Will modify the
instruction
, with the graph topology fixed.
k
Instruction
Wise
,
// Will modify the
statement
, with the graph topology fixed.
k
Stmt
Wise
,
// Will not modify the IR, just collect information or visualization.
kDebug
,
};
...
...
@@ -45,7 +45,7 @@ class Pass {
Kind
kind
()
const
{
return
kind_
;
}
bool
is_debug_pass
()
const
{
return
kind_
==
Kind
::
kDebug
;
}
bool
is_program_pass
()
const
{
return
kind_
==
Kind
::
kProgramWise
;
}
bool
is_
instruction_pass
()
const
{
return
kind_
==
Kind
::
kInstruction
Wise
;
}
bool
is_
stmt_pass
()
const
{
return
kind_
==
Kind
::
kStmt
Wise
;
}
virtual
~
Pass
()
=
default
;
...
...
@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
ProgramPass
()
:
Pass
(
Kind
::
kProgramWise
)
{}
};
class
Instruction
Pass
:
public
Pass
{
class
Stmt
Pass
:
public
Pass
{
public:
InstructionPass
()
:
Pass
(
Kind
::
kInstruction
Wise
)
{}
StmtPass
()
:
Pass
(
Kind
::
kStmt
Wise
)
{}
};
class
DebugPass
:
public
Pass
{
...
...
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -19,7 +19,7 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
RuntimeContextAssignPass
:
public
Instruction
Pass
{
class
RuntimeContextAssignPass
:
public
Stmt
Pass
{
public:
RuntimeContextAssignPass
()
{
#ifdef LITE_WITH_CUDA
...
...
@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
auto
&
inst
=
node
.
As
Stm
t
();
switch
(
inst
.
picked_kernel
().
target
())
{
case
TARGET
(
kHost
):
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
f1ca00a4
...
...
@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
!
n
.
Is
Instruc
t
())
continue
;
if
(
!
n
.
Is
Stm
t
())
continue
;
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
}
std
::
vector
<
mir
::
Node
*>
nodes
;
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
PADDLE_ENFORCE
(
adj_n
->
Is
Instruc
t
());
PADDLE_ENFORCE
(
adj_n
->
Is
Stm
t
());
nodes
.
push_back
(
adj_n
);
}
}
...
...
@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
...
...
@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
...
...
@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
As
Instruc
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
node_storage_
.
back
().
As
Stm
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
...
...
@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
bool
SSAGraph
::
CheckLinksRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
for
(
auto
*
x
:
node
.
inlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
for
(
auto
*
x
:
node
.
outlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
}
return
true
;
...
...
@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArg
ument
(
name
);
node_storage_
.
back
().
AsArg
(
name
);
return
&
node_storage_
.
back
();
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
f1ca00a4
...
...
@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
arguments_
[
name
]
->
AsArg
ument
().
is_weight
=
true
;
arguments_
[
name
]
->
AsArg
().
is_weight
=
true
;
}
}
...
...
@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
// instr -> output argument
if
(
a
->
Is
Instruct
()
&&
b
->
IsArgument
())
{
auto
&
inst
=
a
->
As
Instruc
t
();
auto
&
output
=
b
->
AsArg
ument
();
if
(
a
->
Is
Stmt
()
&&
b
->
IsArg
())
{
auto
&
inst
=
a
->
As
Stm
t
();
auto
&
output
=
b
->
AsArg
();
if
(
!
output
.
type
)
{
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
...
...
@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
}
// input argument -> instr
if
(
a
->
IsArg
ument
()
&&
b
->
IsInstruc
t
())
{
auto
&
input
=
a
->
AsArg
ument
();
auto
&
inst
=
b
->
As
Instruc
t
();
if
(
a
->
IsArg
()
&&
b
->
IsStm
t
())
{
auto
&
input
=
a
->
AsArg
();
auto
&
inst
=
b
->
As
Stm
t
();
if
(
!
input
.
type
)
{
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
graph
)
<<
"graph not valid"
;
// sort kernels by the factors.
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
auto
&
instruct
=
node
.
As
Instruc
t
();
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
instruct
=
node
.
As
Stm
t
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
f1ca00a4
...
...
@@ -33,7 +33,7 @@ namespace mir {
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class
StaticKernelPickPass
:
public
mir
::
Instruction
Pass
{
class
StaticKernelPickPass
:
public
mir
::
Stmt
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
!
valid_places_
.
empty
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
Is
Instruc
t
())
continue
;
if
(
!
node
->
Is
Stm
t
())
continue
;
auto
inlinks
=
node
->
inlinks
;
for
(
auto
*
in
:
inlinks
)
{
ComplementInputs
(
graph
.
get
(),
node
,
in
);
...
...
@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
return
;
CHECK
(
inst_node
->
Is
Instruc
t
());
auto
&
inst
=
inst_node
->
As
Instruc
t
();
CHECK
(
inst_node
->
Is
Stm
t
());
auto
&
inst
=
inst_node
->
As
Stm
t
();
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsArg
ument
());
auto
in_arg_name
=
in
->
AsArg
ument
().
name
;
CHECK
(
in
->
IsArg
());
auto
in_arg_name
=
in
->
AsArg
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArg
ument
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
ument
().
name
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
ument
().
type
<<
" -> "
<<
*
decl_arg_type
;
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
,
in
->
AsArgument
().
name
,
graph
,
inst_node
,
valid_places_
);
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
->
AsArg
().
name
,
graph
,
inst_node
,
valid_places_
);
}
}
...
...
@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy
Instruc
t Node.
// So there will be a new Argument node and a new IoCopy
Statemen
t Node.
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
...
...
@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// Create the new var manually.
inst_node
->
As
Instruc
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
As
Stm
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
...
...
@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Instruc
t
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Stm
t
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
As
Instruc
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
io_copy_inst
->
As
Stm
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
As
Instruc
t
();
auto
&
inst
=
inst_node
->
As
Stm
t
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
...
...
@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
As
Instruc
t
().
op
->
op_info
()
->
desc
();
auto
desc_dummy
=
inst_node
->
As
Stm
t
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
As
Instruc
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
As
Stm
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
for
(
auto
&
kernel
:
inst_node
->
As
Instruc
t
().
valid_kernels
)
{
inst_node
->
As
Instruc
t
().
op
->
AttachKernel
(
kernel
.
get
());
for
(
auto
&
kernel
:
inst_node
->
As
Stm
t
().
valid_kernels
)
{
inst_node
->
As
Stm
t
().
op
->
AttachKernel
(
kernel
.
get
());
}
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
f1ca00a4
...
...
@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
Is
Instruc
t
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Instruc
t
().
op_type
;
if
(
v
->
Is
Stm
t
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Stm
t
().
op_type
;
continue
;
}
...
...
@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArg
ument
())
{
CHECK
(
node
.
AsArg
ument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
<<
" type not determined, "
<<
&
node
;
if
(
node
.
IsArg
())
{
CHECK
(
node
.
AsArg
().
type
)
<<
"node "
<<
node
.
AsArg
().
name
<<
" type not determined, "
<<
&
node
;
}
}
}
...
...
@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
As
Instruc
t
();
auto
&
inst
=
x
->
As
Stm
t
();
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
...
...
@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
...
...
@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
node
->
AsArg
ument
().
type
=
type
;
node
->
AsArg
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
}
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
f1ca00a4
...
...
@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
optimizer
.
Run
(
std
::
move
(
program
),
places
);
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
LOG
(
INFO
)
<<
"num
instruction
s "
<<
runtime_program
->
num_instructions
();
LOG
(
INFO
)
<<
"num
statement
s "
<<
runtime_program
->
num_instructions
();
}
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
f1ca00a4
...
...
@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a
instruction
to transform a type to another.
// is is possible to add a
statement
to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录