Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
65bfecc9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65bfecc9
编写于
4月 18, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make kernel param-type-recorder and typesystem work
上级
1efa91dd
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
459 addition
and
75 deletion
+459
-75
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+2
-1
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+9
-20
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+33
-2
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+3
-1
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+10
-1
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+1
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+74
-1
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+60
-24
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+1
-0
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
+34
-0
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+116
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+7
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+79
-20
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+7
-0
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+12
-0
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+9
-4
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
65bfecc9
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
target_wrapper_lite SRCS target_wrapper.cc
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
kernel_lite SRCS kernel.cc DEPS type_system
)
cc_library
(
kernel_lite SRCS kernel.cc DEPS type_system
target_wrapper_lite
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
...
...
paddle/fluid/lite/core/kernel.cc
浏览文件 @
65bfecc9
...
@@ -17,29 +17,18 @@
...
@@ -17,29 +17,18 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
bool
operator
<
(
const
Place
&
a
,
const
Place
&
b
)
{
if
(
a
.
target
!=
b
.
target
)
return
a
.
target
<
b
.
target
;
else
if
(
a
.
precision
!=
b
.
precision
)
return
a
.
precision
<
b
.
precision
;
else
if
(
a
.
layout
!=
b
.
layout
)
return
a
.
layout
<
b
.
layout
;
return
true
;
}
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
if
(
a
.
kernel_type
!=
b
.
kernel_type
)
return
a
.
hash
()
<
b
.
hash
();
return
a
.
kernel_type
<
b
.
kernel_type
;
}
else
if
(
a
.
io
!=
b
.
io
)
return
a
.
io
<
b
.
io
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
else
if
(
a
.
arg_name
!=
b
.
arg_name
)
const
ParamTypeRegistry
::
KernelIdTy
&
other
)
{
return
a
.
arg_name
<
b
.
arg_name
;
std
::
string
io_s
=
other
.
io
==
ParamTypeRegistry
::
IO
::
kInput
?
"in"
:
"out"
;
else
if
(
!
(
a
.
place
==
b
.
place
))
{
os
<<
other
.
kernel_type
<<
":"
<<
other
.
arg_name
<<
":"
<<
io_s
<<
":"
return
a
.
place
<
b
.
place
;
<<
other
.
place
.
DebugString
();
}
return
os
;
return
true
;
}
}
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
65bfecc9
...
@@ -55,6 +55,7 @@ class KernelBase {
...
@@ -55,6 +55,7 @@ class KernelBase {
void
Torch
()
{}
void
Torch
()
{}
virtual
Place
place
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
...
@@ -87,7 +88,9 @@ struct ParamType {
...
@@ -87,7 +88,9 @@ struct ParamType {
:
element_type_hash
(
element_type_hash
)
{}
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{
tensor_place
=
type_
->
place
();
}
std
::
string
DebugString
()
const
{
return
tensor_place
.
DebugString
();
}
};
};
/*
/*
...
@@ -167,15 +170,32 @@ class ParamTypeRegistry {
...
@@ -167,15 +170,32 @@ class ParamTypeRegistry {
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
types_
[
key
]
=
data_type
;
types_
[
key
]
=
data_type
;
CHECK
(
types_
.
count
(
key
));
}
}
ParamType
Retrive
(
const
Place
&
place
,
int
offset
);
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
LOG
(
INFO
)
<<
"Looking for "
<<
key
;
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
x
;
static
ParamTypeRegistry
x
;
return
x
;
return
x
;
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ParamTypeRegistry
&
other
)
{
for
(
auto
&
item
:
other
.
types_
)
{
os
<<
item
.
first
<<
" "
<<
item
.
second
.
DebugString
()
<<
"
\n
"
;
}
return
os
;
}
private:
private:
ParamTypeRegistry
()
=
default
;
ParamTypeRegistry
()
=
default
;
...
@@ -186,6 +206,16 @@ class ParamTypeRegistry {
...
@@ -186,6 +206,16 @@ class ParamTypeRegistry {
Place
place
;
Place
place
;
IO
io
;
IO
io
;
std
::
string
arg_name
;
std
::
string
arg_name
;
size_t
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelIdTy
&
other
);
};
};
using
key_t
=
KernelIdTy
;
using
key_t
=
KernelIdTy
;
...
@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
...
@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
Place
place
()
const
override
{
return
Place
{
Target
,
Precision
,
DataLayout
};
}
std
::
string
name
()
const
override
{
std
::
string
name
()
const
override
{
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
65bfecc9
...
@@ -8,6 +8,7 @@ cc_library(mir_passes
...
@@ -8,6 +8,7 @@ cc_library(mir_passes
io_complement_pass.cc
io_complement_pass.cc
graph_visualize_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc
demo_pass.cc
DEPS mir_pass types_lite
)
DEPS mir_pass types_lite
)
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
65bfecc9
...
@@ -20,7 +20,7 @@ namespace lite {
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
mir
{
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
for
(
auto
&
item
:
graph
->
Topolot
icalOrder
())
{
for
(
auto
&
item
:
graph
->
InstructTopolog
icalOrder
())
{
if
(
item
->
IsInstruct
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
auto
&
instruct
=
item
->
AsInstruct
();
kernels_
.
emplace_back
(
std
::
move
(
instruct
.
valid_kernels
.
front
()));
kernels_
.
emplace_back
(
std
::
move
(
instruct
.
valid_kernels
.
front
()));
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
65bfecc9
...
@@ -19,7 +19,9 @@ namespace paddle {
...
@@ -19,7 +19,9 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{
// Start from inputs of the graph, those should should have place set.
}
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
65bfecc9
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -44,11 +45,15 @@ class Node {
...
@@ -44,11 +45,15 @@ class Node {
Place
place
;
Place
place
;
// The kernel instances this Instruct contains.
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
shared_ptr
<
OpInfo
>
op_info
;
};
};
struct
Argument
{
struct
Argument
{
std
::
string
name
;
std
::
string
name
;
Place
place
;
Place
place
;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
};
};
Argument
&
AsArgument
(
const
std
::
string
&
name
)
{
Argument
&
AsArgument
(
const
std
::
string
&
name
)
{
...
@@ -57,9 +62,13 @@ class Node {
...
@@ -57,9 +62,13 @@ class Node {
return
x
;
return
x
;
}
}
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
)
{
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
lite
::
OpInfo
>&
op_info
)
{
auto
&
x
=
AsInstruct
();
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
x
.
op_type
=
op_type
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
op_info
=
op_info
;
return
x
;
return
x
;
}
}
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
65bfecc9
...
@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
...
@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
generate_program_pass
);
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
65bfecc9
...
@@ -16,6 +16,79 @@
...
@@ -16,6 +16,79 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{}
// namespace mir
namespace
mir
{
bool
SSAGraph
::
CheckBidirectionalConnection
()
{
LOG
(
INFO
)
<<
"node count "
<<
node_storage_
.
size
();
for
(
auto
&
node
:
node_storage_
)
{
for
(
auto
*
in
:
node
.
inlinks
)
{
CHECK
(
in
->
outlinks
.
end
()
!=
std
::
find
(
in
->
outlinks
.
begin
(),
in
->
outlinks
.
end
(),
&
node
));
}
for
(
auto
*
out
:
node
.
outlinks
)
{
CHECK
(
out
->
inlinks
.
end
()
!=
std
::
find
(
out
->
inlinks
.
begin
(),
out
->
inlinks
.
end
(),
&
node
));
}
}
return
true
;
}
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
SSAGraph
::
BuildOperationAdjList
()
{
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
!
n
.
IsInstruct
())
continue
;
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
}
std
::
vector
<
mir
::
Node
*>
nodes
;
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
PADDLE_ENFORCE
(
adj_n
->
IsInstruct
());
nodes
.
push_back
(
adj_n
);
}
}
std
::
sort
(
nodes
.
begin
(),
nodes
.
end
(),
[](
mir
::
Node
*
node1
,
mir
::
Node
*
node2
)
{
return
node1
>
node2
;
});
adj_list
[
&
n
].
insert
(
std
::
make_move_iterator
(
nodes
.
begin
()),
std
::
make_move_iterator
(
nodes
.
end
()));
}
return
adj_list
;
}
void
SSAGraph
::
SortHelper
(
const
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
&
adj_list
,
mir
::
Node
*
node
,
std
::
set
<
mir
::
Node
*>
*
visited
,
std
::
vector
<
mir
::
Node
*>
*
ret
)
{
visited
->
insert
(
node
);
for
(
auto
adj
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
adj
)
==
visited
->
end
())
{
SortHelper
(
adj_list
,
adj
,
visited
,
ret
);
}
}
ret
->
push_back
(
node
);
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
InstructTopologicalOrder
()
{
CheckBidirectionalConnection
();
std
::
stack
<
mir
::
Node
*>
stack
;
std
::
set
<
mir
::
Node
*>
visited
;
std
::
vector
<
mir
::
Node
*>
res
;
auto
adj_list
=
BuildOperationAdjList
();
for
(
auto
adj
:
adj_list
)
{
if
(
visited
.
find
(
adj
.
first
)
==
visited
.
end
())
{
SortHelper
(
adj_list
,
adj
.
first
,
&
visited
,
&
res
);
}
}
return
res
;
}
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
65bfecc9
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <stack>
#include <stack>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_lite.h"
...
@@ -34,7 +35,13 @@ struct Program {
...
@@ -34,7 +35,13 @@ struct Program {
std
::
list
<
std
::
string
>
tmp_vars
;
std
::
list
<
std
::
string
>
tmp_vars
;
std
::
list
<
std
::
string
>
weights
;
std
::
list
<
std
::
string
>
weights
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
;
lite
::
Scope
*
scope
{};
};
// Program of kernel.
struct
KernelProgram
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
instructions
;
lite
::
Scope
*
scope
{};
};
};
// An Graph for MIR. It is built from a list of Op and a scope.
// An Graph for MIR. It is built from a list of Op and a scope.
...
@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
...
@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
// TODO(Superjomn) remove one valid_places here.
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
);
node_storage_
.
back
().
AsInstruct
(
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
->
op_info
()
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
}
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
for
(
const
std
::
string
&
name
:
op
->
o
p_info
()
->
o
utput_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
if
(
!
arguments_
.
count
(
name
))
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
...
@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
...
@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
arg
.
name
=
name
;
arg
.
name
=
name
;
arguments_
.
emplace
(
name
,
&
new_node
);
arguments_
.
emplace
(
name
,
&
new_node
);
}
}
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
outlinks
.
push_back
(
arg
);
arg
->
inlinks
.
push_back
(
&
new_node
);
}
}
}
}
MarkArgumentWeights
(
program
);
}
}
void
sort_utils
(
mir
::
Node
*
n
,
std
::
map
<
mir
::
Node
*
,
bool
>
&
visited
,
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
std
::
stack
<
mir
::
Node
*>
&
stack
)
{
visited
[
n
]
=
true
;
// The inputs of the graph.
for
(
auto
&
out
:
n
->
outlinks
)
{
std
::
vector
<
mir
::
Node
*>
inputs
()
{
if
(
!
visited
[
out
])
{
std
::
vector
<
mir
::
Node
*>
res
;
sort_utils
(
out
,
visited
,
stack
);
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
}
}
return
res
;
}
}
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
{
// The outputs of the graph.
std
::
map
<
mir
::
Node
*
,
bool
>
visited
;
std
::
vector
<
mir
::
Node
*>
outputs
()
{
std
::
stack
<
mir
::
Node
*>
stack
;
std
::
vector
<
mir
::
Node
*>
res
;
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
node
.
outlinks
.
empty
())
{
if
(
!
visited
[
&
n
])
sort_utils
(
&
n
,
visited
,
stack
);
res
.
push_back
(
&
node
);
}
}
while
(
!
stack
.
empty
())
{
res
.
push_back
(
stack
.
top
());
stack
.
pop
();
}
}
return
res
;
return
res
;
}
}
...
@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
...
@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
mir
::
Node
*
RetriveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
}
private:
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
arguments_
[
name
]
->
AsArgument
().
is_weight
=
true
;
}
}
// Build operator inlink edge table.
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
BuildOperationAdjList
();
void
SortHelper
(
const
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
&
adj_list
,
mir
::
Node
*
node
,
std
::
set
<
mir
::
Node
*>
*
visited
,
std
::
vector
<
mir
::
Node
*>
*
ret
);
private:
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
65bfecc9
...
@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
place
=
instruct
.
valid_kernels
.
front
()
->
place
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
}
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
浏览文件 @
65bfecc9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/variable_place_inference_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
VariablePlaceInferencePass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
MarkInputPlace
(
graph
.
get
());
InferenceArgumentPlace
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
variable_place_inference_pass
,
paddle
::
lite
::
mir
::
VariablePlaceInferencePass
);
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
65bfecc9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
/*
* Mark the place of the variables in the SSAGrpah, it will inference the
* variables' place by the kernels outputs them.
*/
class
VariablePlaceInferencePass
:
public
DebugPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
private:
// Mark the place of input arguments.
void
MarkInputPlace
(
SSAGraph
*
graph
)
{
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
IsInstruct
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsInstruct
().
op_type
;
continue
;
}
auto
&
arg
=
v
->
AsArgument
();
arg
.
place
.
target
=
argument_default_target_
;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
AsInstruct
();
CHECK
(
inst
.
place
.
is_valid
())
<<
"kernel's place should be set when loaded"
;
// deal with inputs
for
(
auto
&
arg_name
:
inst
.
op_info
->
input_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
input_argument
().
at
(
arg_name
);
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
UpdatePlace
(
&
arg_node
.
place
,
type
->
tensor_place
);
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
->
output_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
()
.
Retrieve
<
ParamTypeRegistry
::
IO
::
kOutput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
output_argument
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
UpdatePlace
(
&
arg_node
.
place
,
type
->
tensor_place
);
}
}
}
}
// Update me's kUnk fields by other's fields.
void
UpdatePlace
(
Place
*
me
,
const
Place
&
other
)
{
CHECK
(
other
.
is_valid
());
if
(
me
->
target
==
TARGET
(
kUnk
))
{
me
->
target
=
other
.
target
;
}
if
(
me
->
precision
==
PRECISION
(
kUnk
))
{
me
->
precision
=
other
.
precision
;
}
if
(
me
->
layout
==
DATALAYOUT
(
kUnk
))
{
me
->
layout
=
other
.
layout
;
}
}
private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA
// computation by default.
TargetType
argument_default_target_
{
TARGET
(
kHost
)};
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
65bfecc9
...
@@ -54,5 +54,12 @@ bool OpLite::Run() {
...
@@ -54,5 +54,12 @@ bool OpLite::Run() {
return
true
;
return
true
;
}
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
!
op_info_
)
<<
"op_info duplicate build found"
;
op_info_
=
std
::
make_shared
<
OpInfo
>
();
op_info_
->
Build
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
}
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
65bfecc9
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <map>
#include <map>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
...
@@ -41,6 +42,8 @@ class Node;
...
@@ -41,6 +42,8 @@ class Node;
class
SSAGraph
;
class
SSAGraph
;
}
}
class
OpInfo
;
/**
/**
* The base class of an light-weight operators, currently just used in inference
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
* to eliminate overhead of some operations in current framework.
...
@@ -78,10 +81,10 @@ class OpLite : public Registry {
...
@@ -78,10 +81,10 @@ class OpLite : public Registry {
// Run this operator.
// Run this operator.
virtual
bool
Run
();
virtual
bool
Run
();
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
;
ExtractInputsAndOutputs
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
()
const
{
return
op_info_
;
}
}
std
::
shared_ptr
<
OpInfo
>
&
mutable_op_info
()
{
return
op_info_
;
}
// Human-readable information.
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
virtual
std
::
string
DebugString
()
const
=
0
;
...
@@ -91,9 +94,6 @@ class OpLite : public Registry {
...
@@ -91,9 +94,6 @@ class OpLite : public Registry {
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
virtual
~
OpLite
()
=
default
;
virtual
~
OpLite
()
=
default
;
protected:
protected:
...
@@ -101,19 +101,6 @@ class OpLite : public Registry {
...
@@ -101,19 +101,6 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
lite
::
Scope
*
scope
)
=
0
;
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
@@ -141,8 +128,80 @@ class OpLite : public Registry {
...
@@ -141,8 +128,80 @@ class OpLite : public Registry {
std
::
string
op_type_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
shared_ptr
<
OpInfo
>
op_info_
;
};
/*
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class
OpInfo
{
public:
void
Build
(
const
framework
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
{
return
output_argument_
;
}
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
{
return
input_argnames_
;
}
const
std
::
list
<
std
::
string
>
&
output_argnames
()
const
{
return
output_argnames_
;
}
private:
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
void
CollectInputAndOutputArgnames
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
InputNames
())
{
input_argnames_
.
push_back
(
item
);
}
for
(
const
auto
&
item
:
opdesc
.
OutputNames
())
{
output_argnames_
.
push_back
(
item
);
}
}
void
CollectArguments
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
input_argument_
[
item
.
first
].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
output_argument_
[
item
.
first
].
push_back
(
x
);
}
}
}
private:
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
input_argnames_
;
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
65bfecc9
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
...
@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
...
@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
.
ConsiderPrecision
();
.
ConsiderPrecision
();
optimizer
.
Run
(
std
::
move
(
program
),
places
);
optimizer
.
Run
(
std
::
move
(
program
),
places
);
auto
*
program_pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
auto
&
kernels
=
program_pass
->
kernels
();
LOG
(
INFO
)
<<
"get kernels: "
<<
kernels
.
size
();
}
}
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
65bfecc9
...
@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
...
@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
return
&
x
;
return
&
x
;
}
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kHost
);
return
&
x
;
}
template
<
>
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
...
@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
...
@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
case
TargetType
::
kX86
:
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
DataLayoutType
::
kNCHW
>
();
case
TargetType
::
kHost
:
return
Get
<
false
,
true
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
default:
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
return
nullptr
;
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
65bfecc9
...
@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
...
@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
.
BindInput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindInput
(
"Input"
,
TARGET
(
kX86
))})
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindOutput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindInput
(
"W"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
.
Finalize
();
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录