Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f1ca00a4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1ca00a4
编写于
5月 02, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename some concepts
Instruction to Stmt
上级
ec27aa46
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
118 addition
and
116 deletion
+118
-116
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+4
-0
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
+4
-4
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+4
-5
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+6
-6
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+35
-35
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+5
-5
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
+3
-3
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+9
-9
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+7
-7
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+1
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+21
-22
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+9
-9
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+1
-1
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+1
-1
未找到文件。
paddle/fluid/lite/core/kernel.h
浏览文件 @
f1ca00a4
...
...
@@ -104,12 +104,16 @@ class KernelBase {
mutable
operators
::
param_t
param_
;
// The corresponding op type.
std
::
string
op_type_
{};
// The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel.
std
::
string
alias_
{};
};
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN,
// MKLDNN, plain CUDA C implementations.
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
DataLayout
=
DataLayoutType
::
kNCHW
>
class
OpKernel
:
public
KernelBase
{
...
...
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
LOG
(
INFO
)
<<
"== Argument types =="
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsArg
ument
())
continue
;
if
(
!
node
.
IsArg
())
continue
;
auto
*
type
=
node
.
AsArg
ument
().
type
;
auto
*
type
=
node
.
AsArg
().
type
;
if
(
type
)
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: "
<<
*
type
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: "
<<
*
type
;
}
else
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: UNK"
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: UNK"
;
}
}
LOG
(
INFO
)
<<
"---------------------"
;
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -23,11 +23,10 @@ namespace mir {
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
LOG
(
INFO
)
<<
instruct
;
insts_
.
emplace_back
(
instruct
.
op
,
std
::
move
(
instruct
.
valid_kernels
.
front
()));
if
(
item
->
IsStmt
())
{
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
stmt
;
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
std
::
string
key
;
if
(
node
.
IsArg
ument
())
{
key
=
node
.
AsArg
ument
().
name
;
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
().
name
;
}
else
{
key
=
node
.
As
Instruc
t
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
As
Stm
t
().
op_type
+
std
::
to_string
(
id
++
);
}
if
(
node
.
Is
Instruc
t
())
{
if
(
node
.
Is
Stm
t
())
{
dot
.
AddNode
(
key
,
{
Dot
::
Attr
(
"shape"
,
"box"
)});
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
...
...
@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
exists_args
.
insert
(
name
);
}
for
(
auto
&
x
:
node
.
outlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
}
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -19,20 +19,20 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
IoCopyKernelPickPass
:
public
Instruction
Pass
{
class
IoCopyKernelPickPass
:
public
Stmt
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Stm
t
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
As
Instruc
t
().
valid_kernels
;
auto
&
kernels
=
node
.
As
Stm
t
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
f1ca00a4
...
...
@@ -34,15 +34,15 @@ class Node {
Node
()
=
default
;
enum
class
Role
{
kArg
ument
=
0
,
k
Instruc
t
,
kArg
=
0
,
k
Stm
t
,
kNumRoles
,
/*should be last*/
kUnk
,
};
struct
Instruc
t
{
struct
Stm
t
{
std
::
string
op_type
;
// The kernel instances this
Instruc
t contains.
// The kernel instances this
Statemen
t contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
...
...
@@ -62,13 +62,13 @@ class Node {
return
*
valid_kernels
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruc
t
&
other
)
{
os
<<
"
Instruc
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stm
t
&
other
)
{
os
<<
"
Statemen
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
}
};
struct
Arg
ument
{
struct
Arg
{
std
::
string
name
;
const
Type
*
type
{};
// Weight is a special kind of argument, it is marked as weight explicitly
...
...
@@ -76,16 +76,16 @@ class Node {
bool
is_weight
{
false
};
};
Arg
ument
&
AsArgument
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
ument
();
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
();
x
.
name
=
name
;
return
x
;
}
Instruct
&
AsInstruc
t
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
As
Instruc
t
();
Stmt
&
AsStm
t
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
As
Stm
t
();
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
...
...
@@ -93,23 +93,23 @@ class Node {
}
// Set roles.
Arg
ument
&
AsArgument
()
{
Arg
&
AsArg
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kArg
ument
);
return
*
arg
ument
_
;
CHECK
(
role_
==
Role
::
kArg
);
return
*
arg_
;
}
role_
=
Role
::
kArg
ument
;
arg
ument_
.
reset
(
new
Argument
);
return
*
arg
ument
_
;
role_
=
Role
::
kArg
;
arg
_
.
reset
(
new
Arg
);
return
*
arg_
;
}
Instruct
&
AsInstruc
t
()
{
Stmt
&
AsStm
t
()
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
k
Instruc
t
);
return
*
instruc
t_
;
CHECK
(
role_
==
Role
::
k
Stm
t
);
return
*
stm
t_
;
}
role_
=
Role
::
k
Instruc
t
;
instruct_
.
reset
(
new
Instruc
t
);
return
*
instruc
t_
;
role_
=
Role
::
k
Stm
t
;
stmt_
.
reset
(
new
Stm
t
);
return
*
stm
t_
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
...
...
@@ -117,26 +117,26 @@ class Node {
if
(
!
other
.
IsRoleSet
())
{
os
<<
"Unk role node"
;
}
if
(
other
.
IsArg
ument
())
{
auto
&
arg
=
other
.
AsArg
ument
();
if
(
other
.
IsArg
())
{
auto
&
arg
=
other
.
AsArg
();
os
<<
"Argument "
<<
arg
.
name
;
}
if
(
other
.
Is
Instruc
t
())
{
auto
&
arg
=
other
.
As
Instruc
t
();
os
<<
"
Instruc
t "
<<
arg
.
op_type
;
if
(
other
.
Is
Stm
t
())
{
auto
&
arg
=
other
.
As
Stm
t
();
os
<<
"
Statemen
t "
<<
arg
.
op_type
;
}
return
os
;
}
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
!=
Role
::
kUnk
;
}
bool
Is
Instruct
()
const
{
return
role_
==
Role
::
kInstruc
t
;
}
bool
IsArg
ument
()
const
{
return
role_
==
Role
::
kArgument
;
}
bool
Is
Stmt
()
const
{
return
role_
==
Role
::
kStm
t
;
}
bool
IsArg
()
const
{
return
role_
==
Role
::
kArg
;
}
private:
// Either
instruc
t_ or argument_ is used.
std
::
unique_ptr
<
Instruct
>
instruc
t_
;
std
::
unique_ptr
<
Arg
ument
>
argument
_
;
// Either
stm
t_ or argument_ is used.
std
::
unique_ptr
<
Stmt
>
stm
t_
;
std
::
unique_ptr
<
Arg
>
arg
_
;
Role
role_
{
Role
::
kUnk
};
};
...
...
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
f1ca00a4
...
...
@@ -26,8 +26,8 @@ class Pass {
enum
class
Kind
{
// Will modify the program/graph topology.
kProgramWise
=
0
,
// Will modify the
instruction
, with the graph topology fixed.
k
Instruction
Wise
,
// Will modify the
statement
, with the graph topology fixed.
k
Stmt
Wise
,
// Will not modify the IR, just collect information or visualization.
kDebug
,
};
...
...
@@ -45,7 +45,7 @@ class Pass {
Kind
kind
()
const
{
return
kind_
;
}
bool
is_debug_pass
()
const
{
return
kind_
==
Kind
::
kDebug
;
}
bool
is_program_pass
()
const
{
return
kind_
==
Kind
::
kProgramWise
;
}
bool
is_
instruction_pass
()
const
{
return
kind_
==
Kind
::
kInstruction
Wise
;
}
bool
is_
stmt_pass
()
const
{
return
kind_
==
Kind
::
kStmt
Wise
;
}
virtual
~
Pass
()
=
default
;
...
...
@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
ProgramPass
()
:
Pass
(
Kind
::
kProgramWise
)
{}
};
class
Instruction
Pass
:
public
Pass
{
class
Stmt
Pass
:
public
Pass
{
public:
InstructionPass
()
:
Pass
(
Kind
::
kInstruction
Wise
)
{}
StmtPass
()
:
Pass
(
Kind
::
kStmt
Wise
)
{}
};
class
DebugPass
:
public
Pass
{
...
...
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -19,7 +19,7 @@ namespace paddle {
namespace
lite
{
namespace
mir
{
class
RuntimeContextAssignPass
:
public
Instruction
Pass
{
class
RuntimeContextAssignPass
:
public
Stmt
Pass
{
public:
RuntimeContextAssignPass
()
{
#ifdef LITE_WITH_CUDA
...
...
@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
auto
&
inst
=
node
.
As
Stm
t
();
switch
(
inst
.
picked_kernel
().
target
())
{
case
TARGET
(
kHost
):
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
f1ca00a4
...
...
@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
!
n
.
Is
Instruc
t
())
continue
;
if
(
!
n
.
Is
Stm
t
())
continue
;
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
}
std
::
vector
<
mir
::
Node
*>
nodes
;
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
PADDLE_ENFORCE
(
adj_n
->
Is
Instruc
t
());
PADDLE_ENFORCE
(
adj_n
->
Is
Stm
t
());
nodes
.
push_back
(
adj_n
);
}
}
...
...
@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
...
...
@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
}
}
...
...
@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
As
Instruc
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
node_storage_
.
back
().
As
Stm
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
...
...
@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
bool
SSAGraph
::
CheckLinksRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
for
(
auto
*
x
:
node
.
inlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
for
(
auto
*
x
:
node
.
outlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
}
return
true
;
...
...
@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArg
ument
(
name
);
node_storage_
.
back
().
AsArg
(
name
);
return
&
node_storage_
.
back
();
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
f1ca00a4
...
...
@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
arguments_
[
name
]
->
AsArg
ument
().
is_weight
=
true
;
arguments_
[
name
]
->
AsArg
().
is_weight
=
true
;
}
}
...
...
@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
// instr -> output argument
if
(
a
->
Is
Instruct
()
&&
b
->
IsArgument
())
{
auto
&
inst
=
a
->
As
Instruc
t
();
auto
&
output
=
b
->
AsArg
ument
();
if
(
a
->
Is
Stmt
()
&&
b
->
IsArg
())
{
auto
&
inst
=
a
->
As
Stm
t
();
auto
&
output
=
b
->
AsArg
();
if
(
!
output
.
type
)
{
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
...
...
@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
}
// input argument -> instr
if
(
a
->
IsArg
ument
()
&&
b
->
IsInstruc
t
())
{
auto
&
input
=
a
->
AsArg
ument
();
auto
&
inst
=
b
->
As
Instruc
t
();
if
(
a
->
IsArg
()
&&
b
->
IsStm
t
())
{
auto
&
input
=
a
->
AsArg
();
auto
&
inst
=
b
->
As
Stm
t
();
if
(
!
input
.
type
)
{
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
graph
)
<<
"graph not valid"
;
// sort kernels by the factors.
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
auto
&
instruct
=
node
.
As
Instruc
t
();
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
instruct
=
node
.
As
Stm
t
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
f1ca00a4
...
...
@@ -33,7 +33,7 @@ namespace mir {
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
*/
class
StaticKernelPickPass
:
public
mir
::
Instruction
Pass
{
class
StaticKernelPickPass
:
public
mir
::
Stmt
Pass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
f1ca00a4
...
...
@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
!
valid_places_
.
empty
());
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
Is
Instruc
t
())
continue
;
if
(
!
node
->
Is
Stm
t
())
continue
;
auto
inlinks
=
node
->
inlinks
;
for
(
auto
*
in
:
inlinks
)
{
ComplementInputs
(
graph
.
get
(),
node
,
in
);
...
...
@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
return
;
CHECK
(
inst_node
->
Is
Instruc
t
());
auto
&
inst
=
inst_node
->
As
Instruc
t
();
CHECK
(
inst_node
->
Is
Stm
t
());
auto
&
inst
=
inst_node
->
As
Stm
t
();
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsArg
ument
());
auto
in_arg_name
=
in
->
AsArg
ument
().
name
;
CHECK
(
in
->
IsArg
());
auto
in_arg_name
=
in
->
AsArg
().
name
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArg
ument
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
ument
().
name
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
ument
().
type
<<
" -> "
<<
*
decl_arg_type
;
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
,
in
->
AsArgument
().
name
,
graph
,
inst_node
,
valid_places_
);
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
->
AsArg
().
name
,
graph
,
inst_node
,
valid_places_
);
}
}
...
...
@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy
Instruc
t Node.
// So there will be a new Argument node and a new IoCopy
Statemen
t Node.
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
...
...
@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// Create the new var manually.
inst_node
->
As
Instruc
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
As
Stm
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
...
...
@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Instruc
t
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Stm
t
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
As
Instruc
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
io_copy_inst
->
As
Stm
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
As
Instruc
t
();
auto
&
inst
=
inst_node
->
As
Stm
t
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
...
...
@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
As
Instruc
t
().
op
->
op_info
()
->
desc
();
auto
desc_dummy
=
inst_node
->
As
Stm
t
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsInstruct
().
op
->
scope
());
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
std
::
string
tmp
;
if
(
inst_node
->
As
Instruc
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
As
Stm
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
for
(
auto
&
kernel
:
inst_node
->
As
Instruc
t
().
valid_kernels
)
{
inst_node
->
As
Instruc
t
().
op
->
AttachKernel
(
kernel
.
get
());
for
(
auto
&
kernel
:
inst_node
->
As
Stm
t
().
valid_kernels
)
{
inst_node
->
As
Stm
t
().
op
->
AttachKernel
(
kernel
.
get
());
}
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
f1ca00a4
...
...
@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
Is
Instruc
t
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Instruc
t
().
op_type
;
if
(
v
->
Is
Stm
t
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Stm
t
().
op_type
;
continue
;
}
...
...
@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArg
ument
())
{
CHECK
(
node
.
AsArg
ument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
<<
" type not determined, "
<<
&
node
;
if
(
node
.
IsArg
())
{
CHECK
(
node
.
AsArg
().
type
)
<<
"node "
<<
node
.
AsArg
().
name
<<
" type not determined, "
<<
&
node
;
}
}
}
...
...
@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
As
Instruc
t
();
auto
&
inst
=
x
->
As
Stm
t
();
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
...
...
@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
...
...
@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
node
->
AsArg
ument
().
type
=
type
;
node
->
AsArg
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
}
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
f1ca00a4
...
...
@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
optimizer
.
Run
(
std
::
move
(
program
),
places
);
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
LOG
(
INFO
)
<<
"num
instruction
s "
<<
runtime_program
->
num_instructions
();
LOG
(
INFO
)
<<
"num
statement
s "
<<
runtime_program
->
num_instructions
();
}
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
f1ca00a4
...
...
@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a
instruction
to transform a type to another.
// is is possible to add a
statement
to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录