Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
25990d29
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25990d29
编写于
4月 21, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make kernel_registry support multiple kernels for single type
上级
e55a5cd9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
88 addition
and
53 deletion
+88
-53
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+19
-2
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-0
paddle/fluid/lite/core/op_registry.cc
paddle/fluid/lite/core/op_registry.cc
+3
-4
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+39
-33
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/relu_compute.h
paddle/fluid/lite/kernels/host/relu_compute.h
+2
-2
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+2
-2
paddle/fluid/lite/operators/fc_op.cc
paddle/fluid/lite/operators/fc_op.cc
+1
-1
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+12
-3
未找到文件。
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
25990d29
...
@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
...
@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
CHECK
(
!
op_type_
.
empty
())
<<
"op_type_ should be set first"
;
CHECK
(
!
op_type_
.
empty
())
<<
"op_type_ should be set first"
;
for
(
auto
place
:
places
)
{
for
(
auto
place
:
places
)
{
kernels
.
emplace_back
(
KernelRegistry
::
Global
().
Create
(
auto
ks
=
KernelRegistry
::
Global
().
Create
(
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
place
.
precision
));
place
.
precision
);
for
(
auto
&&
it
:
ks
)
{
kernels
.
emplace_back
(
std
::
move
(
it
));
}
}
}
return
kernels
;
return
kernels
;
...
@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
...
@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return
AttachImpl
(
opdesc
,
scope
);
return
AttachImpl
(
opdesc
,
scope
);
}
}
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
{
auto
*
var
=
scope
->
FindVar
(
name
);
CHECK
(
var
)
<<
"no variable called "
<<
name
<<
" found"
;
return
&
var
->
Get
<
lite
::
Tensor
>
();
}
Tensor
*
OpLite
::
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
{
auto
*
var
=
scope
->
FindVar
(
name
);
CHECK
(
var
)
<<
"no variable called "
<<
name
<<
" found"
;
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
for
(
auto
&
item
:
input_argument_
)
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
25990d29
...
@@ -119,6 +119,9 @@ class OpLite : public Registry {
...
@@ -119,6 +119,9 @@ class OpLite : public Registry {
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
Tensor
*
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
Tensor
*
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
friend
class
mir
::
Node
;
friend
class
mir
::
Node
;
friend
class
mir
::
SSAGraph
;
friend
class
mir
::
SSAGraph
;
...
...
paddle/fluid/lite/core/op_registry.cc
浏览文件 @
25990d29
...
@@ -17,9 +17,8 @@
...
@@ -17,9 +17,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
std
::
unique_ptr
<
KernelBase
>
KernelRegistry
::
Create
(
const
std
::
string
&
op_type
,
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
KernelRegistry
::
Create
(
TargetType
target
,
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
#define CREATE_KERNEL(target__) \
switch (precision) { \
switch (precision) { \
case PRECISION(kFloat): \
case PRECISION(kFloat): \
...
@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
...
@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
}
}
#undef CREATE_KERNEL
#undef CREATE_KERNEL
return
nullptr
;
return
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
()
;
}
}
KernelRegistry
::
KernelRegistry
()
{
KernelRegistry
::
KernelRegistry
()
{
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
25990d29
...
@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> {
...
@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> {
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
>
using
KernelRegistryForTarget
=
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>
,
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
KernelBase
>>
;
std
::
unique_ptr
<
OpKernel
<
Target
,
Precision
>>>
;
class
KernelRegistry
final
{
class
KernelRegistry
final
{
public:
public:
...
@@ -80,16 +79,16 @@ class KernelRegistry final {
...
@@ -80,16 +79,16 @@ class KernelRegistry final {
}
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
>
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
)
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>
>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
.
template
get
<
kernel_registor_t
*
>()
->
Create
(
op_type
);
->
Create
s
(
op_type
);
}
}
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
,
std
::
list
<
std
::
unique_ptr
<
KernelBase
>
>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
TargetType
target
,
PrecisionType
precision
);
PrecisionType
precision
);
// Get a kernel registry offset in all the registries.
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
template
<
TargetType
Target
,
PrecisionType
Precision
>
...
@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> {
...
@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> {
// Kernel registry
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
op_type__##__##target__##__##precision__##__registor__instance__
alias__) \
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__instance__##alias__
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
PRECISION(precision__), KernelClass> \
alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
static paddle::lite::KernelRegistor<TARGET(target__), \
precision__)(#op_type__); \
PRECISION(precision__), KernelClass> \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
int touch_##op_type__##target__##precision__() { \
alias__)(#op_type__); \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
return 0; \
alias__); \
} \
int touch_##op_type__##target__##precision__##alias__() { \
static bool op_type__##target__##precision__##param_register \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
return 0; \
TARGET(target__), PRECISION(precision__)>(#op_type__)
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
alias__) __attribute__((unused)) = \
extern int touch_##op_type__##target__##precision__(); \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
PRECISION(precision__)>( \
__attribute__((unused)) = touch_##op_type__##target__##precision__();
#op_type__)
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
25990d29
...
@@ -24,7 +24,7 @@ namespace host {
...
@@ -24,7 +24,7 @@ namespace host {
// NOTE should use pure std C++ implementation.
// NOTE should use pure std C++ implementation.
void
FcCompute
::
Run
()
{
void
FcCompute
::
Run
()
{
auto
&
param
=
this
->
p
aram
<
operators
::
FcParam
>
();
auto
&
param
=
this
->
P
aram
<
operators
::
FcParam
>
();
CHECK_GE
(
param
.
input
->
dims
().
size
(),
2UL
);
CHECK_GE
(
param
.
input
->
dims
().
size
(),
2UL
);
CHECK_EQ
(
param
.
output
->
dims
().
size
(),
2UL
);
CHECK_EQ
(
param
.
output
->
dims
().
size
(),
2UL
);
...
@@ -51,7 +51,8 @@ void FcCompute::Run() {
...
@@ -51,7 +51,8 @@ void FcCompute::Run() {
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
.
BindInput
(
"Input"
,
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
25990d29
...
@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
FeedParam
;
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
FeedParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
FeedParam
>
();
const
Tensor
&
feed_item
=
theparam
.
feed_list
->
at
(
theparam
.
col
);
const
Tensor
&
feed_item
=
theparam
.
feed_list
->
at
(
theparam
.
col
);
theparam
.
out
->
CopyDataFrom
(
feed_item
);
theparam
.
out
->
CopyDataFrom
(
feed_item
);
}
}
...
@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
)
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
...
...
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
25990d29
...
@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
MulParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
MulParam
>
();
core
::
dim2
x_shape
(
core
::
dim2
x_shape
(
{
product
(
theparam
.
x
->
dims
().
begin
(),
{
product
(
theparam
.
x
->
dims
().
begin
(),
theparam
.
x
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
),
theparam
.
x
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
),
...
@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
MulCompute
)
paddle
::
lite
::
kernels
::
host
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
TARGET
(
kHost
))})
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
...
...
paddle/fluid/lite/kernels/host/relu_compute.h
浏览文件 @
25990d29
...
@@ -24,7 +24,7 @@ namespace host {
...
@@ -24,7 +24,7 @@ namespace host {
class
ReluCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
class
ReluCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
public:
public:
void
Run
()
override
{
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
ReluParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
ReluParam
>
();
auto
n
=
product
(
theparam
.
input
->
dims
());
auto
n
=
product
(
theparam
.
input
->
dims
());
const
float
*
input
=
theparam
.
input
->
data
<
float
>
();
const
float
*
input
=
theparam
.
input
->
data
<
float
>
();
float
*
output
=
theparam
.
output
->
mutable_data
<
float
>
();
float
*
output
=
theparam
.
output
->
mutable_data
<
float
>
();
...
@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ReluCompute
)
paddle
::
lite
::
kernels
::
host
::
ReluCompute
,
def
)
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
25990d29
...
@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
ScaleParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
ScaleParam
>
();
scale_compute
(
theparam
.
x
->
data
<
float
>
(),
theparam
.
x
->
mutable_data
<
float
>
(),
scale_compute
(
theparam
.
x
->
data
<
float
>
(),
theparam
.
x
->
mutable_data
<
float
>
(),
product
(
theparam
.
x
->
dims
()),
theparam
.
scale
,
theparam
.
bias
,
product
(
theparam
.
x
->
dims
()),
theparam
.
scale
,
theparam
.
bias
,
theparam
.
bias_after_scale
);
theparam
.
bias_after_scale
);
...
@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
}
// namespace paddle
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
)
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
Finalize
();
.
Finalize
();
paddle/fluid/lite/operators/fc_op.cc
浏览文件 @
25990d29
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "fc_op.h"
#include "
paddle/fluid/lite/operators/
fc_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
...
paddle/fluid/lite/utils/factory.h
浏览文件 @
25990d29
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <list>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <unordered_map>
#include <unordered_map>
...
@@ -49,13 +50,21 @@ class Factory {
...
@@ -49,13 +50,21 @@ class Factory {
void
Register
(
const
std
::
string
&
op_type
,
creator_t
&&
creator
)
{
void
Register
(
const
std
::
string
&
op_type
,
creator_t
&&
creator
)
{
CHECK
(
!
creators_
.
count
(
op_type
))
<<
"The op "
<<
op_type
CHECK
(
!
creators_
.
count
(
op_type
))
<<
"The op "
<<
op_type
<<
" has already registered"
;
<<
" has already registered"
;
creators_
.
emplace
(
op_type
,
std
::
move
(
creator
));
creators_
[
op_type
].
emplace_back
(
std
::
move
(
creator
));
}
}
item_ptr_t
Create
(
const
std
::
string
&
op_type
)
const
{
item_ptr_t
Create
(
const
std
::
string
&
op_type
)
const
{
return
std
::
move
(
Creates
(
op_type
).
front
());
}
std
::
list
<
item_ptr_t
>
Creates
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
creators_
.
find
(
op_type
);
auto
it
=
creators_
.
find
(
op_type
);
CHECK
(
it
!=
creators_
.
end
())
<<
"no item called "
<<
op_type
;
CHECK
(
it
!=
creators_
.
end
())
<<
"no item called "
<<
op_type
;
return
it
->
second
();
std
::
list
<
item_ptr_t
>
res
;
for
(
auto
&
c
:
it
->
second
)
{
res
.
emplace_back
(
c
());
}
return
res
;
}
}
std
::
string
DebugString
()
const
{
std
::
string
DebugString
()
const
{
...
@@ -67,7 +76,7 @@ class Factory {
...
@@ -67,7 +76,7 @@ class Factory {
}
}
protected:
protected:
std
::
unordered_map
<
std
::
string
,
creator_t
>
creators_
;
std
::
unordered_map
<
std
::
string
,
std
::
list
<
creator_t
>
>
creators_
;
};
};
/* A helper function to help run a lambda at the start.
/* A helper function to help run a lambda at the start.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录