Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
123c79fa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
123c79fa
编写于
5月 12, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor type system
上级
9d6a0c88
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
138 addition
and
195 deletion
+138
-195
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+91
-99
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+41
-88
paddle/fluid/lite/core/type_system_test.cc
paddle/fluid/lite/core/type_system_test.cc
+6
-8
未找到文件。
paddle/fluid/lite/core/type_system.cc
浏览文件 @
123c79fa
...
...
@@ -13,126 +13,118 @@
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
#include "type_system.h"
namespace
paddle
{
namespace
lite
{
// ------------------------- GetType specification ----------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
false
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
UnsupportedTy
x
;
return
&
x
;
}
// ------------------------- end GetType specification ------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kX86
);
return
&
x
;
size_t
ParamTypeRegistry
::
KernelIdTy
::
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kHost
);
return
&
x
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
IsTensor
())
{
os
<<
"<Tensor:"
;
}
else
{
os
<<
"<"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
}
const
Type
*
Type
::
GetTensorTy
(
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
int
device
)
{
// NOTE quite naive implementation here, but not performance sensitive.
DataType
::
ID
type_id
=
DataType
::
ID
::
Tensor
;
template
<
TargetType
Target
>
TensorListAnyTy
*
GetTensorListAnyTy
()
{
static
TensorListAnyTy
x
(
Target
);
return
&
x
;
}
template
<
TargetType
Target
>
TensorAnyTy
*
GetTensorAnyTy
()
{
static
TensorAnyTy
x
(
Target
);
return
&
x
;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
template
<
>
const
Type
*
Type
::
Get
<
TensorListAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorListAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorListAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorListAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
type_id
));
HASH_ONE
(
target
);
HASH_ONE
(
precision
);
HASH_ONE
(
layout
);
HASH_ONE
(
device
);
#undef HASH_ONE
template
<
>
const
Type
*
Type
::
Get
<
TensorAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
std
::
stringstream
name
;
name
<<
"Tensor<"
;
name
<<
TargetToStr
(
target
)
<<
","
;
name
<<
PrecisionToStr
(
precision
)
<<
","
;
name
<<
DataLayoutToStr
(
layout
)
<<
","
;
name
<<
device
;
name
<<
">"
;
auto
it
=
type_repo_
.
find
(
v
);
if
(
it
==
type_repo_
.
end
())
{
// The Types should alive across the process life, no need to delete.
type_repo_
[
v
]
=
new
Type
(
type_id
,
name
.
str
(),
target
,
precision
,
layout
,
device
);
}
return
type_repo_
[
v
];
}
template
<
TargetType
Target
>
const
Type
*
GetTensorFp32NCHWTy
()
{
static
TensorFp32NCHWTy
x
(
Target
);
return
&
x
;
}
const
Type
*
Type
::
GetTensorListTy
(
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
int
device
)
{
DataType
::
ID
type_id
=
DataType
::
ID
::
TensorList
;
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TARGET
(
kHost
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kHost
)
>
();
case
TARGET
(
kCUDA
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kCUDA
)
>
();
case
TARGET
(
kX86
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target Type "
<<
TargetToStr
(
target
);
}
return
nullptr
;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
)
{
using
id_t
=
DataTypeBase
::
ID
;
switch
(
type_id
)
{
case
id_t
::
Tensor_Any
:
return
Type
::
Get
<
TensorAnyTy
>
(
place
.
target
);
case
id_t
::
Tensor_Fp32_NCHW
:
return
Type
::
Get
<
TensorFp32NCHWTy
>
(
place
.
target
);
case
id_t
::
TensorList_Any
:
return
Type
::
Get
<
TensorListAnyTy
>
(
place
.
target
);
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
return
nullptr
;
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
type_id
));
HASH_ONE
(
target
);
HASH_ONE
(
precision
);
HASH_ONE
(
layout
);
HASH_ONE
(
device
);
#undef HASH_ONE
std
::
stringstream
name
;
name
<<
"TensorList<"
;
name
<<
TargetToStr
(
target
)
<<
","
;
name
<<
PrecisionToStr
(
precision
)
<<
","
;
name
<<
DataLayoutToStr
(
layout
)
<<
","
;
name
<<
device
;
name
<<
">"
;
if
(
!
type_repo_
[
v
])
// The Types should alive across the process life, no need to delete.
type_repo_
[
v
]
=
new
Type
(
type_id
,
name
.
str
(),
target
,
precision
,
layout
,
device
);
return
type_repo_
[
v
];
}
// ------------------------- end GetType specification ------------------------
const
Type
*
Type
::
GetUnsupportedTy
()
{
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
DataType
::
ID
::
Unsupported
));
if
(
!
type_repo_
[
v
])
type_repo_
[
v
]
=
new
Type
(
DataType
::
ID
::
Unsupported
,
"Unsupported"
,
TARGET
(
kUnk
),
PRECISION
(
kUnk
),
DATALAYOUT
(
kUnk
),
-
1
);
}
size_t
ParamTypeRegistry
::
KernelIdTy
::
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
const
Type
*
Type
::
GetVoidTy
()
{
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
DataType
::
ID
::
Void
));
if
(
!
type_repo_
[
v
])
type_repo_
[
v
]
=
new
Type
(
DataType
::
ID
::
Void
,
"Void"
,
TARGET
(
kAny
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
);
}
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
123c79fa
...
...
@@ -73,7 +73,7 @@ namespace lite {
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
class
DataType
Base
{
class
DataType
{
public:
// The Void type can cast to any other type.
// The Unsupported is the data type that developed include in the system, for
...
...
@@ -83,43 +83,56 @@ class DataTypeBase {
enum
class
ID
:
int
{
Void
=
0
,
// unknown type that can be cast to any data type.
Unsupported
,
// Unsupported data type that will not be analyzed.
Tensor_Fp32_NCHW
,
Tensor_Int8_NCHW
,
Tensor_Int64_NCHW
,
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
Tensor_Any
,
// Used by feed or fetch op.
TensorList_Any
,
Tensor
,
// A tensor list, but all the elements should have the same type.
TensorList
,
// ---------
NumTypes
,
// Must remains as last defined ID.
};
ID
id
()
const
{
return
id_
;
}
// type check.
bool
IsTensor
()
const
{
return
is_tensor_
;
}
bool
IsVoid
()
const
{
return
id_
==
ID
::
Void
;
}
bool
IsUnsupported
()
const
{
return
id_
==
ID
::
Unsupported
;
}
bool
IsTensorFp32NCHW
()
const
{
return
id_
==
ID
::
Tensor_Fp32_NCHW
;
}
bool
IsTensorInt8NCHW
()
const
{
return
id_
==
ID
::
Tensor_Int8_NCHW
;
}
bool
IsTensorInt64NCHW
()
const
{
return
id_
==
ID
::
Tensor_Int64_NCHW
;
}
bool
IsTensor
()
const
{
return
id_
==
ID
::
Tensor
;
}
bool
IsTensorList
()
const
{
return
id_
==
ID
::
TensorList
;
}
// Get number of types.
int
num_types
()
const
{
return
static_cast
<
int
>
(
ID
::
NumTypes
);
}
protected:
// Can only extended by subclass.
DataType
Base
(
ID
id
,
bool
is_tensor
)
:
id_
(
id
),
is_tensor_
(
is_tensor
)
{}
DataType
(
ID
id
)
:
id_
(
id
)
{}
ID
id_
{
ID
::
Unsupported
};
bool
is_tensor_
{
false
};
};
/*
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class
Type
:
public
DataType
Base
{
class
Type
:
public
DataType
{
public:
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
/// Get a Tensor type.
static
const
Type
*
GetTensorTy
(
TargetType
target
,
PrecisionType
precision
=
PRECISION
(
kFloat
),
DataLayoutType
layout
=
DATALAYOUT
(
kNCHW
),
int
device
=
0
);
/// Get a TensorList type.
static
const
Type
*
GetTensorListTy
(
TargetType
target
,
PrecisionType
precision
=
PRECISION
(
kFloat
),
DataLayoutType
layout
=
DATALAYOUT
(
kNCHW
),
int
device
=
0
);
/// Get an Unsupported type.
static
const
Type
*
GetUnsupportedTy
();
/// Get an Void type.
static
const
Type
*
GetVoidTy
();
TargetType
target
()
const
{
return
place_
.
target
;
}
PrecisionType
precision
()
const
{
return
place_
.
precision
;
}
DataLayoutType
layout
()
const
{
return
place_
.
layout
;
}
...
...
@@ -130,52 +143,23 @@ class Type : public DataTypeBase {
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
is_tensor_
)
{
os
<<
"<Tensor:"
;
}
else
{
os
<<
"<"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
// Get a type.
static
const
Type
*
Get
();
template
<
typename
TypeTy
>
static
const
Type
*
Get
(
TargetType
target
=
TargetType
::
kHost
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
);
virtual
~
Type
()
=
default
;
protected:
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
TargetType
target
=
TargetType
::
kHost
,
/// One should avoid using this construct.
Type
(
ID
id
,
const
std
::
string
&
name
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
,
short
device
=
0
)
:
DataTypeBase
(
id
,
is_tensor
),
place_
{
target
,
precision
,
layout
,
device
},
name_
(
name
)
{}
:
DataType
(
id
),
place_
{
target
,
precision
,
layout
,
device
},
name_
(
name
)
{}
// An map is used here to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
static
std
::
map
<
size_t
,
const
Type
*>
type_repo_
;
protected:
Place
place_
;
const
std
::
string
name_
;
};
...
...
@@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) {
// is only one instance across the system.
class
VoidTy
:
public
Type
{
public:
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
,
false
/*is_tensor*/
)
{}
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
)
{}
};
class
UnsupportedTy
:
public
Type
{
public:
UnsupportedTy
()
:
Type
(
ID
::
Unsupported
,
"Unsupported"
,
false
/*is_tensor*/
)
{}
};
class
TensorAnyTy
:
public
Type
{
public:
explicit
TensorAnyTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
// A list of tensor, and no assumption on the data layout or data type.
class
TensorListAnyTy
:
public
Type
{
public:
explicit
TensorListAnyTy
(
TargetType
target
)
:
Type
(
ID
::
TensorList_Any
,
"TensorList_Any"
,
false
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
class
TensorFp32NCHWTy
:
public
Type
{
public:
explicit
TensorFp32NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt8NCHWTy
:
public
Type
{
public:
explicit
TensorInt8NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt64NCHWTy
:
public
Type
{
public:
explicit
TensorInt64NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
const
Type
*
LookupType
(
DataType
Base
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
const
Type
*
LookupType
(
DataType
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
// ------------------------- end predefined types ---------------------------
/*
...
...
paddle/fluid/lite/core/type_system_test.cc
浏览文件 @
123c79fa
...
...
@@ -18,15 +18,13 @@
namespace
paddle
{
namespace
lite
{
TEST
(
TypeSystem
,
test
)
{
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
<
lite
::
TensorBase
>
());
}
TEST
(
TypeSystem
,
CheckDuplicateGet
)
{
auto
*
tensor_ty
=
Type
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
auto
*
tensor_ty1
=
Type
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
TEST
(
TypeSystem
,
register_new
)
{
TypeSystem
::
Global
().
Register
<
int
>
(
"int32"
);
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
<
int
>
());
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
(
typeid
(
int
).
hash_code
()));
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
(
"int32"
));
ASSERT_EQ
(
tensor_ty
,
tensor_ty1
);
}
}
// namespace lite
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录