Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
65bfecc9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
65bfecc9
编写于
4月 18, 2019
作者:
S
superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make kernel param-type-recorder and typesystem work
上级
1efa91dd
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
459 addition
and
75 deletion
+459
-75
paddle/fluid/lite/core/CMakeLists.txt
paddle/fluid/lite/core/CMakeLists.txt
+2
-1
paddle/fluid/lite/core/kernel.cc
paddle/fluid/lite/core/kernel.cc
+9
-20
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+33
-2
paddle/fluid/lite/core/mir/CMakeLists.txt
paddle/fluid/lite/core/mir/CMakeLists.txt
+1
-0
paddle/fluid/lite/core/mir/generate_program_pass.cc
paddle/fluid/lite/core/mir/generate_program_pass.cc
+1
-1
paddle/fluid/lite/core/mir/io_complement_pass.cc
paddle/fluid/lite/core/mir/io_complement_pass.cc
+3
-1
paddle/fluid/lite/core/mir/node.h
paddle/fluid/lite/core/mir/node.h
+10
-1
paddle/fluid/lite/core/mir/passes.h
paddle/fluid/lite/core/mir/passes.h
+1
-0
paddle/fluid/lite/core/mir/ssa_graph.cc
paddle/fluid/lite/core/mir/ssa_graph.cc
+74
-1
paddle/fluid/lite/core/mir/ssa_graph.h
paddle/fluid/lite/core/mir/ssa_graph.h
+60
-24
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
+1
-0
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
+34
-0
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
+116
-0
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+7
-0
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+79
-20
paddle/fluid/lite/core/optimizer_test.cc
paddle/fluid/lite/core/optimizer_test.cc
+7
-0
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+12
-0
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+9
-4
未找到文件。
paddle/fluid/lite/core/CMakeLists.txt
浏览文件 @
65bfecc9
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
memory_lite SRCS memory.cc
)
cc_library
(
target_wrapper_lite SRCS target_wrapper.cc
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
tensor_lite SRCS tensor.cc DEPS memory_lite
)
cc_library
(
kernel_lite SRCS kernel.cc DEPS type_system
)
cc_library
(
kernel_lite SRCS kernel.cc DEPS type_system
target_wrapper_lite
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
variable_lite SRCS variable.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
op_registry_lite SRCS op_registry.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
cc_library
(
scope_lite SRCS scope.cc
)
...
...
paddle/fluid/lite/core/kernel.cc
浏览文件 @
65bfecc9
...
@@ -17,29 +17,18 @@
...
@@ -17,29 +17,18 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
bool
operator
<
(
const
Place
&
a
,
const
Place
&
b
)
{
if
(
a
.
target
!=
b
.
target
)
return
a
.
target
<
b
.
target
;
else
if
(
a
.
precision
!=
b
.
precision
)
return
a
.
precision
<
b
.
precision
;
else
if
(
a
.
layout
!=
b
.
layout
)
return
a
.
layout
<
b
.
layout
;
return
true
;
}
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
bool
ParamTypeRegistry
::
KeyCmp
::
operator
()(
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
a
,
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
const
ParamTypeRegistry
::
key_t
&
b
)
const
{
if
(
a
.
kernel_type
!=
b
.
kernel_type
)
return
a
.
hash
()
<
b
.
hash
();
return
a
.
kernel_type
<
b
.
kernel_type
;
}
else
if
(
a
.
io
!=
b
.
io
)
return
a
.
io
<
b
.
io
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
else
if
(
a
.
arg_name
!=
b
.
arg_name
)
const
ParamTypeRegistry
::
KernelIdTy
&
other
)
{
return
a
.
arg_name
<
b
.
arg_name
;
std
::
string
io_s
=
other
.
io
==
ParamTypeRegistry
::
IO
::
kInput
?
"in"
:
"out"
;
else
if
(
!
(
a
.
place
==
b
.
place
))
{
os
<<
other
.
kernel_type
<<
":"
<<
other
.
arg_name
<<
":"
<<
io_s
<<
":"
return
a
.
place
<
b
.
place
;
<<
other
.
place
.
DebugString
();
}
return
os
;
return
true
;
}
}
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/kernel.h
浏览文件 @
65bfecc9
...
@@ -55,6 +55,7 @@ class KernelBase {
...
@@ -55,6 +55,7 @@ class KernelBase {
void
Torch
()
{}
void
Torch
()
{}
virtual
Place
place
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
TargetType
target
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
PrecisionType
precision
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
virtual
DataLayoutType
layout
()
const
=
0
;
...
@@ -87,7 +88,9 @@ struct ParamType {
...
@@ -87,7 +88,9 @@ struct ParamType {
:
element_type_hash
(
element_type_hash
)
{}
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{
tensor_place
=
type_
->
place
();
}
std
::
string
DebugString
()
const
{
return
tensor_place
.
DebugString
();
}
};
};
/*
/*
...
@@ -167,15 +170,32 @@ class ParamTypeRegistry {
...
@@ -167,15 +170,32 @@ class ParamTypeRegistry {
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
types_
[
key
]
=
data_type
;
types_
[
key
]
=
data_type
;
CHECK
(
types_
.
count
(
key
));
}
}
ParamType
Retrive
(
const
Place
&
place
,
int
offset
);
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
LOG
(
INFO
)
<<
"Looking for "
<<
key
;
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
x
;
static
ParamTypeRegistry
x
;
return
x
;
return
x
;
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ParamTypeRegistry
&
other
)
{
for
(
auto
&
item
:
other
.
types_
)
{
os
<<
item
.
first
<<
" "
<<
item
.
second
.
DebugString
()
<<
"
\n
"
;
}
return
os
;
}
private:
private:
ParamTypeRegistry
()
=
default
;
ParamTypeRegistry
()
=
default
;
...
@@ -186,6 +206,16 @@ class ParamTypeRegistry {
...
@@ -186,6 +206,16 @@ class ParamTypeRegistry {
Place
place
;
Place
place
;
IO
io
;
IO
io
;
std
::
string
arg_name
;
std
::
string
arg_name
;
size_t
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelIdTy
&
other
);
};
};
using
key_t
=
KernelIdTy
;
using
key_t
=
KernelIdTy
;
...
@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
...
@@ -213,6 +243,7 @@ class OpKernel : public KernelBase {
TargetType
target
()
const
override
{
return
Target
;
}
TargetType
target
()
const
override
{
return
Target
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
PrecisionType
precision
()
const
override
{
return
Precision
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
DataLayoutType
layout
()
const
override
{
return
DataLayout
;
}
Place
place
()
const
override
{
return
Place
{
Target
,
Precision
,
DataLayout
};
}
std
::
string
name
()
const
override
{
std
::
string
name
()
const
override
{
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
return
op_type
()
+
":"
+
TargetToStr
(
Target
)
+
"/"
+
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
PrecisionToStr
(
Precision
)
+
"/"
+
DataLayoutToStr
(
DataLayout
);
...
...
paddle/fluid/lite/core/mir/CMakeLists.txt
浏览文件 @
65bfecc9
...
@@ -8,6 +8,7 @@ cc_library(mir_passes
...
@@ -8,6 +8,7 @@ cc_library(mir_passes
io_complement_pass.cc
io_complement_pass.cc
graph_visualize_pass.cc
graph_visualize_pass.cc
generate_program_pass.cc
generate_program_pass.cc
variable_place_inference_pass.cc
demo_pass.cc
demo_pass.cc
DEPS mir_pass types_lite
)
DEPS mir_pass types_lite
)
...
...
paddle/fluid/lite/core/mir/generate_program_pass.cc
浏览文件 @
65bfecc9
...
@@ -20,7 +20,7 @@ namespace lite {
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
mir
{
namespace
mir
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
void
GenerateProgramPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
for
(
auto
&
item
:
graph
->
Topolot
icalOrder
())
{
for
(
auto
&
item
:
graph
->
InstructTopolog
icalOrder
())
{
if
(
item
->
IsInstruct
())
{
if
(
item
->
IsInstruct
())
{
auto
&
instruct
=
item
->
AsInstruct
();
auto
&
instruct
=
item
->
AsInstruct
();
kernels_
.
emplace_back
(
std
::
move
(
instruct
.
valid_kernels
.
front
()));
kernels_
.
emplace_back
(
std
::
move
(
instruct
.
valid_kernels
.
front
()));
...
...
paddle/fluid/lite/core/mir/io_complement_pass.cc
浏览文件 @
65bfecc9
...
@@ -19,7 +19,9 @@ namespace paddle {
...
@@ -19,7 +19,9 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
mir
{
namespace
mir
{
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{}
void
IoComplementPass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>
&
graph
)
{
// Start from inputs of the graph, those should should have place set.
}
}
// namespace mir
}
// namespace mir
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/mir/node.h
浏览文件 @
65bfecc9
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -44,11 +45,15 @@ class Node {
...
@@ -44,11 +45,15 @@ class Node {
Place
place
;
Place
place
;
// The kernel instances this Instruct contains.
// The kernel instances this Instruct contains.
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
valid_kernels
;
std
::
shared_ptr
<
OpInfo
>
op_info
;
};
};
struct
Argument
{
struct
Argument
{
std
::
string
name
;
std
::
string
name
;
Place
place
;
Place
place
;
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool
is_weight
{
false
};
};
};
Argument
&
AsArgument
(
const
std
::
string
&
name
)
{
Argument
&
AsArgument
(
const
std
::
string
&
name
)
{
...
@@ -57,9 +62,13 @@ class Node {
...
@@ -57,9 +62,13 @@ class Node {
return
x
;
return
x
;
}
}
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
)
{
Instruct
&
AsInstruct
(
const
std
::
string
&
op_type
,
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>&&
kernels
,
const
std
::
shared_ptr
<
lite
::
OpInfo
>&
op_info
)
{
auto
&
x
=
AsInstruct
();
auto
&
x
=
AsInstruct
();
x
.
op_type
=
op_type
;
x
.
op_type
=
op_type
;
x
.
valid_kernels
=
std
::
move
(
kernels
);
x
.
op_info
=
op_info
;
return
x
;
return
x
;
}
}
...
...
paddle/fluid/lite/core/mir/passes.h
浏览文件 @
65bfecc9
...
@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
...
@@ -23,5 +23,6 @@ namespace mir {} // namespace mir
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
demo
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
static_kernel_pick_pass
);
USE_MIR_PASS
(
variable_place_inference_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
io_complement_pass
);
USE_MIR_PASS
(
generate_program_pass
);
USE_MIR_PASS
(
generate_program_pass
);
paddle/fluid/lite/core/mir/ssa_graph.cc
浏览文件 @
65bfecc9
...
@@ -16,6 +16,79 @@
...
@@ -16,6 +16,79 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
mir
{}
// namespace mir
namespace
mir
{
bool
SSAGraph
::
CheckBidirectionalConnection
()
{
LOG
(
INFO
)
<<
"node count "
<<
node_storage_
.
size
();
for
(
auto
&
node
:
node_storage_
)
{
for
(
auto
*
in
:
node
.
inlinks
)
{
CHECK
(
in
->
outlinks
.
end
()
!=
std
::
find
(
in
->
outlinks
.
begin
(),
in
->
outlinks
.
end
(),
&
node
));
}
for
(
auto
*
out
:
node
.
outlinks
)
{
CHECK
(
out
->
inlinks
.
end
()
!=
std
::
find
(
out
->
inlinks
.
begin
(),
out
->
inlinks
.
end
(),
&
node
));
}
}
return
true
;
}
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
SSAGraph
::
BuildOperationAdjList
()
{
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
adj_list
;
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
!
n
.
IsInstruct
())
continue
;
if
(
adj_list
.
find
(
&
n
)
==
adj_list
.
end
())
{
adj_list
[
&
n
]
=
std
::
set
<
mir
::
Node
*>
();
}
std
::
vector
<
mir
::
Node
*>
nodes
;
for
(
auto
&
var
:
n
.
inlinks
)
{
for
(
auto
&
adj_n
:
var
->
inlinks
)
{
PADDLE_ENFORCE
(
adj_n
->
IsInstruct
());
nodes
.
push_back
(
adj_n
);
}
}
std
::
sort
(
nodes
.
begin
(),
nodes
.
end
(),
[](
mir
::
Node
*
node1
,
mir
::
Node
*
node2
)
{
return
node1
>
node2
;
});
adj_list
[
&
n
].
insert
(
std
::
make_move_iterator
(
nodes
.
begin
()),
std
::
make_move_iterator
(
nodes
.
end
()));
}
return
adj_list
;
}
void
SSAGraph
::
SortHelper
(
const
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
&
adj_list
,
mir
::
Node
*
node
,
std
::
set
<
mir
::
Node
*>
*
visited
,
std
::
vector
<
mir
::
Node
*>
*
ret
)
{
visited
->
insert
(
node
);
for
(
auto
adj
:
adj_list
.
at
(
node
))
{
if
(
visited
->
find
(
adj
)
==
visited
->
end
())
{
SortHelper
(
adj_list
,
adj
,
visited
,
ret
);
}
}
ret
->
push_back
(
node
);
}
std
::
vector
<
mir
::
Node
*>
SSAGraph
::
InstructTopologicalOrder
()
{
CheckBidirectionalConnection
();
std
::
stack
<
mir
::
Node
*>
stack
;
std
::
set
<
mir
::
Node
*>
visited
;
std
::
vector
<
mir
::
Node
*>
res
;
auto
adj_list
=
BuildOperationAdjList
();
for
(
auto
adj
:
adj_list
)
{
if
(
visited
.
find
(
adj
.
first
)
==
visited
.
end
())
{
SortHelper
(
adj_list
,
adj
.
first
,
&
visited
,
&
res
);
}
}
return
res
;
}
}
// namespace mir
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/mir/ssa_graph.h
浏览文件 @
65bfecc9
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <stack>
#include <stack>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_lite.h"
...
@@ -34,7 +35,13 @@ struct Program {
...
@@ -34,7 +35,13 @@ struct Program {
std
::
list
<
std
::
string
>
tmp_vars
;
std
::
list
<
std
::
string
>
tmp_vars
;
std
::
list
<
std
::
string
>
weights
;
std
::
list
<
std
::
string
>
weights
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
std
::
list
<
std
::
unique_ptr
<
OpLite
>>
ops
;
lite
::
Scope
*
scope
;
lite
::
Scope
*
scope
{};
};
// Program of kernel.
struct
KernelProgram
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
instructions
;
lite
::
Scope
*
scope
{};
};
};
// An Graph for MIR. It is built from a list of Op and a scope.
// An Graph for MIR. It is built from a list of Op and a scope.
...
@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
...
@@ -59,17 +66,19 @@ class SSAGraph : GraphBase {
// TODO(Superjomn) remove one valid_places here.
// TODO(Superjomn) remove one valid_places here.
op
->
SetValidPlaces
(
valid_places
);
op
->
SetValidPlaces
(
valid_places
);
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_kernel
=
node_storage_
.
back
().
AsInstruct
(
op
->
op_type_
);
node_storage_
.
back
().
AsInstruct
(
new_kernel
.
valid_kernels
=
op
->
CreateKernels
(
valid_places
);
op
->
op_type_
,
op
->
CreateKernels
(
valid_places
),
op
->
op_info
()
);
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
inlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
CHECK
(
new_node
.
outlinks
.
empty
())
<<
"duplicate Build found"
;
// collect inputs and outputs
// collect inputs and outputs
for
(
const
std
::
string
&
name
:
op
->
input_names
())
{
for
(
const
std
::
string
&
name
:
op
->
op_info
()
->
input_names
())
{
new_node
.
inlinks
.
push_back
(
arguments_
.
at
(
name
));
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
inlinks
.
push_back
(
arg
);
arg
->
outlinks
.
push_back
(
&
new_node
);
}
}
for
(
const
std
::
string
&
name
:
op
->
output_names
())
{
for
(
const
std
::
string
&
name
:
op
->
o
p_info
()
->
o
utput_names
())
{
if
(
!
arguments_
.
count
(
name
))
{
if
(
!
arguments_
.
count
(
name
))
{
node_storage_
.
emplace_back
();
node_storage_
.
emplace_back
();
auto
&
new_node
=
node_storage_
.
back
();
auto
&
new_node
=
node_storage_
.
back
();
...
@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
...
@@ -77,33 +86,35 @@ class SSAGraph : GraphBase {
arg
.
name
=
name
;
arg
.
name
=
name
;
arguments_
.
emplace
(
name
,
&
new_node
);
arguments_
.
emplace
(
name
,
&
new_node
);
}
}
new_node
.
outlinks
.
push_back
(
arguments_
.
at
(
name
));
auto
*
arg
=
arguments_
.
at
(
name
);
new_node
.
outlinks
.
push_back
(
arg
);
arg
->
inlinks
.
push_back
(
&
new_node
);
}
}
}
}
MarkArgumentWeights
(
program
);
}
}
void
sort_utils
(
mir
::
Node
*
n
,
std
::
map
<
mir
::
Node
*
,
bool
>
&
visited
,
std
::
vector
<
mir
::
Node
*>
InstructTopologicalOrder
();
std
::
stack
<
mir
::
Node
*>
&
stack
)
{
visited
[
n
]
=
true
;
// The inputs of the graph.
for
(
auto
&
out
:
n
->
outlinks
)
{
std
::
vector
<
mir
::
Node
*>
inputs
()
{
if
(
!
visited
[
out
])
{
std
::
vector
<
mir
::
Node
*>
res
;
sort_utils
(
out
,
visited
,
stack
);
for
(
auto
&
node
:
node_storage_
)
{
if
(
node
.
inlinks
.
empty
())
{
res
.
push_back
(
&
node
);
}
}
}
}
return
res
;
}
}
std
::
vector
<
mir
::
Node
*>
TopoloticalOrder
()
{
// The outputs of the graph.
std
::
map
<
mir
::
Node
*
,
bool
>
visited
;
std
::
vector
<
mir
::
Node
*>
outputs
()
{
std
::
stack
<
mir
::
Node
*>
stack
;
std
::
vector
<
mir
::
Node
*>
res
;
std
::
vector
<
mir
::
Node
*>
res
;
for
(
auto
&
node
:
node_storage_
)
{
for
(
auto
&
n
:
mutable_nodes
())
{
if
(
node
.
outlinks
.
empty
())
{
if
(
!
visited
[
&
n
])
sort_utils
(
&
n
,
visited
,
stack
);
res
.
push_back
(
&
node
);
}
}
while
(
!
stack
.
empty
())
{
res
.
push_back
(
stack
.
top
());
stack
.
pop
();
}
}
return
res
;
return
res
;
}
}
...
@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
...
@@ -111,6 +122,31 @@ class SSAGraph : GraphBase {
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
const
std
::
list
<
mir
::
Node
>
&
nodes
()
const
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
std
::
list
<
mir
::
Node
>
&
mutable_nodes
()
{
return
node_storage_
;
}
mir
::
Node
*
RetriveArgument
(
const
std
::
string
&
arg
)
{
auto
it
=
arguments_
.
find
(
arg
);
if
(
it
!=
arguments_
.
end
())
{
return
it
->
second
;
}
return
nullptr
;
}
private:
// Check the bidirectional connection.
bool
CheckBidirectionalConnection
();
void
MarkArgumentWeights
(
const
Program
&
program
)
{
for
(
const
auto
&
name
:
program
.
weights
)
{
arguments_
[
name
]
->
AsArgument
().
is_weight
=
true
;
}
}
// Build operator inlink edge table.
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
BuildOperationAdjList
();
void
SortHelper
(
const
std
::
map
<
mir
::
Node
*
,
std
::
set
<
mir
::
Node
*>>
&
adj_list
,
mir
::
Node
*
node
,
std
::
set
<
mir
::
Node
*>
*
visited
,
std
::
vector
<
mir
::
Node
*>
*
ret
);
private:
private:
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
list
<
mir
::
Node
>
node_storage_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
std
::
map
<
std
::
string
,
mir
::
Node
*>
arguments_
;
...
...
paddle/fluid/lite/core/mir/static_kernel_pick_pass.cc
浏览文件 @
65bfecc9
...
@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
...
@@ -47,6 +47,7 @@ void StaticKernelPickPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
// TODO(Superjomn) reconsider this.
// TODO(Superjomn) reconsider this.
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
clear
();
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
valid_kernels
.
emplace_back
(
std
::
move
(
scored
.
front
().
second
));
instruct
.
place
=
instruct
.
valid_kernels
.
front
()
->
place
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
LOG
(
INFO
)
<<
"pick "
<<
instruct
.
valid_kernels
.
front
()
->
name
();
}
}
}
}
...
...
paddle/fluid/lite/core/mir/variable_place_inference_pass.cc
浏览文件 @
65bfecc9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/variable_place_inference_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
void
VariablePlaceInferencePass
::
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
{
MarkInputPlace
(
graph
.
get
());
InferenceArgumentPlace
(
graph
.
get
());
}
}
// namespace mir
}
// namespace lite
}
// namespace paddle
REGISTER_MIR_PASS
(
variable_place_inference_pass
,
paddle
::
lite
::
mir
::
VariablePlaceInferencePass
);
paddle/fluid/lite/core/mir/variable_place_inference_pass.h
浏览文件 @
65bfecc9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace
paddle
{
namespace
lite
{
namespace
mir
{
/*
* Mark the place of the variables in the SSAGrpah, it will inference the
* variables' place by the kernels outputs them.
*/
class
VariablePlaceInferencePass
:
public
DebugPass
{
public:
void
Apply
(
std
::
unique_ptr
<
mir
::
SSAGraph
>&
graph
)
override
;
private:
// Mark the place of input arguments.
void
MarkInputPlace
(
SSAGraph
*
graph
)
{
for
(
const
auto
&
v
:
graph
->
inputs
())
{
// the feed op might in the inputs
if
(
v
->
IsInstruct
())
{
LOG
(
INFO
)
<<
"found kernel in inputs "
<<
v
->
AsInstruct
().
op_type
;
continue
;
}
auto
&
arg
=
v
->
AsArgument
();
arg
.
place
.
target
=
argument_default_target_
;
// the other place description can't be determined yet, until their first
// usage by some kernel.
}
}
void
InferenceArgumentPlace
(
SSAGraph
*
graph
)
{
LOG
(
INFO
)
<<
"param-type-registry:
\n
"
<<
ParamTypeRegistry
::
Global
();
for
(
auto
&
x
:
graph
->
InstructTopologicalOrder
())
{
auto
&
inst
=
x
->
AsInstruct
();
CHECK
(
inst
.
place
.
is_valid
())
<<
"kernel's place should be set when loaded"
;
// deal with inputs
for
(
auto
&
arg_name
:
inst
.
op_info
->
input_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
().
Retrieve
<
ParamTypeRegistry
::
IO
::
kInput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
input_argument
().
at
(
arg_name
);
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
UpdatePlace
(
&
arg_node
.
place
,
type
->
tensor_place
);
}
}
for
(
auto
&
arg_name
:
inst
.
op_info
->
output_argnames
())
{
auto
type
=
ParamTypeRegistry
::
Global
()
.
Retrieve
<
ParamTypeRegistry
::
IO
::
kOutput
>
(
inst
.
place
,
inst
.
op_type
,
arg_name
);
CHECK
(
type
)
<<
"no param-type found for "
<<
inst
.
op_type
<<
":"
<<
arg_name
<<
" "
<<
inst
.
place
.
DebugString
();
auto
arg_names
=
inst
.
op_info
->
output_argument
().
at
(
arg_name
);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for
(
auto
&
arg_name
:
arg_names
)
{
auto
*
node
=
graph
->
RetriveArgument
(
arg_name
);
CHECK
(
node
)
<<
"argument "
<<
arg_name
<<
" not exists in the graph"
;
auto
&
arg_node
=
node
->
AsArgument
();
if
(
arg_node
.
place
.
is_valid
())
continue
;
UpdatePlace
(
&
arg_node
.
place
,
type
->
tensor_place
);
}
}
}
}
// Update me's kUnk fields by other's fields.
void
UpdatePlace
(
Place
*
me
,
const
Place
&
other
)
{
CHECK
(
other
.
is_valid
());
if
(
me
->
target
==
TARGET
(
kUnk
))
{
me
->
target
=
other
.
target
;
}
if
(
me
->
precision
==
PRECISION
(
kUnk
))
{
me
->
precision
=
other
.
precision
;
}
if
(
me
->
layout
==
DATALAYOUT
(
kUnk
))
{
me
->
layout
=
other
.
layout
;
}
}
private:
// The default target for arguments, e.g. load weights to CPU memory for CUDA
// computation by default.
TargetType
argument_default_target_
{
TARGET
(
kHost
)};
};
}
// namespace mir
}
// namespace lite
}
// namespace paddle
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
65bfecc9
...
@@ -54,5 +54,12 @@ bool OpLite::Run() {
...
@@ -54,5 +54,12 @@ bool OpLite::Run() {
return
true
;
return
true
;
}
}
bool
OpLite
::
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
CHECK
(
!
op_info_
)
<<
"op_info duplicate build found"
;
op_info_
=
std
::
make_shared
<
OpInfo
>
();
op_info_
->
Build
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
}
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
paddle/fluid/lite/core/op_lite.h
浏览文件 @
65bfecc9
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <glog/logging.h>
#include <glog/logging.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <map>
#include <map>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
...
@@ -41,6 +42,8 @@ class Node;
...
@@ -41,6 +42,8 @@ class Node;
class
SSAGraph
;
class
SSAGraph
;
}
}
class
OpInfo
;
/**
/**
* The base class of an light-weight operators, currently just used in inference
* The base class of an light-weight operators, currently just used in inference
* to eliminate overhead of some operations in current framework.
* to eliminate overhead of some operations in current framework.
...
@@ -78,10 +81,10 @@ class OpLite : public Registry {
...
@@ -78,10 +81,10 @@ class OpLite : public Registry {
// Run this operator.
// Run this operator.
virtual
bool
Run
();
virtual
bool
Run
();
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
Attach
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
;
ExtractInputsAndOutputs
(
opdesc
);
return
AttachImpl
(
opdesc
,
scope
);
const
std
::
shared_ptr
<
OpInfo
>
&
op_info
()
const
{
return
op_info_
;
}
}
std
::
shared_ptr
<
OpInfo
>
&
mutable_op_info
()
{
return
op_info_
;
}
// Human-readable information.
// Human-readable information.
virtual
std
::
string
DebugString
()
const
=
0
;
virtual
std
::
string
DebugString
()
const
=
0
;
...
@@ -91,9 +94,6 @@ class OpLite : public Registry {
...
@@ -91,9 +94,6 @@ class OpLite : public Registry {
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
void
PickKernel
(
const
std
::
vector
<
Place
>
&
valid_places
,
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
KernelStrategy
kernel_strategy
=
KernelStrategy
::
kStatic
);
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
virtual
~
OpLite
()
=
default
;
virtual
~
OpLite
()
=
default
;
protected:
protected:
...
@@ -101,19 +101,6 @@ class OpLite : public Registry {
...
@@ -101,19 +101,6 @@ class OpLite : public Registry {
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
virtual
bool
AttachImpl
(
const
framework
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
=
0
;
lite
::
Scope
*
scope
)
=
0
;
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
// Specify the kernel to run by default. This will specify the value of
// Specify the kernel to run by default. This will specify the value of
// `kernel_place_`.
// `kernel_place_`.
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
virtual
void
StaticPickKernel
(
const
std
::
vector
<
Place
>
&
valid_targets
)
{
...
@@ -141,8 +128,80 @@ class OpLite : public Registry {
...
@@ -141,8 +128,80 @@ class OpLite : public Registry {
std
::
string
op_type_
;
std
::
string
op_type_
;
std
::
vector
<
Place
>
valid_places_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
shared_ptr
<
OpInfo
>
op_info_
;
};
/*
* Operator Information, such as some description. It will be shared by all the
* kernels of the same operator.
*/
class
OpInfo
{
public:
void
Build
(
const
framework
::
OpDesc
&
desc
)
{
ExtractInputsAndOutputs
(
desc
);
CollectInputAndOutputArgnames
(
desc
);
CollectArguments
(
desc
);
}
const
std
::
list
<
std
::
string
>
&
input_names
()
const
{
return
input_names_
;
}
const
std
::
list
<
std
::
string
>
&
output_names
()
const
{
return
output_names_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
input_argument
()
{
return
input_argument_
;
}
const
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
&
output_argument
()
{
return
output_argument_
;
}
const
std
::
list
<
std
::
string
>
&
input_argnames
()
const
{
return
input_argnames_
;
}
const
std
::
list
<
std
::
string
>
&
output_argnames
()
const
{
return
output_argnames_
;
}
private:
void
ExtractInputsAndOutputs
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
input_names_
.
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
const
auto
&
x
:
item
.
second
)
{
output_names_
.
push_back
(
x
);
}
}
}
void
CollectInputAndOutputArgnames
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
InputNames
())
{
input_argnames_
.
push_back
(
item
);
}
for
(
const
auto
&
item
:
opdesc
.
OutputNames
())
{
output_argnames_
.
push_back
(
item
);
}
}
void
CollectArguments
(
const
framework
::
OpDesc
&
opdesc
)
{
for
(
const
auto
&
item
:
opdesc
.
Inputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
input_argument_
[
item
.
first
].
push_back
(
x
);
}
}
for
(
const
auto
&
item
:
opdesc
.
Outputs
())
{
for
(
auto
&
x
:
item
.
second
)
{
output_argument_
[
item
.
first
].
push_back
(
x
);
}
}
}
private:
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
input_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
output_names_
;
std
::
list
<
std
::
string
>
input_argnames_
;
std
::
list
<
std
::
string
>
output_argnames_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
input_argument_
;
std
::
map
<
std
::
string
,
std
::
list
<
std
::
string
>>
output_argument_
;
};
};
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/optimizer_test.cc
浏览文件 @
65bfecc9
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/lite/core/optimizer.h"
#include "paddle/fluid/lite/core/optimizer.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/mir/generate_program_pass.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/pass_manager.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
#include "paddle/fluid/lite/core/mir/static_kernel_pick_pass.h"
...
@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
...
@@ -36,6 +37,12 @@ TEST(Optimizer, test) {
.
ConsiderPrecision
();
.
ConsiderPrecision
();
optimizer
.
Run
(
std
::
move
(
program
),
places
);
optimizer
.
Run
(
std
::
move
(
program
),
places
);
auto
*
program_pass
=
mir
::
PassManager
::
Global
().
LookUp
<
mir
::
GenerateProgramPass
>
(
"generate_program_pass"
);
auto
&
kernels
=
program_pass
->
kernels
();
LOG
(
INFO
)
<<
"get kernels: "
<<
kernels
.
size
();
}
}
}
// namespace lite
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.cc
浏览文件 @
65bfecc9
...
@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
...
@@ -34,6 +34,14 @@ Type::Get<false /*is_unsupported*/, true /*is_tensor*/, TargetType::kX86,
return
&
x
;
return
&
x
;
}
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kHost
);
return
&
x
;
}
template
<
>
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
...
@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
...
@@ -46,6 +54,10 @@ const Type* Type::Get<TensorFp32NCHWTy>(TargetType target) {
case
TargetType
::
kX86
:
case
TargetType
::
kX86
:
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
return
Get
<
false
,
true
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
DataLayoutType
::
kNCHW
>
();
case
TargetType
::
kHost
:
return
Get
<
false
,
true
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
default:
default:
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
LOG
(
FATAL
)
<<
"unsupported target "
<<
TargetToStr
(
target
);
return
nullptr
;
return
nullptr
;
...
...
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
65bfecc9
...
@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
...
@@ -52,8 +52,13 @@ PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
.
BindInput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindInput
(
"Input"
,
TARGET
(
kX86
))})
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindOutput
(
0
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kX86
))})
.
BindInput
(
"Bias"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindInput
(
"W"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
Finalize
();
.
Finalize
();
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录