Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0245a2dd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0245a2dd
编写于
4月 28, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add variable inference pass tester
and code clean
上级
28d27145
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
272 addition
and
157 deletion
+272
-157
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+1
-4
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+11
-0
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+2
-34
paddle/fluid/lite/core/mir/io_complement_pass.h
paddle/fluid/lite/core/mir/io_complement_pass.h
+0
-4
paddle/fluid/lite/core/mir/pass_registry.h
paddle/fluid/lite/core/mir/pass_registry.h
+1
-1
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+6
-5
paddle/fluid/lite/core/mir/ssa_graph_test.cc
paddle/fluid/lite/core/mir/ssa_graph_test.cc
+1
-1
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+1
-2
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
+4
-4
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+15
-10
paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc
...fluid/lite/core/mir/variable_place_inference_pass_test.cc
+80
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+1
-1
paddle/fluid/lite/core/op_registry.cc
paddle/fluid/lite/core/op_registry.cc
+1
-1
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+6
-6
paddle/fluid/lite/core/optimizer.cc
paddle/fluid/lite/core/optimizer.cc
+0
-28
paddle/fluid/lite/core/optimizer.h
paddle/fluid/lite/core/optimizer.h
+26
-6
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+1
-1
paddle/fluid/lite/core/program.h
paddle/fluid/lite/core/program.h
+1
-1
paddle/fluid/lite/core/program_fake_utils.h
paddle/fluid/lite/core/program_fake_utils.h
+63
-0
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+2
-47
paddle/fluid/lite/kernels/cuda/use_kernels.h
paddle/fluid/lite/kernels/cuda/use_kernels.h
+24
-0
paddle/fluid/lite/kernels/host/use_kernels.h
paddle/fluid/lite/kernels/host/use_kernels.h
+22
-0
paddle/fluid/lite/model_parser/runtime.h
paddle/fluid/lite/model_parser/runtime.h
+3
-1
未找到文件。
paddle/fluid/lite/core/kernel.h
浏览文件 @
0245a2dd
...
@@ -96,10 +96,7 @@ class KernelBase {
...
@@ -96,10 +96,7 @@ class KernelBase {
return
type
->
type
;
return
type
->
type
;
}
}
void
set_alias
(
const
std
::
string
&
x
)
{
void
set_alias
(
const
std
::
string
&
x
)
{
alias_
=
x
;
}
alias_
=
x
;
LOG
(
INFO
)
<<
"kernel "
<<
op_type
()
<<
" setting alias "
<<
alias
();
}
const
std
::
string
&
alias
()
const
{
return
alias_
;
}
const
std
::
string
&
alias
()
const
{
return
alias_
;
}
virtual
Place
place
()
const
=
0
;
virtual
Place
place
()
const
=
0
;
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
0245a2dd
...
@@ -24,3 +24,14 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
...
@@ -24,3 +24,14 @@ cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
mir_pass_manager
mir_pass_manager
program_fake_utils
program_fake_utils
)
)
cc_test
(
test_variable_place_infrence_pass SRCS variable_place_inference_pass_test.cc DEPS
ops_lite
host_kernels
kernels_cuda
mir_passes
mir_pass_manager
optimizer_lite
program_fake_utils
target_wrapper_host
target_wrapper_cuda
)
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
0245a2dd
...
@@ -36,10 +36,7 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -36,10 +36,7 @@ void IoComplementPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
ComplementInputs
(
graph
.
get
(),
node
,
in
);
ComplementInputs
(
graph
.
get
(),
node
,
in
);
}
}
}
}
VLOG
(
3
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
// PickIoCopyKernel(graph.get());
LOG
(
INFO
)
<<
"
\n
"
<<
Visualize
(
graph
.
get
());
}
}
void
IoComplementPass
::
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
void
IoComplementPass
::
ComplementInputs
(
SSAGraph
*
graph
,
Node
*
inst_node
,
...
@@ -96,6 +93,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
...
@@ -96,6 +93,7 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
// create Op and kernels.
// create Op and kernels.
auto
io_copy_op
=
LiteOpRegistry
::
Global
().
Create
(
"io_copy"
);
auto
io_copy_op
=
LiteOpRegistry
::
Global
().
Create
(
"io_copy"
);
CHECK
(
io_copy_op
)
<<
"create op ["
<<
io_copy_op
<<
"] failed"
;
// CHECK(io_copy_op);
// CHECK(io_copy_op);
// Create the new var manually.
// Create the new var manually.
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
inst_node
->
AsInstruct
().
op
->
scope
()
->
Var
(
io_copy_output_name
);
...
@@ -144,36 +142,6 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
...
@@ -144,36 +142,6 @@ void IoComplementPass::AddIoCopyInst(const Type& from, const Type& to,
graph
->
CheckValid
();
graph
->
CheckValid
();
}
}
void
IoComplementPass
::
PickIoCopyKernel
(
SSAGraph
*
graph
)
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsInstruct
()
&&
node
.
AsInstruct
().
op_type
==
"io_copy"
)
{
auto
&
kernels
=
node
.
AsInstruct
().
valid_kernels
;
CHECK
(
!
kernels
.
empty
())
<<
"No valid kernels found for IoCopy Op"
;
for
(
auto
&
kernel
:
kernels
)
{
CHECK_EQ
(
node
.
inlinks
.
size
(),
1UL
);
CHECK_EQ
(
node
.
outlinks
.
size
(),
1UL
);
auto
*
inty
=
node
.
inlinks
.
front
()
->
AsArgument
().
type
;
auto
*
outy
=
node
.
outlinks
.
front
()
->
AsArgument
().
type
;
const
Type
*
in_arg_ty
=
kernel
->
GetInputDeclType
(
"Input"
);
if
(
TypeCompatibleTo
(
*
inty
,
*
in_arg_ty
))
{
const
Type
*
out_arg_ty
=
kernel
->
GetOutputDeclType
(
"Out"
);
// Both the input and output type matches, remove other kernels
// directly.
if
(
out_arg_ty
->
target
()
==
outy
->
target
())
{
LOG
(
INFO
)
<<
"get a IOCopy kernel"
;
auto
x
=
std
::
move
(
kernel
);
kernels
.
clear
();
kernels
.
emplace_back
(
std
::
move
(
x
));
break
;
}
}
}
}
}
// Check the compatiblity.
}
void
IoComplementPass
::
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
)
{
void
IoComplementPass
::
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
)
{
CHECK
(
!
valid_places
.
empty
());
CHECK
(
!
valid_places
.
empty
());
valid_places_
=
valid_places
;
valid_places_
=
valid_places
;
...
...
paddle/fluid/lite/core/mir/io_complement_pass.h
浏览文件 @
0245a2dd
...
@@ -26,7 +26,6 @@ static void UpdateInputTo(framework::proto::OpDesc* desc,
...
@@ -26,7 +26,6 @@ static void UpdateInputTo(framework::proto::OpDesc* desc,
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
item
:
*
desc
->
mutable_inputs
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
())
{
for
(
auto
&
input
:
*
item
.
mutable_arguments
())
{
if
(
input
==
from
)
{
if
(
input
==
from
)
{
LOG
(
INFO
)
<<
"** update input argument from "
<<
from
<<
" to "
<<
to
;
input
=
to
;
input
=
to
;
}
}
}
}
...
@@ -49,9 +48,6 @@ class IoComplementPass : public ProgramPass {
...
@@ -49,9 +48,6 @@ class IoComplementPass : public ProgramPass {
void
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
);
void
SetValidPlaces
(
const
std
::
vector
<
Place
>&
valid_places
);
// Pick the right kernel of IoCopy considering the input and output Type.
void
PickIoCopyKernel
(
SSAGraph
*
graph
);
const
std
::
vector
<
Place
>&
valid_places
()
const
{
return
valid_places_
;
};
const
std
::
vector
<
Place
>&
valid_places
()
const
{
return
valid_places_
;
};
private:
private:
...
...
paddle/fluid/lite/core/mir/pass_registry.h
浏览文件 @
0245a2dd
...
@@ -25,7 +25,7 @@ namespace mir {
...
@@ -25,7 +25,7 @@ namespace mir {
class
PassRegistry
{
class
PassRegistry
{
public:
public:
PassRegistry
(
const
std
::
string
&
name
,
mir
::
Pass
*
pass
)
{
PassRegistry
(
const
std
::
string
&
name
,
mir
::
Pass
*
pass
)
{
LOG
(
INFO
)
<<
"Registry add MIR pass "
<<
name
;
VLOG
(
2
)
<<
"Registry add MIR pass "
<<
name
;
PassManager
::
Global
().
AddNewPass
(
name
,
pass
);
PassManager
::
Global
().
AddNewPass
(
name
,
pass
);
}
}
...
...
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
0245a2dd
...
@@ -91,7 +91,9 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
...
@@ -91,7 +91,9 @@ std::vector<mir::Node *> SSAGraph::InstructTopologicalOrder() {
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
void
SSAGraph
::
GraphCreateTmpVarNodes
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
for
(
const
auto
&
name
:
program
.
tmp_vars
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating temp variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
new_node
.
AsArgument
(
name
);
...
@@ -102,7 +104,9 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
...
@@ -102,7 +104,9 @@ void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
void
SSAGraph
::
GraphCreateWeightVarNodes
(
const
Program
&
program
)
{
// create weight nodes.
// create weight nodes.
for
(
const
auto
&
name
:
program
.
weights
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
LOG
(
INFO
)
<<
"create arg node "
<<
name
;
CHECK
(
!
arguments_
.
count
(
name
))
<<
"duplicate creating weight variable: "
<<
name
;
VLOG
(
5
)
<<
"create arg node "
<<
name
;
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
new_node
.
AsArgument
(
name
);
new_node
.
AsArgument
(
name
);
...
@@ -134,10 +138,8 @@ void SSAGraph::Build(const Program &program,
...
@@ -134,10 +138,8 @@ void SSAGraph::Build(const Program &program,
for
(
auto
&
op
:
program
.
ops
)
{
for
(
auto
&
op
:
program
.
ops
)
{
auto
*
op_node
=
GraphCreateInstructNode
(
program
,
op
,
valid_places
);
auto
*
op_node
=
GraphCreateInstructNode
(
program
,
op
,
valid_places
);
LOG
(
INFO
)
<<
"checking op "
<<
op
->
op_type_
;
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
auto
*
arg
=
Argument
(
name
);
auto
*
arg
=
Argument
(
name
);
LOG
(
INFO
)
<<
"input "
<<
name
;
CHECK
(
arg
->
IsRoleSet
());
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
arg
,
op_node
);
DirectedLink
(
arg
,
op_node
);
}
}
...
@@ -145,7 +147,6 @@ void SSAGraph::Build(const Program &program,
...
@@ -145,7 +147,6 @@ void SSAGraph::Build(const Program &program,
if
(
!
arguments_
.
count
(
name
))
{
if
(
!
arguments_
.
count
(
name
))
{
NewArgumentNode
(
name
);
NewArgumentNode
(
name
);
}
}
LOG
(
INFO
)
<<
"output "
<<
name
;
auto
*
arg
=
arguments_
.
at
(
name
);
auto
*
arg
=
arguments_
.
at
(
name
);
CHECK
(
arg
->
IsRoleSet
());
CHECK
(
arg
->
IsRoleSet
());
DirectedLink
(
op_node
,
arg
);
DirectedLink
(
op_node
,
arg
);
...
...
paddle/fluid/lite/core/mir/ssa_graph_test.cc
浏览文件 @
0245a2dd
...
@@ -35,7 +35,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
...
@@ -35,7 +35,7 @@ void BuildFc(framework::ProgramDesc* desc, const std::string& x,
}
}
TEST
(
SSAGraph
,
test
)
{
TEST
(
SSAGraph
,
test
)
{
auto
program
=
FakeProgram
();
auto
program
=
ProgramFaker
();
SSAGraph
graph
;
SSAGraph
graph
;
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
std
::
vector
<
Place
>
places
{{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}};
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
0245a2dd
...
@@ -38,7 +38,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -38,7 +38,6 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
std
::
vector
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
KernelBase
>>>
scored
;
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
for
(
auto
&&
kernel
:
instruct
.
valid_kernels
)
{
size_t
score
=
KernelGrade
(
*
kernel
);
size_t
score
=
KernelGrade
(
*
kernel
);
LOG
(
INFO
)
<<
"kernel "
<<
kernel
->
summary
()
<<
" "
<<
score
;
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
scored
.
emplace_back
(
score
,
std
::
move
(
kernel
));
}
}
...
@@ -49,7 +48,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -49,7 +48,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
VLOG
(
2
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
}
}
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.h
浏览文件 @
0245a2dd
...
@@ -74,10 +74,10 @@ class StaticKernelPickPass : public mir::InstructionPass {
...
@@ -74,10 +74,10 @@ class StaticKernelPickPass : public mir::InstructionPass {
score
+=
kMax
/
static_cast
<
int
>
(
score
+=
kMax
/
static_cast
<
int
>
(
core
::
KernelPickFactor
::
Factor
::
DataLayoutFirst
);
core
::
KernelPickFactor
::
Factor
::
DataLayoutFirst
);
}
}
LOG
(
INFO
)
<<
"picker tactic "
<<
kernel_pick_factors_
;
VLOG
(
4
)
<<
"picker tactic "
<<
kernel_pick_factors_
;
LOG
(
INFO
)
<<
"kernel place "
<<
kernel
.
place
();
VLOG
(
4
)
<<
"kernel place "
<<
kernel
.
place
();
LOG
(
INFO
)
<<
"picker place "
<<
place
();
VLOG
(
4
)
<<
"picker place "
<<
place
();
LOG
(
INFO
)
<<
"score "
<<
score
;
VLOG
(
4
)
<<
"score "
<<
score
;
// The data layout is not considered, for the input and output arguments
// The data layout is not considered, for the input and output arguments
// might have different data layout.
// might have different data layout.
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
0245a2dd
...
@@ -51,49 +51,54 @@ class VariablePlaceInferencePass : public DebugPass {
...
@@ -51,49 +51,54 @@ class VariablePlaceInferencePass : public DebugPass {
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
for
(
auto
&
node
:
graph
->
mutable_nodes
())
{
if
(
node
.
IsArgument
())
{
if
(
node
.
IsArgument
())
{
CHECK
(
node
.
AsArgument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
CHECK
(
node
.
AsArgument
().
type
)
<<
"node "
<<
node
.
AsArgument
().
name
<<
" type not determined
"
;
<<
" type not determined
, "
<<
&
node
;
}
}
}
}
}
}
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
VLOG
(
3
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
AsInstruct
();
auto
&
inst
=
x
->
AsInstruct
();
// The IoCopyOp is a tool operator, it won't support the type inference.
// The IoCopyOp is a tool operator, it won't support the type inference.
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
if
(
inst
.
op_type
==
"io_copy"
)
continue
;
// LOG(INFO) << "- inferencing type " <<
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
// deal with inputs
VLOG
(
4
)
<<
"inferencing op "
<<
inst
.
op_type
;
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
input_argnames
())
{
LOG
(
INFO
)
<<
"-- input arg_name "
<<
arg_name
;
VLOG
(
3
)
<<
"-- input arg_name "
<<
arg_name
;
// check if inputs's place is set, if not set, update them with the
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
// kernel's declaration.
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetInputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
input_argument
().
at
(
arg_name
);
for
(
auto
&
arg_name
:
arg_names
)
{
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
if
(
!
arg_node
.
type
)
{
arg_node
.
type
=
type
;
VLOG
(
4
)
<<
"set type "
<<
*
type
<<
" "
<<
node
;
arg_node
.
type
=
type
;
}
}
}
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
for
(
auto
&
arg_name
:
inst
.
op_info
()
->
output_argnames
())
{
LOG
(
INFO
)
<<
"-- output arg_name "
<<
arg_name
;
VLOG
(
3
)
<<
"-- output arg_name "
<<
arg_name
;
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
type
=
inst
.
picked_kernel
().
GetOutputDeclType
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output_argument
().
at
(
arg_name
);
auto
arg_names
=
inst
.
op_info
()
->
output_argument
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
for
(
auto
&
arg_name
:
arg_names
)
{
LOG
(
INFO
)
<<
"--- var "
<<
arg_name
;
VLOG
(
3
)
<<
"--- var "
<<
arg_name
;
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
auto
*
node
=
graph
->
RetrieveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
type
)
continue
;
if
(
!
arg_node
.
type
)
{
node
->
AsArgument
().
type
=
type
;
node
->
AsArgument
().
type
=
type
;
VLOG
(
3
)
<<
"set type "
<<
*
type
;
}
}
}
}
}
}
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass_test.cc
0 → 100644
浏览文件 @
0245a2dd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/program_fake_utils.h"
#include "paddle/fluid/lite/kernels/cuda/use_kernels.h"
#include "paddle/fluid/lite/kernels/host/use_kernels.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
TEST
(
variable_place_inference_pass
,
test
)
{
std
::
shared_ptr
<
Scope
>
scope
(
new
lite
::
Scope
);
ProgramFaker
program_faker
;
program_faker
.
AddFeed
(
"a"
,
0
);
program_faker
.
AddMul
(
"a"
,
"W"
,
"a1"
);
program_faker
.
AddMul
(
"a1"
,
"W1"
,
"a2"
);
program_faker
.
AddFetch
(
"a2"
,
0
);
program_faker
.
CreateVars
(
scope
.
get
());
auto
*
desc
=
program_faker
.
program
();
Optimizer
optimizer
;
std
::
vector
<
Place
>
places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
),
},
Place
{
TARGET
(
kHost
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
),
},
Place
{
TARGET
(
kCUDA
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
},
});
Program
program
(
*
desc
,
scope
,
places
);
core
::
KernelPickFactor
factor
;
factor
.
ConsiderTarget
();
std
::
vector
<
std
::
string
>
passes
({
"static_kernel_pick_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_complement_pass"
,
//
});
Place
prefered_place
{
TARGET
(
kCUDA
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
),
};
optimizer
.
KernelPickPreferPlace
(
prefered_place
);
optimizer
.
Run
(
std
::
move
(
program
),
places
,
factor
,
passes
);
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
USE_LITE_OP
(
mul
);
USE_LITE_OP
(
feed
);
USE_LITE_OP
(
fetch
);
USE_LITE_OP
(
io_copy
);
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
0245a2dd
...
@@ -35,7 +35,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
...
@@ -35,7 +35,7 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
}
}
CHECK
(
!
kernels
.
empty
())
<<
"No kernel found for Op "
<<
op_type_
;
CHECK
(
!
kernels
.
empty
())
<<
"No kernel found for Op "
<<
op_type_
;
LOG
(
INFO
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
VLOG
(
2
)
<<
"op "
<<
op_type_
<<
" get "
<<
kernels
.
size
()
<<
" kernels"
;
return
kernels
;
return
kernels
;
}
}
...
...
paddle/fluid/lite/core/op_registry.cc
浏览文件 @
0245a2dd
...
@@ -21,7 +21,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
...
@@ -21,7 +21,7 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
,
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
)
{
DataLayoutType
layout
)
{
Place
place
{
target
,
precision
,
layout
};
Place
place
{
target
,
precision
,
layout
};
LOG
(
INFO
)
<<
"creating "
<<
op_type
<<
" kernel for "
<<
place
;
VLOG
(
5
)
<<
"creating "
<<
op_type
<<
" kernel for "
<<
place
;
#define CREATE_KERNEL1(target__, precision__) \
#define CREATE_KERNEL1(target__, precision__) \
switch (layout) { \
switch (layout) { \
case DATALAYOUT(kNCHW): \
case DATALAYOUT(kNCHW): \
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
0245a2dd
...
@@ -81,9 +81,9 @@ class KernelRegistry final {
...
@@ -81,9 +81,9 @@ class KernelRegistry final {
void
Register
(
const
std
::
string
&
name
,
void
Register
(
const
std
::
string
&
name
,
typename
KernelRegistryForTarget
<
Target
,
Precision
,
typename
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>::
creator_t
&&
creator
)
{
Layout
>::
creator_t
&&
creator
)
{
LOG
(
INFO
)
<<
"register for "
<<
TargetToStr
(
Target
)
<<
":"
VLOG
(
3
)
<<
"register for "
<<
TargetToStr
(
Target
)
<<
":"
<<
PrecisionToStr
(
Precision
)
<<
"//"
<<
PrecisionToStr
(
Precision
)
<<
"//"
<<
GetKernelOffset
<
Target
,
Precision
,
Layout
>
();
<<
GetKernelOffset
<
Target
,
Precision
,
Layout
>
();
using
kernel_registor_t
=
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
KernelRegistryForTarget
<
Target
,
Precision
,
Layout
>
;
auto
&
varient
=
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()];
auto
&
varient
=
registries_
[
GetKernelOffset
<
Target
,
Precision
,
Layout
>
()];
...
@@ -144,9 +144,9 @@ class KernelRegistor : public lite::Registor<KernelType> {
...
@@ -144,9 +144,9 @@ class KernelRegistor : public lite::Registor<KernelType> {
public:
public:
KernelRegistor
(
const
std
::
string
&
op_type
,
const
std
::
string
&
alias
)
KernelRegistor
(
const
std
::
string
&
op_type
,
const
std
::
string
&
alias
)
:
Registor
<
KernelType
>
([
=
]
{
:
Registor
<
KernelType
>
([
=
]
{
LOG
(
INFO
)
<<
"Register kernel "
<<
op_type
<<
" for "
VLOG
(
3
)
<<
"Register kernel "
<<
op_type
<<
" for "
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
)
<<
TargetToStr
(
target
)
<<
" "
<<
PrecisionToStr
(
precision
)
<<
" "
<<
DataLayoutToStr
(
layout
)
<<
" alias "
<<
alias
;
<<
" "
<<
DataLayoutToStr
(
layout
)
<<
" alias "
<<
alias
;
KernelRegistry
::
Global
().
Register
<
target
,
precision
,
layout
>
(
KernelRegistry
::
Global
().
Register
<
target
,
precision
,
layout
>
(
op_type
,
[
=
]()
->
std
::
unique_ptr
<
KernelType
>
{
op_type
,
[
=
]()
->
std
::
unique_ptr
<
KernelType
>
{
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
std
::
unique_ptr
<
KernelType
>
x
(
new
KernelType
);
...
...
paddle/fluid/lite/core/optimizer.cc
浏览文件 @
0245a2dd
...
@@ -27,33 +27,5 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
...
@@ -27,33 +27,5 @@ void Optimizer::SpecifyKernelPickTactic(core::KernelPickFactor factor) {
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
*
pass
->
mutable_kernel_pick_factors
()
=
factor
;
}
}
void
Optimizer
::
RunPasses
()
{
std
::
vector
<
std
::
string
>
passes
({
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_complement_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"runtime_context_assign_pass"
,
//
});
for
(
auto
&
pass_type
:
passes
)
{
LOG
(
INFO
)
<<
".. running pass "
<<
pass_type
;
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
(
pass_type
);
CHECK
(
pass
);
if
(
pass
->
name
()
==
"io_complement_pass"
)
{
auto
*
_pass
=
dynamic_cast
<
mir
::
IoComplementPass
*>
(
pass
);
_pass
->
SetValidPlaces
(
valid_places_
);
CHECK
(
!
_pass
->
valid_places
().
empty
());
_pass
->
Apply
(
graph_
);
}
else
{
pass
->
Apply
(
graph_
);
}
}
// mir::PassManager::Global().Run(graph_);
}
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/optimizer.h
浏览文件 @
0245a2dd
...
@@ -41,8 +41,24 @@ class Optimizer {
...
@@ -41,8 +41,24 @@ class Optimizer {
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
.
reset
(
new
mir
::
SSAGraph
);
graph_
->
Build
(
program
,
valid_places
);
graph_
->
Build
(
program
,
valid_places
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
SpecifyKernelPickTactic
(
kernel_pick_factor
);
// InitIoComplement();
InitIoComplement
();
RunPasses
();
if
(
passes
.
empty
())
{
RunPasses
(
std
::
vector
<
std
::
string
>
{{
"static_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_complement_pass"
,
//
"argument_type_display_pass"
,
//
"variable_place_inference_pass"
,
//
"argument_type_display_pass"
,
//
"io_copy_kernel_pick_pass"
,
//
"variable_place_inference_pass"
,
//
"runtime_context_assign_pass"
,
//
}});
}
else
{
RunPasses
(
passes
);
}
exec_scope_
=
program
.
exec_scope
;
exec_scope_
=
program
.
exec_scope
;
}
}
...
@@ -86,11 +102,15 @@ class Optimizer {
...
@@ -86,11 +102,15 @@ class Optimizer {
protected:
protected:
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
void
SpecifyKernelPickTactic
(
core
::
KernelPickFactor
factor
);
// Run the default passes registered in the PassManager.
void
RunPasses
();
// Specify the passes and run them.
// Specify the passes and run them.
void
RunPasses
(
std
::
vector
<
std
::
string
>&
passes
);
void
RunPasses
(
const
std
::
vector
<
std
::
string
>&
passes
)
{
for
(
auto
&
x
:
passes
)
{
LOG
(
INFO
)
<<
"== Running pass "
<<
x
;
auto
*
pass
=
mir
::
PassManager
::
Global
().
LookUp
(
x
);
CHECK
(
pass
);
pass
->
Apply
(
graph_
);
}
}
private:
private:
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
std
::
unique_ptr
<
mir
::
SSAGraph
>
graph_
;
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
0245a2dd
...
@@ -25,7 +25,7 @@ namespace lite {
...
@@ -25,7 +25,7 @@ namespace lite {
TEST
(
Optimizer
,
test
)
{
TEST
(
Optimizer
,
test
)
{
Optimizer
optimizer
;
Optimizer
optimizer
;
auto
program
=
FakeProgram
();
auto
program
=
ProgramFaker
();
std
::
vector
<
Place
>
places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
std
::
vector
<
Place
>
places
({
Place
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)}});
auto
*
pick_pass
=
auto
*
pick_pass
=
...
...
paddle/fluid/lite/core/program.h
浏览文件 @
0245a2dd
...
@@ -64,7 +64,7 @@ struct Program {
...
@@ -64,7 +64,7 @@ struct Program {
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
for
(
auto
*
op_desc
:
program
.
Block
(
0
).
AllOps
())
{
auto
op_type
=
op_desc
->
Type
();
auto
op_type
=
op_desc
->
Type
();
// if (op_type == "feed" || op_type == "fetch") continue;
// if (op_type == "feed" || op_type == "fetch") continue;
LOG
(
INFO
)
<<
"create Op ["
<<
op_type
<<
"]"
;
VLOG
(
4
)
<<
"create Op ["
<<
op_type
<<
"]"
;
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
ops
.
emplace_back
(
LiteOpRegistry
::
Global
().
Create
(
op_type
));
// pick initial kernel
// pick initial kernel
ops
.
back
()
->
PickKernel
(
valid_places
);
ops
.
back
()
->
PickKernel
(
valid_places
);
...
...
paddle/fluid/lite/core/program_fake_utils.h
浏览文件 @
0245a2dd
...
@@ -71,5 +71,68 @@ Program FakeProgram() {
...
@@ -71,5 +71,68 @@ Program FakeProgram() {
return
program
;
return
program
;
}
}
class
ProgramFaker
{
public:
ProgramFaker
()
{}
framework
::
ProgramDesc
*
program
()
{
desc_
.
Flush
();
return
&
desc_
;
}
void
CreateVars
(
lite
::
Scope
*
scope
)
{
for
(
auto
&
var
:
tmp_vars_
)
{
auto
*
x
=
scope
->
Var
(
var
);
x
->
GetMutable
<
lite
::
Tensor
>
();
}
for
(
auto
&
x
:
tmp_vars_
)
{
desc_
.
MutableBlock
(
0
)
->
Var
(
x
);
}
}
void
AddMul
(
const
std
::
string
&
X
,
const
std
::
string
&
Y
,
const
std
::
string
&
out
)
{
tmp_vars_
.
insert
(
X
);
tmp_vars_
.
insert
(
Y
);
tmp_vars_
.
insert
(
out
);
auto
*
block
=
desc_
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"mul"
);
op
->
SetInput
(
"X"
,
{
X
});
op
->
SetInput
(
"Y"
,
{
Y
});
op
->
SetOutput
(
"Out"
,
{
Y
});
op
->
SetAttr
(
"x_num_col_dims"
,
1
);
op
->
SetAttr
(
"y_num_col_dims"
,
1
);
}
void
AddFeed
(
const
std
::
string
&
Out
,
int
col
)
{
tmp_vars_
.
insert
(
Out
);
auto
*
block
=
desc_
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"feed"
);
op
->
SetInput
(
"X"
,
{
"feed"
});
op
->
SetOutput
(
"Out"
,
{
Out
});
op
->
SetAttr
(
"col"
,
col
);
}
void
AddFetch
(
const
std
::
string
&
Input
,
int
col
)
{
tmp_vars_
.
insert
(
Input
);
auto
*
block
=
desc_
.
MutableBlock
(
0
);
auto
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"fetch"
);
op
->
SetInput
(
"X"
,
{
Input
});
op
->
SetOutput
(
"Out"
,
{
"fetch"
});
op
->
SetAttr
(
"col"
,
col
);
}
private:
std
::
set
<
std
::
string
>
tmp_vars_
;
std
::
vector
<
std
::
string
>
weight_vars_
;
framework
::
ProgramDesc
desc_
;
};
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/type_system.h
浏览文件 @
0245a2dd
...
@@ -142,6 +142,8 @@ class Type : public DataTypeBase {
...
@@ -142,6 +142,8 @@ class Type : public DataTypeBase {
}
}
if
(
other
.
is_tensor_
)
{
if
(
other
.
is_tensor_
)
{
os
<<
"<Tensor:"
;
os
<<
"<Tensor:"
;
}
else
{
os
<<
"<"
;
}
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
...
@@ -256,53 +258,6 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
...
@@ -256,53 +258,6 @@ const Type* LookupType(DataTypeBase::ID type_id, bool is_unknown,
bool
is_tensor
,
Place
place
);
bool
is_tensor
,
Place
place
);
// ------------------------- end predefined types ---------------------------
// ------------------------- end predefined types ---------------------------
// NOTE TypeSystem has some overhead, and better to be used in analysis phase.
class
TypeSystem
{
private:
// Put all valid types for Variables here!
TypeSystem
()
{
// Tensor is a valid data type for Variable.
Register
<
Tensor
>
(
"tensor"
);
}
public:
static
TypeSystem
&
Global
()
{
static
TypeSystem
x
;
return
x
;
}
template
<
typename
T
>
void
Register
(
const
std
::
string
&
type
)
{
size_t
hash
=
typeid
(
T
).
hash_code
();
CHECK
(
!
types_
.
count
(
hash
))
<<
"duplicate register type "
<<
type
<<
" found!"
;
types_
[
hash
]
=
type
;
names_
.
insert
(
type
);
}
template
<
typename
T
>
bool
Contains
()
const
{
return
types_
.
count
(
typeid
(
T
).
hash_code
());
}
bool
Contains
(
size_t
hash
)
const
{
return
types_
.
count
(
hash
);
}
bool
Contains
(
const
std
::
string
&
type
)
{
return
names_
.
count
(
type
);
}
std
::
string
DebugInfo
()
const
{
std
::
stringstream
ss
;
for
(
const
auto
&
it
:
types_
)
{
ss
<<
it
.
second
<<
"
\n
"
;
}
return
ss
.
str
();
}
private:
std
::
unordered_map
<
size_t
/*hash*/
,
std
::
string
/*name*/
>
types_
;
TypeSystem
(
const
TypeSystem
&
)
=
delete
;
std
::
unordered_set
<
std
::
string
>
names_
;
};
/*
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* can represent any Variable data type.
...
...
paddle/fluid/lite/kernels/cuda/use_kernels.h
0 → 100644
浏览文件 @
0245a2dd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/op_registry.h"
// TODO(Superjomn) make this file a library, that will make compile dependency
// easier.
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL
(
mul
,
kCUDA
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
host_to_device
);
USE_LITE_KERNEL
(
io_copy
,
kCUDA
,
kAny
,
kAny
,
device_to_host
);
#endif
paddle/fluid/lite/kernels/host/use_kernels.h
0 → 100644
浏览文件 @
0245a2dd
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/op_registry.h"
USE_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
kNCHW
,
def
);
USE_LITE_KERNEL
(
feed
,
kHost
,
kAny
,
kAny
,
def
);
USE_LITE_KERNEL
(
fetch
,
kHost
,
kAny
,
kAny
,
def
);
paddle/fluid/lite/model_parser/runtime.h
浏览文件 @
0245a2dd
...
@@ -95,7 +95,7 @@ class OpDesc {
...
@@ -95,7 +95,7 @@ class OpDesc {
std
::
string
op_type
;
std
::
string
op_type
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs
;
std
::
map
<
std
::
string
,
variant
<
int
,
std
::
string
>>
attrs
;
std
::
map
<
std
::
string
,
variant
<
int
,
float
,
std
::
string
>>
attrs
;
};
};
class
BlockDesc
{
class
BlockDesc
{
...
@@ -112,6 +112,8 @@ class BlockDesc {
...
@@ -112,6 +112,8 @@ class BlockDesc {
class
ProgramDesc
{
class
ProgramDesc
{
public:
public:
void
Parse
(
const
framework
::
proto
::
ProgramDesc
&
desc
);
void
Parse
(
const
framework
::
proto
::
ProgramDesc
&
desc
);
BlockDesc
block
;
};
};
}
// namespace lite
}
// namespace lite
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录