Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
f1ca00a4
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
699
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f1ca00a4
编写于
5月 02, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename some concepts
Instruction to Stmt
上级
ec27aa46
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
118 addition
and
116 deletion
+118
-116
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+4
-0
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
+4
-4
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+4
-5
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
+6
-6
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
+6
-6
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+35
-35
paddle/fluid/lite/core/mir/pass.h
paddle/fluid/lite/core/mir/pass.h
+5
-5
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
+3
-3
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+9
-9
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+7
-7
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+2
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+1
-1
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
+21
-22
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+9
-9
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+1
-1
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+1
-1
未找到文件。
paddle/fluid/lite/core/kernel.h
浏览文件 @
f1ca00a4
...
@@ -104,12 +104,16 @@ class KernelBase {
...
@@ -104,12 +104,16 @@ class KernelBase {
mutable
operators
::
param_t
param_
;
mutable
operators
::
param_t
param_
;
// The corresponding op type.
// The corresponding op type.
std
::
string
op_type_
{};
std
::
string
op_type_
{};
// The extra identity to help defficiate a specific kernel, op_type_ + alias_
// is the unique ID for the kernel.
std
::
string
alias_
{};
std
::
string
alias_
{};
};
};
// Light-weight kernel implementation.
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// The OpKernel is designed to implement the specific algorithm on a target
// device.
// device.
// TODO(Superjomn) Consider to add a Platform type to differentiate CUDNN,
// MKLDNN, plain CUDA C implementations.
template
<
TargetType
Target
,
PrecisionType
Precision
,
template
<
TargetType
Target
,
PrecisionType
Precision
,
DataLayoutType
DataLayout
=
DataLayoutType
::
kNCHW
>
DataLayoutType
DataLayout
=
DataLayoutType
::
kNCHW
>
class
OpKernel
:
public
KernelBase
{
class
OpKernel
:
public
KernelBase
{
...
...
paddle/fluid/lite/core/mir/argument_type_display_pass.cc
浏览文件 @
f1ca00a4
...
@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
...
@@ -24,13 +24,13 @@ class ArgumentTypeDisplayPass : public DebugPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
LOG
(
INFO
)
<<
"== Argument types =="
;
LOG
(
INFO
)
<<
"== Argument types =="
;
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
IsArg
ument
())
continue
;
if
(
!
node
.
IsArg
())
continue
;
auto
*
type
=
node
.
AsArg
ument
().
type
;
auto
*
type
=
node
.
AsArg
().
type
;
if
(
type
)
{
if
(
type
)
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: "
<<
*
type
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: "
<<
*
type
;
}
else
{
}
else
{
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
ument
().
name
<<
" type: UNK"
;
LOG
(
INFO
)
<<
"* ARG "
<<
node
.
AsArg
().
name
<<
" type: UNK"
;
}
}
}
}
LOG
(
INFO
)
<<
"---------------------"
;
LOG
(
INFO
)
<<
"---------------------"
;
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
f1ca00a4
...
@@ -23,11 +23,10 @@ namespace mir {
...
@@ -23,11 +23,10 @@ namespace mir {
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
LOG
(
INFO
)
<<
"final program
\n
"
<<
Visualize
(
graph
.
get
());
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
for
(
auto
&
item
:
graph
->
InstructTopologicalOrder
())
{
if
(
item
->
IsInstruct
())
{
if
(
item
->
IsStmt
())
{
auto
&
instruct
=
item
->
AsInstruct
();
auto
&
stmt
=
item
->
AsStmt
();
LOG
(
INFO
)
<<
instruct
;
LOG
(
INFO
)
<<
stmt
;
insts_
.
emplace_back
(
instruct
.
op
,
insts_
.
emplace_back
(
stmt
.
op
,
std
::
move
(
stmt
.
valid_kernels
.
front
()));
std
::
move
(
instruct
.
valid_kernels
.
front
()));
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/graph_visualize_pass.cc
浏览文件 @
f1ca00a4
...
@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
...
@@ -34,16 +34,16 @@ std::string Visualize(mir::SSAGraph* graph) {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
std
::
string
key
;
std
::
string
key
;
if
(
node
.
IsArg
ument
())
{
if
(
node
.
IsArg
())
{
key
=
node
.
AsArg
ument
().
name
;
key
=
node
.
AsArg
().
name
;
}
else
{
}
else
{
key
=
node
.
As
Instruc
t
().
op_type
+
std
::
to_string
(
id
++
);
key
=
node
.
As
Stm
t
().
op_type
+
std
::
to_string
(
id
++
);
}
}
if
(
node
.
Is
Instruc
t
())
{
if
(
node
.
Is
Stm
t
())
{
dot
.
AddNode
(
key
,
{
Dot
::
Attr
(
"shape"
,
"box"
)});
dot
.
AddNode
(
key
,
{
Dot
::
Attr
(
"shape"
,
"box"
)});
for
(
auto
&
x
:
node
.
inlinks
)
{
for
(
auto
&
x
:
node
.
inlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
dot
.
AddNode
(
name
,
{});
}
}
...
@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
...
@@ -51,7 +51,7 @@ std::string Visualize(mir::SSAGraph* graph) {
exists_args
.
insert
(
name
);
exists_args
.
insert
(
name
);
}
}
for
(
auto
&
x
:
node
.
outlinks
)
{
for
(
auto
&
x
:
node
.
outlinks
)
{
auto
name
=
x
->
AsArg
ument
().
name
;
auto
name
=
x
->
AsArg
().
name
;
if
(
!
exists_args
.
count
(
name
))
{
if
(
!
exists_args
.
count
(
name
))
{
dot
.
AddNode
(
name
,
{});
dot
.
AddNode
(
name
,
{});
}
}
...
...
paddle/fluid/lite/core/mir/io_copy_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
@@ -19,20 +19,20 @@ namespace paddle {
...
@@ -19,20 +19,20 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
class
IoCopyKernelPickPass
:
public
Instruction
Pass
{
class
IoCopyKernelPickPass
:
public
Stmt
Pass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
auto
&
inst
=
node
.
As
Stm
t
();
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
if
(
inst
.
op_type
!=
"io_copy"
)
continue
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
LOG
(
INFO
)
<<
"....> picking a IO COPY kernel"
;
auto
&
kernels
=
node
.
As
Instruc
t
().
valid_kernels
;
auto
&
kernels
=
node
.
As
Stm
t
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArg
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
ument
().
type
;
const
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArg
().
type
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"input type "
<<
*
inty
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
LOG
(
INFO
)
<<
"output type "
<<
*
outy
;
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
f1ca00a4
...
@@ -34,15 +34,15 @@ class Node {
...
@@ -34,15 +34,15 @@ class Node {
Node
()
=
default
;
Node
()
=
default
;
enum
class
Role
{
enum
class
Role
{
kArg
ument
=
0
,
kArg
=
0
,
k
Instruc
t
,
k
Stm
t
,
kNumRoles
,
/*should be last*/
kNumRoles
,
/*should be last*/
kUnk
,
kUnk
,
};
};
struct
Instruc
t
{
struct
Stm
t
{
std
::
string
op_type
;
std
::
string
op_type
;
// The kernel instances this
Instruc
t contains.
// The kernel instances this
Statemen
t contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
// TODO(Superjomn) make this a shared_ptr for resource safety.
// TODO(Superjomn) make this a shared_ptr for resource safety.
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
std
::
shared_ptr
<
OpLite
>
op
;
// we hold op to run InferShape
...
@@ -62,13 +62,13 @@ class Node {
...
@@ -62,13 +62,13 @@ class Node {
return
*
valid_kernels
.
front
();
return
*
valid_kernels
.
front
();
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Instruc
t
&
other
)
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Stm
t
&
other
)
{
os
<<
"
Instruc
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
os
<<
"
Statemen
t "
<<
other
.
op_type
<<
" "
<<
other
.
place
();
return
os
;
return
os
;
}
}
};
};
struct
Arg
ument
{
struct
Arg
{
std
::
string
name
;
std
::
string
name
;
const
Type
*
type
{};
const
Type
*
type
{};
// Weight is a special kind of argument, it is marked as weight explicitly
// Weight is a special kind of argument, it is marked as weight explicitly
...
@@ -76,16 +76,16 @@ class Node {
...
@@ -76,16 +76,16 @@ class Node {
bool
is_weight
{
false
};
bool
is_weight
{
false
};
};
};
Arg
ument
&
AsArgument
(
const
std
::
string
&
name
)
{
Arg
&
AsArg
(
const
std
::
string
&
name
)
{
auto
&
x
=
AsArg
ument
();
auto
&
x
=
AsArg
();
x
.
name
=
name
;
x
.
name
=
name
;
return
x
;
return
x
;
}
}
Instruct
&
AsInstruc
t
(
const
std
::
string
&
op_type
,
Stmt
&
AsStm
t
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
const
std
::
shared_ptr
<
OpLite
>&
op
)
{
auto
&
x
=
As
Instruc
t
();
auto
&
x
=
As
Stm
t
();
x
.
op_type
=
op_type
;
x
.
op_type
=
op_type
;
x
.
op
=
op
;
x
.
op
=
op
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
valid_kernels
=
std
::
move
(
kernels
);
...
@@ -93,23 +93,23 @@ class Node {
...
@@ -93,23 +93,23 @@ class Node {
}
}
// Set roles.
// Set roles.
Arg
ument
&
AsArgument
()
{
Arg
&
AsArg
()
{
if
(
role_
!=
Role
::
kUnk
)
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
kArg
ument
);
CHECK
(
role_
==
Role
::
kArg
);
return
*
arg
ument
_
;
return
*
arg_
;
}
}
role_
=
Role
::
kArg
ument
;
role_
=
Role
::
kArg
;
arg
ument_
.
reset
(
new
Argument
);
arg
_
.
reset
(
new
Arg
);
return
*
arg
ument
_
;
return
*
arg_
;
}
}
Instruct
&
AsInstruc
t
()
{
Stmt
&
AsStm
t
()
{
if
(
role_
!=
Role
::
kUnk
)
{
if
(
role_
!=
Role
::
kUnk
)
{
CHECK
(
role_
==
Role
::
k
Instruc
t
);
CHECK
(
role_
==
Role
::
k
Stm
t
);
return
*
instruc
t_
;
return
*
stm
t_
;
}
}
role_
=
Role
::
k
Instruc
t
;
role_
=
Role
::
k
Stm
t
;
instruct_
.
reset
(
new
Instruc
t
);
stmt_
.
reset
(
new
Stm
t
);
return
*
instruc
t_
;
return
*
stm
t_
;
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Node
&
other
)
{
...
@@ -117,26 +117,26 @@ class Node {
...
@@ -117,26 +117,26 @@ class Node {
if
(
!
other
.
IsRoleSet
())
{
if
(
!
other
.
IsRoleSet
())
{
os
<<
"Unk role node"
;
os
<<
"Unk role node"
;
}
}
if
(
other
.
IsArg
ument
())
{
if
(
other
.
IsArg
())
{
auto
&
arg
=
other
.
AsArg
ument
();
auto
&
arg
=
other
.
AsArg
();
os
<<
"Argument "
<<
arg
.
name
;
os
<<
"Argument "
<<
arg
.
name
;
}
}
if
(
other
.
Is
Instruc
t
())
{
if
(
other
.
Is
Stm
t
())
{
auto
&
arg
=
other
.
As
Instruc
t
();
auto
&
arg
=
other
.
As
Stm
t
();
os
<<
"
Instruc
t "
<<
arg
.
op_type
;
os
<<
"
Statemen
t "
<<
arg
.
op_type
;
}
}
return
os
;
return
os
;
}
}
// Check roles.
// Check roles.
bool
IsRoleSet
()
const
{
return
role_
!=
Role
::
kUnk
;
}
bool
IsRoleSet
()
const
{
return
role_
!=
Role
::
kUnk
;
}
bool
Is
Instruct
()
const
{
return
role_
==
Role
::
kInstruc
t
;
}
bool
Is
Stmt
()
const
{
return
role_
==
Role
::
kStm
t
;
}
bool
IsArg
ument
()
const
{
return
role_
==
Role
::
kArgument
;
}
bool
IsArg
()
const
{
return
role_
==
Role
::
kArg
;
}
private:
private:
// Either
instruc
t_ or argument_ is used.
// Either
stm
t_ or argument_ is used.
std
::
unique_ptr
<
Instruct
>
instruc
t_
;
std
::
unique_ptr
<
Stmt
>
stm
t_
;
std
::
unique_ptr
<
Arg
ument
>
argument
_
;
std
::
unique_ptr
<
Arg
>
arg
_
;
Role
role_
{
Role
::
kUnk
};
Role
role_
{
Role
::
kUnk
};
};
};
...
...
paddle/fluid/lite/core/mir/pass.h
浏览文件 @
f1ca00a4
...
@@ -26,8 +26,8 @@ class Pass {
...
@@ -26,8 +26,8 @@ class Pass {
enum
class
Kind
{
enum
class
Kind
{
// Will modify the program/graph topology.
// Will modify the program/graph topology.
kProgramWise
=
0
,
kProgramWise
=
0
,
// Will modify the
instruction
, with the graph topology fixed.
// Will modify the
statement
, with the graph topology fixed.
k
Instruction
Wise
,
k
Stmt
Wise
,
// Will not modify the IR, just collect information or visualization.
// Will not modify the IR, just collect information or visualization.
kDebug
,
kDebug
,
};
};
...
@@ -45,7 +45,7 @@ class Pass {
...
@@ -45,7 +45,7 @@ class Pass {
Kind
kind
()
const
{
return
kind_
;
}
Kind
kind
()
const
{
return
kind_
;
}
bool
is_debug_pass
()
const
{
return
kind_
==
Kind
::
kDebug
;
}
bool
is_debug_pass
()
const
{
return
kind_
==
Kind
::
kDebug
;
}
bool
is_program_pass
()
const
{
return
kind_
==
Kind
::
kProgramWise
;
}
bool
is_program_pass
()
const
{
return
kind_
==
Kind
::
kProgramWise
;
}
bool
is_
instruction_pass
()
const
{
return
kind_
==
Kind
::
kInstruction
Wise
;
}
bool
is_
stmt_pass
()
const
{
return
kind_
==
Kind
::
kStmt
Wise
;
}
virtual
~
Pass
()
=
default
;
virtual
~
Pass
()
=
default
;
...
@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
...
@@ -61,9 +61,9 @@ class ProgramPass : public Pass {
ProgramPass
()
:
Pass
(
Kind
::
kProgramWise
)
{}
ProgramPass
()
:
Pass
(
Kind
::
kProgramWise
)
{}
};
};
class
Instruction
Pass
:
public
Pass
{
class
Stmt
Pass
:
public
Pass
{
public:
public:
InstructionPass
()
:
Pass
(
Kind
::
kInstruction
Wise
)
{}
StmtPass
()
:
Pass
(
Kind
::
kStmt
Wise
)
{}
};
};
class
DebugPass
:
public
Pass
{
class
DebugPass
:
public
Pass
{
...
...
paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc
浏览文件 @
f1ca00a4
...
@@ -19,7 +19,7 @@ namespace paddle {
...
@@ -19,7 +19,7 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
class
RuntimeContextAssignPass
:
public
Instruction
Pass
{
class
RuntimeContextAssignPass
:
public
Stmt
Pass
{
public:
public:
RuntimeContextAssignPass
()
{
RuntimeContextAssignPass
()
{
#ifdef LITE_WITH_CUDA
#ifdef LITE_WITH_CUDA
...
@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
...
@@ -29,9 +29,9 @@ class RuntimeContextAssignPass : public InstructionPass {
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
inst
=
node
.
As
Instruc
t
();
auto
&
inst
=
node
.
As
Stm
t
();
switch
(
inst
.
picked_kernel
().
target
())
{
switch
(
inst
.
picked_kernel
().
target
())
{
case
TARGET
(
kHost
):
case
TARGET
(
kHost
):
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
f1ca00a4
...
@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
...
@@ -37,14 +37,14 @@ std::map<mir::Node *, std::set<mir::Node *>> SSAGraph::BuildOperationAdjList() {
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
mutable_nodes
())
{
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
!
n
.
Is
Instruc
t
())
continue
;
if
(
!
n
.
Is
Stm
t
())
continue
;
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
}
}
std
::
vector
<
mir
::
Node
*>
nodes
;
std
::
vector
<
mir
::
Node
*>
nodes
;
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
PADDLE_ENFORCE
(
adj_n
->
Is
Instruc
t
());
PADDLE_ENFORCE
(
adj_n
->
Is
Stm
t
());
nodes
.
push_back
(
adj_n
);
nodes
.
push_back
(
adj_n
);
}
}
}
}
...
@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
...
@@ -96,7 +96,7 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
arguments_
[
name
]
=
&
new_node
;
}
}
}
}
...
@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
...
@@ -109,7 +109,7 @@ void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
VLOG
(
5
)
<<
"create arg node "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArg
ument
(
name
);
new_node
.
AsArg
(
name
);
arguments_
[
name
]
=
&
new_node
;
arguments_
[
name
]
=
&
new_node
;
}
}
}
}
...
@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
...
@@ -122,7 +122,7 @@ Node *SSAGraph::GraphCreateInstructNode(
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
auto
kernels
=
op
->
CreateKernels
(
valid_places
);
node_storage_
.
back
().
As
Instruc
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
node_storage_
.
back
().
As
Stm
t
(
op
->
op_type_
,
std
::
move
(
kernels
),
op
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
...
@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
...
@@ -202,14 +202,14 @@ bool SSAGraph::CheckNodesRoleSet() {
bool
SSAGraph
::
CheckLinksRoleSet
()
{
bool
SSAGraph
::
CheckLinksRoleSet
()
{
for
(
auto
&
node
:
mutable_nodes
())
{
for
(
auto
&
node
:
mutable_nodes
())
{
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
CHECK_OR_FALSE
(
node
.
IsRoleSet
());
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
for
(
auto
*
x
:
node
.
inlinks
)
{
for
(
auto
*
x
:
node
.
inlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
}
for
(
auto
*
x
:
node
.
outlinks
)
{
for
(
auto
*
x
:
node
.
outlinks
)
{
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsRoleSet
());
CHECK_OR_FALSE
(
x
->
IsArg
ument
());
CHECK_OR_FALSE
(
x
->
IsArg
());
}
}
}
}
return
true
;
return
true
;
...
@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
...
@@ -219,7 +219,7 @@ Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate argument called "
<<
name
;
arguments_
[
name
]
=
&
node_storage_
.
back
();
arguments_
[
name
]
=
&
node_storage_
.
back
();
node_storage_
.
back
().
AsArg
ument
(
name
);
node_storage_
.
back
().
AsArg
(
name
);
return
&
node_storage_
.
back
();
return
&
node_storage_
.
back
();
}
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
f1ca00a4
...
@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
...
@@ -76,7 +76,7 @@ class SSAGraph : GraphBase {
void
MarkArgumentWeights
(
const
Program
&
program
)
{
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
arguments_
[
name
]
->
AsArg
ument
().
is_weight
=
true
;
arguments_
[
name
]
->
AsArg
().
is_weight
=
true
;
}
}
}
}
...
@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
...
@@ -115,9 +115,9 @@ static void DirectedLink(Node *a, Node *b) {
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
static
void
LocalInferenceType
(
Node
*
a
,
Node
*
b
,
const
std
::
string
&
arg_name
)
{
// instr -> output argument
// instr -> output argument
if
(
a
->
Is
Instruct
()
&&
b
->
IsArgument
())
{
if
(
a
->
Is
Stmt
()
&&
b
->
IsArg
())
{
auto
&
inst
=
a
->
As
Instruc
t
();
auto
&
inst
=
a
->
As
Stm
t
();
auto
&
output
=
b
->
AsArg
ument
();
auto
&
output
=
b
->
AsArg
();
if
(
!
output
.
type
)
{
if
(
!
output
.
type
)
{
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
output
.
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
...
@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
...
@@ -125,9 +125,9 @@ static void LocalInferenceType(Node *a, Node *b, const std::string &arg_name) {
}
}
// input argument -> instr
// input argument -> instr
if
(
a
->
IsArg
ument
()
&&
b
->
IsInstruc
t
())
{
if
(
a
->
IsArg
()
&&
b
->
IsStm
t
())
{
auto
&
input
=
a
->
AsArg
ument
();
auto
&
input
=
a
->
AsArg
();
auto
&
inst
=
b
->
As
Instruc
t
();
auto
&
inst
=
b
->
As
Stm
t
();
if
(
!
input
.
type
)
{
if
(
!
input
.
type
)
{
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
input
.
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
}
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
f1ca00a4
...
@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -33,8 +33,8 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
graph
)
<<
"graph not valid"
;
CHECK
(
graph
)
<<
"graph not valid"
;
// sort kernels by the factors.
// sort kernels by the factors.
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
!
node
.
Is
Instruc
t
())
continue
;
if
(
!
node
.
Is
Stm
t
())
continue
;
auto
&
instruct
=
node
.
As
Instruc
t
();
auto
&
instruct
=
node
.
As
Stm
t
();
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
size_t
score
=
KernelGrade
(
*
kernel
);
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
f1ca00a4
...
@@ -33,7 +33,7 @@ namespace mir {
...
@@ -33,7 +33,7 @@ namespace mir {
* - kernel_pick_factors, the factors to consider in picking kernels.
* - kernel_pick_factors, the factors to consider in picking kernels.
* Set them first before execute the pass.
* Set them first before execute the pass.
*/
*/
class
StaticKernelPickPass
:
public
mir
::
Instruction
Pass
{
class
StaticKernelPickPass
:
public
mir
::
Stmt
Pass
{
public:
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
...
...
paddle/fluid/lite/core/mir/type_target_transform_pass.cc
浏览文件 @
f1ca00a4
...
@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -33,7 +33,7 @@ void TypeTargetTransformPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
CHECK
(
!
valid_places_
.
empty
());
CHECK
(
!
valid_places_
.
empty
());
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
!
node
->
Is
Instruc
t
())
continue
;
if
(
!
node
->
Is
Stm
t
())
continue
;
auto
inlinks
=
node
->
inlinks
;
auto
inlinks
=
node
->
inlinks
;
for
(
auto
*
in
:
inlinks
)
{
for
(
auto
*
in
:
inlinks
)
{
ComplementInputs
(
graph
.
get
(),
node
,
in
);
ComplementInputs
(
graph
.
get
(),
node
,
in
);
...
@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
...
@@ -49,22 +49,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
std
::
find
(
inst_node
->
inlinks
.
begin
(),
inst_node
->
inlinks
.
end
(),
in
))
return
;
return
;
CHECK
(
inst_node
->
Is
Instruc
t
());
CHECK
(
inst_node
->
Is
Stm
t
());
auto
&
inst
=
inst_node
->
As
Instruc
t
();
auto
&
inst
=
inst_node
->
As
Stm
t
();
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsRoleSet
());
CHECK
(
in
->
IsArg
ument
());
CHECK
(
in
->
IsArg
());
auto
in_arg_name
=
in
->
AsArg
ument
().
name
;
auto
in_arg_name
=
in
->
AsArg
().
name
;
std
::
string
tmp
;
std
::
string
tmp
;
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
CHECK
(
inst
.
op_info
()
->
GetInputArgname
(
in_arg_name
,
&
tmp
));
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
auto
decl_arg_type
=
inst
.
picked_kernel
().
GetInputDeclType
(
tmp
);
CHECK
(
in
->
AsArg
ument
().
type
);
CHECK
(
in
->
AsArg
().
type
);
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
))
{
if
(
!
TargetCompatibleTo
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
))
{
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
ument
().
name
LOG
(
INFO
)
<<
"found Target unmatched tensor: "
<<
in
->
AsArg
().
name
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
" for kernel "
<<
inst
.
op
->
DebugString
()
<<
" "
<<
*
in
->
AsArg
ument
().
type
<<
" -> "
<<
*
decl_arg_type
;
<<
*
in
->
AsArg
().
type
<<
" -> "
<<
*
decl_arg_type
;
// Add an IoCopy instruction to make the input compatible with other dist.
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst
(
*
in
->
AsArg
ument
().
type
,
*
decl_arg_type
,
in
->
AsArgument
().
name
,
AddIoCopyInst
(
*
in
->
AsArg
().
type
,
*
decl_arg_type
,
in
->
AsArg
().
name
,
graph
,
graph
,
inst_node
,
valid_places_
);
inst_node
,
valid_places_
);
}
}
}
}
...
@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -73,7 +73,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
Node
*
inst_node
,
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
CHECK
(
!
valid_places
.
empty
())
<<
"valid_place should be set"
;
// var -> new_transform_op -> new_var -> inst
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy
Instruc
t Node.
// So there will be a new Argument node and a new IoCopy
Statemen
t Node.
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
node_id
=
[
&
]
{
return
graph
->
nodes
().
size
();
};
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
auto
io_copy_output_name
=
var
+
"/trans/"
+
std
::
to_string
(
node_id
());
...
@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -85,7 +85,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// CHECK(io_copy_op);
// Create the new var manually.
// Create the new var manually.
inst_node
->
As
Instruc
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
As
Stm
t
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
// Create IoCopy Instruction.
// Create IoCopy Instruction.
lite
::
OpDesc
op_desc
;
lite
::
OpDesc
op_desc
;
...
@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -93,16 +93,16 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetInput
(
"Input"
,
{
var
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
op_desc
.
SetOutput
(
"Out"
,
{
io_copy_output_name
});
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Instruc
t
().
op
->
scope
());
io_copy_op
->
Attach
(
op_desc
,
inst_node
->
As
Stm
t
().
op
->
scope
());
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
auto
kernels
=
io_copy_op
->
CreateKernels
(
valid_places
);
io_copy_inst
->
As
Instruc
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
io_copy_inst
->
As
Stm
t
(
"io_copy"
,
std
::
move
(
kernels
),
io_copy_op
);
// Remove the old link
// Remove the old link
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
RemoveDirectedLink
(
graph
->
Argument
(
var
),
inst_node
);
// Update the original instruction OpDesc.
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Update its input to the io_copy_output_name
auto
&
inst
=
inst_node
->
As
Instruc
t
();
auto
&
inst
=
inst_node
->
As
Stm
t
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
auto
inst_program_desc
=
inst
.
op_info
()
->
desc
();
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
...
@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
...
@@ -111,20 +111,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink
(
io_copy_output_arg
,
inst_node
);
DirectedLink
(
io_copy_output_arg
,
inst_node
);
// reset opdesc and update kernel information
// reset opdesc and update kernel information
auto
desc_dummy
=
inst_node
->
As
Instruc
t
().
op
->
op_info
()
->
desc
();
auto
desc_dummy
=
inst_node
->
As
Stm
t
().
op
->
op_info
()
->
desc
();
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
UpdateInputTo
(
&
desc_dummy
,
var
,
io_copy_output_name
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
lite
::
OpDesc
desc_fake
(
desc_dummy
);
inst_node
->
AsInstruct
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
Attach
(
desc_fake
,
inst_node
->
AsStmt
().
op
->
scope
());
inst_node
->
AsInstruct
().
op
->
scope
());
std
::
string
tmp
;
std
::
string
tmp
;
if
(
inst_node
->
As
Instruc
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
if
(
inst_node
->
As
Stm
t
().
op_info
()
->
GetInputArgname
(
"a"
,
&
tmp
))
{
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
CHECK
(
false
)
<<
"get old a "
<<
tmp
;
}
}
for
(
auto
&
kernel
:
inst_node
->
As
Instruc
t
().
valid_kernels
)
{
for
(
auto
&
kernel
:
inst_node
->
As
Stm
t
().
valid_kernels
)
{
inst_node
->
As
Instruc
t
().
op
->
AttachKernel
(
kernel
.
get
());
inst_node
->
As
Stm
t
().
op
->
AttachKernel
(
kernel
.
get
());
}
}
graph
->
CheckValid
();
graph
->
CheckValid
();
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
f1ca00a4
...
@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -34,8 +34,8 @@ class VariablePlaceInferencePass : public DebugPass {
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
CHECK
(
!
graph
->
inputs
().
empty
())
<<
"graph's inputs should be set"
;
for
(
const
auto
&
v
:
graph
->
inputs
())
{
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
// the feed op might in the inputs
if
(
v
->
Is
Instruc
t
())
{
if
(
v
->
Is
Stm
t
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Instruc
t
().
op_type
;
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
As
Stm
t
().
op_type
;
continue
;
continue
;
}
}
...
@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -49,9 +49,9 @@ class VariablePlaceInferencePass : public DebugPass {
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
void
CheckAllArgumentTypeDetermined
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArg
ument
())
{
if
(
node
.
IsArg
())
{
CHECK
(
node
.
AsArg
ument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
CHECK
(
node
.
AsArg
().
type
)
<<
"node "
<<
node
.
AsArg
().
name
<<
" type not determined, "
<<
&
node
;
<<
" type not determined, "
<<
&
node
;
}
}
}
}
}
}
...
@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -59,7 +59,7 @@ class VariablePlaceInferencePass : public DebugPass {
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
As
Instruc
t
();
auto
&
inst
=
x
->
As
Stm
t
();
// The IoCopyOp is a tool operator, it won't support the type inference.
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// LOG(INFO) << "- inferencing type " <<
...
@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -76,7 +76,7 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
if
(
!
arg_node
.
type
)
{
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
arg_node
.
type
=
type
;
...
@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -94,9 +94,9 @@ class VariablePlaceInferencePass : public DebugPass {
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArg
ument
();
auto
&
arg_node
=
node
->
AsArg
();
if
(
!
arg_node
.
type
)
{
if
(
!
arg_node
.
type
)
{
node
->
AsArg
ument
().
type
=
type
;
node
->
AsArg
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
}
}
}
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
f1ca00a4
...
@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
...
@@ -38,7 +38,7 @@ TEST(Optimizer, test) {
optimizer
.
Run
(
std
::
move
(
program
),
places
);
optimizer
.
Run
(
std
::
move
(
program
),
places
);
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
auto
runtime_program
=
optimizer
.
GenRuntimeProgram
();
LOG
(
INFO
)
<<
"num
instruction
s "
<<
runtime_program
->
num_instructions
();
LOG
(
INFO
)
<<
"num
statement
s "
<<
runtime_program
->
num_instructions
();
}
}
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
f1ca00a4
...
@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
...
@@ -152,7 +152,7 @@ class Type : public DataTypeBase {
}
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a
instruction
to transform a type to another.
// is is possible to add a
statement
to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录