Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
25990d29
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25990d29
编写于
4月 21, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make kernel_registry support multiple kernels for single type
上级
e55a5cd9
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
88 addition
and
53 deletion
+88
-53
paddle/fluid/lite/core/op_lite.cc
paddle/fluid/lite/core/op_lite.cc
+19
-2
paddle/fluid/lite/core/op_lite.h
paddle/fluid/lite/core/op_lite.h
+3
-0
paddle/fluid/lite/core/op_registry.cc
paddle/fluid/lite/core/op_registry.cc
+3
-4
paddle/fluid/lite/core/op_registry.h
paddle/fluid/lite/core/op_registry.h
+39
-33
paddle/fluid/lite/kernels/host/fc_compute.cc
paddle/fluid/lite/kernels/host/fc_compute.cc
+3
-2
paddle/fluid/lite/kernels/host/feed_compute.cc
paddle/fluid/lite/kernels/host/feed_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/mul_compute.cc
paddle/fluid/lite/kernels/host/mul_compute.cc
+2
-2
paddle/fluid/lite/kernels/host/relu_compute.h
paddle/fluid/lite/kernels/host/relu_compute.h
+2
-2
paddle/fluid/lite/kernels/host/scale_compute.cc
paddle/fluid/lite/kernels/host/scale_compute.cc
+2
-2
paddle/fluid/lite/operators/fc_op.cc
paddle/fluid/lite/operators/fc_op.cc
+1
-1
paddle/fluid/lite/utils/factory.h
paddle/fluid/lite/utils/factory.h
+12
-3
未找到文件。
paddle/fluid/lite/core/op_lite.cc
浏览文件 @
25990d29
...
...
@@ -25,9 +25,12 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
CHECK
(
!
op_type_
.
empty
())
<<
"op_type_ should be set first"
;
for
(
auto
place
:
places
)
{
kernels
.
emplace_back
(
KernelRegistry
::
Global
().
Create
(
auto
ks
=
KernelRegistry
::
Global
().
Create
(
(
kernel_type
.
empty
()
?
op_type_
:
kernel_type
),
place
.
target
,
place
.
precision
));
place
.
precision
);
for
(
auto
&&
it
:
ks
)
{
kernels
.
emplace_back
(
std
::
move
(
it
));
}
}
return
kernels
;
...
...
@@ -61,6 +64,20 @@ bool OpLite::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
return
AttachImpl
(
opdesc
,
scope
);
}
const
Tensor
*
OpLite
::
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
{
auto
*
var
=
scope
->
FindVar
(
name
);
CHECK
(
var
)
<<
"no variable called "
<<
name
<<
" found"
;
return
&
var
->
Get
<
lite
::
Tensor
>
();
}
Tensor
*
OpLite
::
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
{
auto
*
var
=
scope
->
FindVar
(
name
);
CHECK
(
var
)
<<
"no variable called "
<<
name
<<
" found"
;
return
var
->
GetMutable
<
lite
::
Tensor
>
();
}
bool
OpInfo
::
GetInputArgname
(
const
std
::
string
&
value_name
,
std
::
string
*
out
)
{
for
(
auto
&
item
:
input_argument_
)
{
auto
it
=
std
::
find
(
item
.
second
.
begin
(),
item
.
second
.
end
(),
value_name
);
...
...
paddle/fluid/lite/core/op_lite.h
浏览文件 @
25990d29
...
...
@@ -119,6 +119,9 @@ class OpLite : public Registry {
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
=
""
);
const
Tensor
*
GetTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
Tensor
*
GetMutableTensor
(
lite
::
Scope
*
scope
,
const
std
::
string
&
name
)
const
;
friend
class
mir
::
Node
;
friend
class
mir
::
SSAGraph
;
...
...
paddle/fluid/lite/core/op_registry.cc
浏览文件 @
25990d29
...
...
@@ -17,9 +17,8 @@
namespace
paddle
{
namespace
lite
{
std
::
unique_ptr
<
KernelBase
>
KernelRegistry
::
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
KernelRegistry
::
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
)
{
#define CREATE_KERNEL(target__) \
switch (precision) { \
case PRECISION(kFloat): \
...
...
@@ -43,7 +42,7 @@ std::unique_ptr<KernelBase> KernelRegistry::Create(const std::string &op_type,
}
#undef CREATE_KERNEL
return
nullptr
;
return
std
::
list
<
std
::
unique_ptr
<
KernelBase
>>
()
;
}
KernelRegistry
::
KernelRegistry
()
{
...
...
paddle/fluid/lite/core/op_registry.h
浏览文件 @
25990d29
...
...
@@ -52,8 +52,7 @@ class OpLiteRegistor : public Registor<OpClass> {
template
<
TargetType
Target
,
PrecisionType
Precision
>
using
KernelRegistryForTarget
=
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
OpKernel
<
Target
,
Precision
>>>
;
Factory
<
OpKernel
<
Target
,
Precision
>
,
std
::
unique_ptr
<
KernelBase
>>
;
class
KernelRegistry
final
{
public:
...
...
@@ -80,16 +79,16 @@ class KernelRegistry final {
}
template
<
TargetType
Target
,
PrecisionType
Precision
>
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
)
{
std
::
list
<
std
::
unique_ptr
<
KernelBase
>
>
Create
(
const
std
::
string
&
op_type
)
{
using
kernel_registor_t
=
KernelRegistryForTarget
<
Target
,
Precision
>
;
return
registries_
[
GetKernelOffset
<
Target
,
Precision
>
()]
.
template
get
<
kernel_registor_t
*
>()
->
Create
(
op_type
);
->
Create
s
(
op_type
);
}
std
::
unique_ptr
<
KernelBase
>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
);
std
::
list
<
std
::
unique_ptr
<
KernelBase
>
>
Create
(
const
std
::
string
&
op_type
,
TargetType
target
,
PrecisionType
precision
);
// Get a kernel registry offset in all the registries.
template
<
TargetType
Target
,
PrecisionType
Precision
>
...
...
@@ -151,29 +150,36 @@ class KernelRegistor : public lite::Registor<KernelType> {
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__) \
op_type__##__##target__##__##precision__##__registor__instance__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, \
precision__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__); \
int touch_##op_type__##target__##precision__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__).Touch(); \
return 0; \
} \
static bool op_type__##target__##precision__##param_register \
__attribute__((unused)) = paddle::lite::ParamTypeRegistry::NewInstance< \
TARGET(target__), PRECISION(precision__)>(#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__) \
extern int touch_##op_type__##target__##precision__(); \
int LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__) \
__attribute__((unused)) = touch_##op_type__##target__##precision__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__) \
op_type__##target__##precision__
#define LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__) \
op_type__##__##target__##__##precision__##__registor__instance__##alias__
#define LITE_KERNEL_REGISTER_FAKE(op_type__, target__, precision__, alias__) \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, alias__)
#define REGISTER_LITE_KERNEL(op_type__, target__, precision__, KernelClass, \
alias__) \
static paddle::lite::KernelRegistor<TARGET(target__), \
PRECISION(precision__), KernelClass> \
LITE_KERNEL_REGISTER_INSTANCE(op_type__, target__, precision__, \
alias__)(#op_type__); \
static KernelClass LITE_KERNEL_INSTANCE(op_type__, target__, precision__, \
alias__); \
int touch_##op_type__##target__##precision__##alias__() { \
LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__).Touch(); \
return 0; \
} \
static bool LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, \
alias__) __attribute__((unused)) = \
paddle::lite::ParamTypeRegistry::NewInstance<TARGET(target__), \
PRECISION(precision__)>( \
#op_type__)
#define USE_LITE_KERNEL(op_type__, target__, precision__, alias__) \
extern int touch_##op_type__##target__##precision__##alias__(); \
int op_type__##target__##precision__##alias__ __attribute__((unused)) = \
touch_##op_type__##target__##precision__##alias__();
#define LITE_KERNEL_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__
#define LITE_KERNEL_PARAM_INSTANCE(op_type__, target__, precision__, alias__) \
op_type__##target__##precision__##alias__##param_register
paddle/fluid/lite/kernels/host/fc_compute.cc
浏览文件 @
25990d29
...
...
@@ -24,7 +24,7 @@ namespace host {
// NOTE should use pure std C++ implementation.
void
FcCompute
::
Run
()
{
auto
&
param
=
this
->
p
aram
<
operators
::
FcParam
>
();
auto
&
param
=
this
->
P
aram
<
operators
::
FcParam
>
();
CHECK_GE
(
param
.
input
->
dims
().
size
(),
2UL
);
CHECK_EQ
(
param
.
output
->
dims
().
size
(),
2UL
);
...
...
@@ -51,7 +51,8 @@ void FcCompute::Run() {
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
)
REGISTER_LITE_KERNEL
(
fc
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FcCompute
,
def
)
.
BindInput
(
"Input"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
...
...
paddle/fluid/lite/kernels/host/feed_compute.cc
浏览文件 @
25990d29
...
...
@@ -26,7 +26,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
FeedParam
;
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
FeedParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
FeedParam
>
();
const
Tensor
&
feed_item
=
theparam
.
feed_list
->
at
(
theparam
.
col
);
theparam
.
out
->
CopyDataFrom
(
feed_item
);
}
...
...
@@ -38,7 +38,7 @@ class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
REGISTER_LITE_KERNEL
(
feed
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
FeedCompute
)
paddle
::
lite
::
kernels
::
host
::
FeedCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindOutput
(
"Out"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
...
...
paddle/fluid/lite/kernels/host/mul_compute.cc
浏览文件 @
25990d29
...
...
@@ -40,7 +40,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
MulParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
MulParam
>
();
core
::
dim2
x_shape
(
{
product
(
theparam
.
x
->
dims
().
begin
(),
theparam
.
x
->
dims
().
begin
()
+
theparam
.
x_num_col_dims
),
...
...
@@ -67,7 +67,7 @@ class MulCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
REGISTER_LITE_KERNEL
(
mul
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
MulCompute
)
paddle
::
lite
::
kernels
::
host
::
MulCompute
,
def
)
.
BindInput
(
"X"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
TARGET
(
kHost
))})
.
BindInput
(
"Y"
,
{
paddle
::
lite
::
Type
::
Get
<
paddle
::
lite
::
TensorFp32NCHWTy
>
(
...
...
paddle/fluid/lite/kernels/host/relu_compute.h
浏览文件 @
25990d29
...
...
@@ -24,7 +24,7 @@ namespace host {
class
ReluCompute
:
public
OpKernel
<
TARGET
(
kHost
),
PRECISION
(
kFloat
)
>
{
public:
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
ReluParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
ReluParam
>
();
auto
n
=
product
(
theparam
.
input
->
dims
());
const
float
*
input
=
theparam
.
input
->
data
<
float
>
();
float
*
output
=
theparam
.
output
->
mutable_data
<
float
>
();
...
...
@@ -43,5 +43,5 @@ class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
REGISTER_LITE_KERNEL
(
relu
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ReluCompute
)
paddle
::
lite
::
kernels
::
host
::
ReluCompute
,
def
)
.
Finalize
();
paddle/fluid/lite/kernels/host/scale_compute.cc
浏览文件 @
25990d29
...
...
@@ -36,7 +36,7 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
using
param_t
=
operators
::
MulParam
;
void
Run
()
override
{
auto
&
theparam
=
p
aram
<
operators
::
ScaleParam
>
();
auto
&
theparam
=
P
aram
<
operators
::
ScaleParam
>
();
scale_compute
(
theparam
.
x
->
data
<
float
>
(),
theparam
.
x
->
mutable_data
<
float
>
(),
product
(
theparam
.
x
->
dims
()),
theparam
.
scale
,
theparam
.
bias
,
theparam
.
bias_after_scale
);
...
...
@@ -51,5 +51,5 @@ class ScaleCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
}
// namespace paddle
REGISTER_LITE_KERNEL
(
scale
,
kHost
,
kFloat
,
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
)
paddle
::
lite
::
kernels
::
host
::
ScaleCompute
,
def
)
.
Finalize
();
paddle/fluid/lite/operators/fc_op.cc
浏览文件 @
25990d29
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fc_op.h"
#include "
paddle/fluid/lite/operators/
fc_op.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace
paddle
{
...
...
paddle/fluid/lite/utils/factory.h
浏览文件 @
25990d29
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <iostream>
#include <list>
#include <memory>
#include <sstream>
#include <unordered_map>
...
...
@@ -49,13 +50,21 @@ class Factory {
void
Register
(
const
std
::
string
&
op_type
,
creator_t
&&
creator
)
{
CHECK
(
!
creators_
.
count
(
op_type
))
<<
"The op "
<<
op_type
<<
" has already registered"
;
creators_
.
emplace
(
op_type
,
std
::
move
(
creator
));
creators_
[
op_type
].
emplace_back
(
std
::
move
(
creator
));
}
item_ptr_t
Create
(
const
std
::
string
&
op_type
)
const
{
return
std
::
move
(
Creates
(
op_type
).
front
());
}
std
::
list
<
item_ptr_t
>
Creates
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
creators_
.
find
(
op_type
);
CHECK
(
it
!=
creators_
.
end
())
<<
"no item called "
<<
op_type
;
return
it
->
second
();
std
::
list
<
item_ptr_t
>
res
;
for
(
auto
&
c
:
it
->
second
)
{
res
.
emplace_back
(
c
());
}
return
res
;
}
std
::
string
DebugString
()
const
{
...
...
@@ -67,7 +76,7 @@ class Factory {
}
protected:
std
::
unordered_map
<
std
::
string
,
creator_t
>
creators_
;
std
::
unordered_map
<
std
::
string
,
std
::
list
<
creator_t
>
>
creators_
;
};
/* A helper function to help run a lambda at the start.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录