Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e55a5cd9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e55a5cd9
编写于
4月 21, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move ParamTypes to type_system.h
上级
43036686
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
162 addition
and
160 deletion
+162
-160
paddle/fluid/lite/core/kernel.h
paddle/fluid/lite/core/kernel.h
+3
-158
paddle/fluid/lite/core/kernel_test.cc
paddle/fluid/lite/core/kernel_test.cc
+2
-2
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+157
-0
未找到文件。
paddle/fluid/lite/core/kernel.h
浏览文件 @
e55a5cd9
...
@@ -45,9 +45,9 @@ class KernelBase {
...
@@ -45,9 +45,9 @@ class KernelBase {
param_
.
set
<
T
>
(
param
);
param_
.
set
<
T
>
(
param
);
}
}
template
<
typename
P
aram
>
template
<
typename
P
>
P
aram
&
p
aram
()
const
{
P
&
P
aram
()
const
{
return
param_
.
get
<
P
aram
>
();
return
param_
.
get
<
P
>
();
}
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
void
set_op_type
(
const
std
::
string
&
type
)
{
op_type_
=
type
;
}
...
@@ -71,161 +71,6 @@ class KernelBase {
...
@@ -71,161 +71,6 @@ class KernelBase {
std
::
string
op_type_
;
std
::
string
op_type_
;
};
};
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* The element_type_hash is the hash code of the element, it should be
* registered in the `TypeSystem`.
*/
struct
ParamType
{
// For unsupported types.
size_t
element_type_hash
{};
Place
tensor_place
{};
const
Type
*
type_
;
explicit
ParamType
()
=
default
;
explicit
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{
tensor_place
=
type_
->
place
();
}
std
::
string
DebugString
()
const
{
return
tensor_place
.
DebugString
();
}
};
/*
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct
ParamTypeRecorder
{
std
::
map
<
std
::
string
,
ParamType
>
inputs
;
std
::
map
<
std
::
string
,
ParamType
>
outputs
;
void
RegisterInputType
(
const
std
::
string
&
arg_name
,
const
ParamType
&
type
)
{
Register
(
&
inputs
,
arg_name
,
type
);
}
void
RegisterOutputType
(
const
std
::
string
&
arg_name
,
const
ParamType
&
type
)
{
Register
(
&
outputs
,
arg_name
,
type
);
}
private:
void
Register
(
std
::
map
<
std
::
string
,
ParamType
>*
ts
,
const
std
::
string
&
arg_name
,
ParamType
type
)
{
(
*
ts
)[
arg_name
]
=
type
;
}
};
/*
* The ParamTypeRegistry help register the input and output data types for all
* the kernels. It is made singleton so that all the objects of the same kernel
* can share the same information.
*
* Usage:
* for register a kernel for FC operator.
* ParamTypeRegistry::Global().Register(
* "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
* {typeid(Tensor), {TARGET(kCUDA)}});
*/
class
ParamTypeRegistry
{
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
/*
* Helper class for registering a ParamType for a Kernel.
* Usage:
*
* NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
* .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
* .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
* PRECISION(kFloat)});
*/
struct
NewInstance
{
explicit
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
const
std
::
string
&
arg_name
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
arg_name
,
ptype
);
return
*
this
;
}
NewInstance
&
BindOutput
(
const
std
::
string
&
arg_name
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kOutput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
arg_name
,
ptype
);
return
*
this
;
}
bool
Finalize
()
{
return
true
;
}
private:
std
::
string
kernel_type_
;
};
template
<
IO
io
>
void
Register
(
const
std
::
string
&
kernel_type
,
const
Place
&
place
,
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
types_
[
key
]
=
data_type
;
CHECK
(
types_
.
count
(
key
));
}
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
x
;
return
x
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ParamTypeRegistry
&
other
)
{
for
(
auto
&
item
:
other
.
types_
)
{
os
<<
item
.
first
<<
" "
<<
item
.
second
.
DebugString
()
<<
"
\n
"
;
}
return
os
;
}
private:
ParamTypeRegistry
()
=
default
;
public:
// Identification for a Kernel.
struct
KernelIdTy
{
std
::
string
kernel_type
;
Place
place
;
IO
io
;
std
::
string
arg_name
;
size_t
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelIdTy
&
other
);
};
using
key_t
=
KernelIdTy
;
struct
KeyCmp
{
bool
operator
()(
const
key_t
&
a
,
const
key_t
&
b
)
const
;
};
private:
std
::
map
<
key_t
,
ParamType
,
ParamTypeRegistry
::
KeyCmp
>
types_
;
};
// Light-weight kernel implementation.
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// The OpKernel is designed to implement the specific algorithm on a target
// device.
// device.
...
...
paddle/fluid/lite/core/kernel_test.cc
浏览文件 @
e55a5cd9
...
@@ -25,8 +25,8 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
...
@@ -25,8 +25,8 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
public:
void
Run
()
override
{
void
Run
()
override
{
LOG
(
INFO
)
<<
"SomeKernel executed"
;
LOG
(
INFO
)
<<
"SomeKernel executed"
;
LOG
(
INFO
)
<<
p
aram
<
operators
::
FcParam
>
().
in_num_col_dims
;
LOG
(
INFO
)
<<
P
aram
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
p
aram
<
operators
::
FcParam
>
().
in_num_col_dims
;
test_code
=
P
aram
<
operators
::
FcParam
>
().
in_num_col_dims
;
}
}
TargetType
target
()
const
override
{
return
TARGET
(
kHost
);
}
TargetType
target
()
const
override
{
return
TARGET
(
kHost
);
}
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
e55a5cd9
...
@@ -21,12 +21,14 @@
...
@@ -21,12 +21,14 @@
// for analysis and runtime.
// for analysis and runtime.
#include <glog/logging.h>
#include <glog/logging.h>
#include <map>
#include <string>
#include <string>
#include <typeinfo>
#include <typeinfo>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -202,5 +204,160 @@ class TypeSystem {
...
@@ -202,5 +204,160 @@ class TypeSystem {
std
::
unordered_set
<
std
::
string
>
names_
;
std
::
unordered_set
<
std
::
string
>
names_
;
};
};
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* The element_type_hash is the hash code of the element, it should be
* registered in the `TypeSystem`.
*/
struct
ParamType
{
// For unsupported types.
size_t
element_type_hash
{};
Place
tensor_place
{};
const
Type
*
type_
;
explicit
ParamType
()
=
default
;
explicit
ParamType
(
size_t
element_type_hash
)
:
element_type_hash
(
element_type_hash
)
{}
ParamType
(
size_t
element_type_hash
,
const
Place
&
place
)
:
element_type_hash
(
element_type_hash
),
tensor_place
(
place
)
{}
ParamType
(
const
Type
*
type
)
:
type_
(
type
)
{
tensor_place
=
type_
->
place
();
}
std
::
string
DebugString
()
const
{
return
tensor_place
.
DebugString
();
}
};
/*
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct
ParamTypeRecorder
{
std
::
map
<
std
::
string
,
ParamType
>
inputs
;
std
::
map
<
std
::
string
,
ParamType
>
outputs
;
void
RegisterInputType
(
const
std
::
string
&
arg_name
,
const
ParamType
&
type
)
{
Register
(
&
inputs
,
arg_name
,
type
);
}
void
RegisterOutputType
(
const
std
::
string
&
arg_name
,
const
ParamType
&
type
)
{
Register
(
&
outputs
,
arg_name
,
type
);
}
private:
void
Register
(
std
::
map
<
std
::
string
,
ParamType
>*
ts
,
const
std
::
string
&
arg_name
,
ParamType
type
)
{
(
*
ts
)[
arg_name
]
=
type
;
}
};
/*
* The ParamTypeRegistry help register the input and output data types for all
* the kernels. It is made singleton so that all the objects of the same kernel
* can share the same information.
*
* Usage:
* for register a kernel for FC operator.
* ParamTypeRegistry::Global().Register(
* "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
* {typeid(Tensor), {TARGET(kCUDA)}});
*/
class
ParamTypeRegistry
{
public:
enum
class
IO
:
int
{
kInput
=
0
,
kOutput
};
template
<
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
/*
* Helper class for registering a ParamType for a Kernel.
* Usage:
*
* NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
* .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
* .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
* PRECISION(kFloat)});
*/
struct
NewInstance
{
explicit
NewInstance
(
const
std
::
string
&
kernel_type
)
:
kernel_type_
(
kernel_type
)
{}
NewInstance
&
BindInput
(
const
std
::
string
&
arg_name
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kInput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
arg_name
,
ptype
);
return
*
this
;
}
NewInstance
&
BindOutput
(
const
std
::
string
&
arg_name
,
const
ParamType
&
ptype
)
{
ParamTypeRegistry
::
Global
().
Register
<
IO
::
kOutput
>
(
kernel_type_
,
Place
{
target
,
precision
,
layout
},
arg_name
,
ptype
);
return
*
this
;
}
bool
Finalize
()
{
return
true
;
}
private:
std
::
string
kernel_type_
;
};
template
<
IO
io
>
void
Register
(
const
std
::
string
&
kernel_type
,
const
Place
&
place
,
const
std
::
string
&
arg_name
,
ParamType
data_type
)
{
KernelIdTy
key
{
kernel_type
,
place
,
io
,
arg_name
};
types_
[
key
]
=
data_type
;
CHECK
(
types_
.
count
(
key
));
}
template
<
IO
io
>
const
ParamType
*
Retrieve
(
const
Place
&
place
,
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
)
{
KernelIdTy
key
{
op_type
,
place
,
io
,
arg_name
};
auto
it
=
types_
.
find
(
key
);
if
(
it
==
types_
.
end
())
return
nullptr
;
return
&
it
->
second
;
}
static
ParamTypeRegistry
&
Global
()
{
static
ParamTypeRegistry
x
;
return
x
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ParamTypeRegistry
&
other
)
{
for
(
auto
&
item
:
other
.
types_
)
{
os
<<
item
.
first
<<
" "
<<
item
.
second
.
DebugString
()
<<
"
\n
"
;
}
return
os
;
}
private:
ParamTypeRegistry
()
=
default
;
public:
// Identification for a Kernel.
struct
KernelIdTy
{
std
::
string
kernel_type
;
Place
place
;
IO
io
;
std
::
string
arg_name
;
size_t
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
KernelIdTy
&
other
);
};
using
key_t
=
KernelIdTy
;
struct
KeyCmp
{
bool
operator
()(
const
key_t
&
a
,
const
key_t
&
b
)
const
;
};
private:
std
::
map
<
key_t
,
ParamType
,
ParamTypeRegistry
::
KeyCmp
>
types_
;
};
}
// namespace lite
}
// namespace lite
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录