Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
123c79fa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
123c79fa
编写于
5月 12, 2019
作者:
S
Superjomn
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor type system
上级
9d6a0c88
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
138 addition
and
195 deletion
+138
-195
paddle/fluid/lite/core/type_system.cc
paddle/fluid/lite/core/type_system.cc
+91
-99
paddle/fluid/lite/core/type_system.h
paddle/fluid/lite/core/type_system.h
+41
-88
paddle/fluid/lite/core/type_system_test.cc
paddle/fluid/lite/core/type_system_test.cc
+6
-8
未找到文件。
paddle/fluid/lite/core/type_system.cc
浏览文件 @
123c79fa
...
...
@@ -13,126 +13,118 @@
// limitations under the License.
#include "paddle/fluid/lite/core/type_system.h"
#include "type_system.h"
namespace
paddle
{
namespace
lite
{
// ------------------------- GetType specification ----------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
false
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
UnsupportedTy
x
;
return
&
x
;
}
// ------------------------- end GetType specification ------------------------
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kX86
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kX86
);
return
&
x
;
size_t
ParamTypeRegistry
::
KernelIdTy
::
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
}
template
<
>
const
Type
*
Type
::
Get
<
false
/*is_unsupported*/
,
true
/*is_tensor*/
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
()
{
static
TensorFp32NCHWTy
x
(
TargetType
::
kHost
);
return
&
x
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
IsTensor
())
{
os
<<
"<Tensor:"
;
}
else
{
os
<<
"<"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
template
<
>
const
Type
*
Type
::
Get
<
UnsupportedTy
>
(
TargetType
target
)
{
return
Get
<
false
,
false
,
TargetType
::
kHost
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
>
();
}
const
Type
*
Type
::
GetTensorTy
(
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
int
device
)
{
// NOTE quite naive implementation here, but not performance sensitive.
DataType
::
ID
type_id
=
DataType
::
ID
::
Tensor
;
template
<
TargetType
Target
>
TensorListAnyTy
*
GetTensorListAnyTy
()
{
static
TensorListAnyTy
x
(
Target
);
return
&
x
;
}
template
<
TargetType
Target
>
TensorAnyTy
*
GetTensorAnyTy
()
{
static
TensorAnyTy
x
(
Target
);
return
&
x
;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
template
<
>
const
Type
*
Type
::
Get
<
TensorListAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorListAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorListAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorListAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
}
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
type_id
));
HASH_ONE
(
target
);
HASH_ONE
(
precision
);
HASH_ONE
(
layout
);
HASH_ONE
(
device
);
#undef HASH_ONE
template
<
>
const
Type
*
Type
::
Get
<
TensorAnyTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TargetType
::
kHost
:
return
GetTensorAnyTy
<
TARGET
(
kHost
)
>
();
case
TargetType
::
kCUDA
:
return
GetTensorAnyTy
<
TARGET
(
kCUDA
)
>
();
case
TargetType
::
kX86
:
return
GetTensorAnyTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
std
::
stringstream
name
;
name
<<
"Tensor<"
;
name
<<
TargetToStr
(
target
)
<<
","
;
name
<<
PrecisionToStr
(
precision
)
<<
","
;
name
<<
DataLayoutToStr
(
layout
)
<<
","
;
name
<<
device
;
name
<<
">"
;
auto
it
=
type_repo_
.
find
(
v
);
if
(
it
==
type_repo_
.
end
())
{
// The Types should alive across the process life, no need to delete.
type_repo_
[
v
]
=
new
Type
(
type_id
,
name
.
str
(),
target
,
precision
,
layout
,
device
);
}
return
type_repo_
[
v
];
}
template
<
TargetType
Target
>
const
Type
*
GetTensorFp32NCHWTy
()
{
static
TensorFp32NCHWTy
x
(
Target
);
return
&
x
;
}
const
Type
*
Type
::
GetTensorListTy
(
TargetType
target
,
PrecisionType
precision
,
DataLayoutType
layout
,
int
device
)
{
DataType
::
ID
type_id
=
DataType
::
ID
::
TensorList
;
template
<
>
const
Type
*
Type
::
Get
<
TensorFp32NCHWTy
>
(
TargetType
target
)
{
switch
(
target
)
{
case
TARGET
(
kHost
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kHost
)
>
();
case
TARGET
(
kCUDA
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kCUDA
)
>
();
case
TARGET
(
kX86
):
return
GetTensorFp32NCHWTy
<
TARGET
(
kX86
)
>
();
default:
LOG
(
FATAL
)
<<
"unsupported target Type "
<<
TargetToStr
(
target
);
}
return
nullptr
;
}
#define HASH_ONE(x) v = hash_combine(v, hasher(static_cast<int>(x)))
const
Type
*
LookupType
(
DataTypeBase
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
)
{
using
id_t
=
DataTypeBase
::
ID
;
switch
(
type_id
)
{
case
id_t
::
Tensor_Any
:
return
Type
::
Get
<
TensorAnyTy
>
(
place
.
target
);
case
id_t
::
Tensor_Fp32_NCHW
:
return
Type
::
Get
<
TensorFp32NCHWTy
>
(
place
.
target
);
case
id_t
::
TensorList_Any
:
return
Type
::
Get
<
TensorListAnyTy
>
(
place
.
target
);
default:
LOG
(
FATAL
)
<<
"unsupported type"
;
}
return
nullptr
;
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
type_id
));
HASH_ONE
(
target
);
HASH_ONE
(
precision
);
HASH_ONE
(
layout
);
HASH_ONE
(
device
);
#undef HASH_ONE
std
::
stringstream
name
;
name
<<
"TensorList<"
;
name
<<
TargetToStr
(
target
)
<<
","
;
name
<<
PrecisionToStr
(
precision
)
<<
","
;
name
<<
DataLayoutToStr
(
layout
)
<<
","
;
name
<<
device
;
name
<<
">"
;
if
(
!
type_repo_
[
v
])
// The Types should alive across the process life, no need to delete.
type_repo_
[
v
]
=
new
Type
(
type_id
,
name
.
str
(),
target
,
precision
,
layout
,
device
);
return
type_repo_
[
v
];
}
// ------------------------- end GetType specification ------------------------
const
Type
*
Type
::
GetUnsupportedTy
()
{
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
DataType
::
ID
::
Unsupported
));
if
(
!
type_repo_
[
v
])
type_repo_
[
v
]
=
new
Type
(
DataType
::
ID
::
Unsupported
,
"Unsupported"
,
TARGET
(
kUnk
),
PRECISION
(
kUnk
),
DATALAYOUT
(
kUnk
),
-
1
);
}
size_t
ParamTypeRegistry
::
KernelIdTy
::
hash
()
const
{
std
::
hash
<
std
::
string
>
h
;
size_t
hash
=
h
(
kernel_type
);
hash
=
hash_combine
(
hash
,
place
.
hash
());
hash
=
hash_combine
(
hash
,
std
::
hash
<
int
>
()(
static_cast
<
int
>
(
io
)));
hash
=
hash_combine
(
hash
,
std
::
hash
<
std
::
string
>
()(
arg_name
));
return
hash
;
const
Type
*
Type
::
GetVoidTy
()
{
std
::
hash
<
int
>
hasher
;
size_t
v
=
hasher
(
static_cast
<
int
>
(
DataType
::
ID
::
Void
));
if
(
!
type_repo_
[
v
])
type_repo_
[
v
]
=
new
Type
(
DataType
::
ID
::
Void
,
"Void"
,
TARGET
(
kAny
),
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
),
-
1
);
}
}
// namespace lite
...
...
paddle/fluid/lite/core/type_system.h
浏览文件 @
123c79fa
...
...
@@ -73,7 +73,7 @@ namespace lite {
//
// TODO(Superjomn) Add operator/kernel-wise static checking to avoid unsupported
// type mixed in the system.
class
DataType
Base
{
class
DataType
{
public:
// The Void type can cast to any other type.
// The Unsupported is the data type that developed include in the system, for
...
...
@@ -83,43 +83,56 @@ class DataTypeBase {
enum
class
ID
:
int
{
Void
=
0
,
// unknown type that can be cast to any data type.
Unsupported
,
// Unsupported data type that will not be analyzed.
Tensor_Fp32_NCHW
,
Tensor_Int8_NCHW
,
Tensor_Int64_NCHW
,
// Tensor_Any represents a Tensor with any place, data, layout. It is used
// in some IO kernels those doesn't care the data.
Tensor_Any
,
// Used by feed or fetch op.
TensorList_Any
,
Tensor
,
// A tensor list, but all the elements should have the same type.
TensorList
,
// ---------
NumTypes
,
// Must remains as last defined ID.
};
ID
id
()
const
{
return
id_
;
}
// type check.
bool
IsTensor
()
const
{
return
is_tensor_
;
}
bool
IsVoid
()
const
{
return
id_
==
ID
::
Void
;
}
bool
IsUnsupported
()
const
{
return
id_
==
ID
::
Unsupported
;
}
bool
IsTensorFp32NCHW
()
const
{
return
id_
==
ID
::
Tensor_Fp32_NCHW
;
}
bool
IsTensorInt8NCHW
()
const
{
return
id_
==
ID
::
Tensor_Int8_NCHW
;
}
bool
IsTensorInt64NCHW
()
const
{
return
id_
==
ID
::
Tensor_Int64_NCHW
;
}
bool
IsTensor
()
const
{
return
id_
==
ID
::
Tensor
;
}
bool
IsTensorList
()
const
{
return
id_
==
ID
::
TensorList
;
}
// Get number of types.
int
num_types
()
const
{
return
static_cast
<
int
>
(
ID
::
NumTypes
);
}
protected:
// Can only extended by subclass.
DataType
Base
(
ID
id
,
bool
is_tensor
)
:
id_
(
id
),
is_tensor_
(
is_tensor
)
{}
DataType
(
ID
id
)
:
id_
(
id
)
{}
ID
id_
{
ID
::
Unsupported
};
bool
is_tensor_
{
false
};
};
/*
* Datatype with device info considered.
* NOTE A Type with different device is treated as different DeviceDataType.
*/
class
Type
:
public
DataType
Base
{
class
Type
:
public
DataType
{
public:
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
/// Get a Tensor type.
static
const
Type
*
GetTensorTy
(
TargetType
target
,
PrecisionType
precision
=
PRECISION
(
kFloat
),
DataLayoutType
layout
=
DATALAYOUT
(
kNCHW
),
int
device
=
0
);
/// Get a TensorList type.
static
const
Type
*
GetTensorListTy
(
TargetType
target
,
PrecisionType
precision
=
PRECISION
(
kFloat
),
DataLayoutType
layout
=
DATALAYOUT
(
kNCHW
),
int
device
=
0
);
/// Get an Unsupported type.
static
const
Type
*
GetUnsupportedTy
();
/// Get an Void type.
static
const
Type
*
GetVoidTy
();
TargetType
target
()
const
{
return
place_
.
target
;
}
PrecisionType
precision
()
const
{
return
place_
.
precision
;
}
DataLayoutType
layout
()
const
{
return
place_
.
layout
;
}
...
...
@@ -130,52 +143,23 @@ class Type : public DataTypeBase {
bool
operator
==
(
const
Type
&
other
)
{
return
id_
==
other
.
id
()
&&
place_
==
other
.
place
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
)
{
if
(
other
.
IsUnsupported
())
{
os
<<
"<Unsupported>"
;
return
os
;
}
if
(
other
.
IsVoid
())
{
os
<<
"<Void>"
;
return
os
;
}
if
(
other
.
is_tensor_
)
{
os
<<
"<Tensor:"
;
}
else
{
os
<<
"<"
;
}
os
<<
TargetToStr
(
other
.
target
())
<<
"/"
<<
PrecisionToStr
(
other
.
precision
())
<<
"/"
<<
DataLayoutToStr
(
other
.
layout
())
<<
">"
;
return
os
;
}
// Can cast to another type. This is heavily used in MIR, by determine whether
// is is possible to add a statement to transform a type to another.
virtual
bool
TypeCastable
(
const
Type
&
type
)
const
{
return
id_
==
type
.
id
();
}
template
<
bool
is_unknown
,
bool
is_tensor
=
true
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
>
// Get a type.
static
const
Type
*
Get
();
template
<
typename
TypeTy
>
static
const
Type
*
Get
(
TargetType
target
=
TargetType
::
kHost
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Type
&
other
);
virtual
~
Type
()
=
default
;
protected:
Type
(
ID
id
,
const
std
::
string
&
name
,
bool
is_tensor
,
TargetType
target
=
TargetType
::
kHost
,
/// One should avoid using this construct.
Type
(
ID
id
,
const
std
::
string
&
name
,
TargetType
target
=
TargetType
::
kHost
,
PrecisionType
precision
=
PrecisionType
::
kFloat
,
DataLayoutType
layout
=
DataLayoutType
::
kNCHW
,
short
device
=
0
)
:
DataTypeBase
(
id
,
is_tensor
),
place_
{
target
,
precision
,
layout
,
device
},
name_
(
name
)
{}
:
DataType
(
id
),
place_
{
target
,
precision
,
layout
,
device
},
name_
(
name
)
{}
// An map is used here to maintain a global repo for types. We don't use
// MACROs with static variables for that the TypeSystem should only used in
// compile time, that is not performance sensitive, and a map-based way is
// easier to implement and maintain.
static
std
::
map
<
size_t
,
const
Type
*>
type_repo_
;
protected:
Place
place_
;
const
std
::
string
name_
;
};
...
...
@@ -224,46 +208,15 @@ static bool TypeCompatibleTo(const Type& a, const Type& b) {
// is only one instance across the system.
class
VoidTy
:
public
Type
{
public:
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
,
false
/*is_tensor*/
)
{}
VoidTy
()
:
Type
(
ID
::
Void
,
"Void"
)
{}
};
class
UnsupportedTy
:
public
Type
{
public:
UnsupportedTy
()
:
Type
(
ID
::
Unsupported
,
"Unsupported"
,
false
/*is_tensor*/
)
{}
};
class
TensorAnyTy
:
public
Type
{
public:
explicit
TensorAnyTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Any
,
"TensorAny"
,
true
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
// A list of tensor, and no assumption on the data layout or data type.
class
TensorListAnyTy
:
public
Type
{
public:
explicit
TensorListAnyTy
(
TargetType
target
)
:
Type
(
ID
::
TensorList_Any
,
"TensorList_Any"
,
false
,
target
,
PRECISION
(
kAny
),
DATALAYOUT
(
kAny
))
{}
};
class
TensorFp32NCHWTy
:
public
Type
{
public:
explicit
TensorFp32NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Fp32_NCHW
,
"TensorFp32NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kFloat
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt8NCHWTy
:
public
Type
{
public:
explicit
TensorInt8NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int8_NCHW
,
"TensorInt8NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
class
TensorInt64NCHWTy
:
public
Type
{
public:
explicit
TensorInt64NCHWTy
(
TargetType
target
)
:
Type
(
ID
::
Tensor_Int64_NCHW
,
"TensorInt64NCHW"
,
true
/*is_tensor*/
,
target
,
PrecisionType
::
kInt8
,
DataLayoutType
::
kNCHW
)
{}
};
const
Type
*
LookupType
(
DataType
Base
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
const
Type
*
LookupType
(
DataType
::
ID
type_id
,
bool
is_unknown
,
bool
is_tensor
,
Place
place
);
// ------------------------- end predefined types ---------------------------
/*
...
...
paddle/fluid/lite/core/type_system_test.cc
浏览文件 @
123c79fa
...
...
@@ -18,15 +18,13 @@
namespace
paddle
{
namespace
lite
{
TEST
(
TypeSystem
,
test
)
{
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
<
lite
::
TensorBase
>
());
}
TEST
(
TypeSystem
,
CheckDuplicateGet
)
{
auto
*
tensor_ty
=
Type
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
auto
*
tensor_ty1
=
Type
::
GetTensorTy
(
TARGET
(
kHost
),
PRECISION
(
kFloat
),
DATALAYOUT
(
kNCHW
));
TEST
(
TypeSystem
,
register_new
)
{
TypeSystem
::
Global
().
Register
<
int
>
(
"int32"
);
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
<
int
>
());
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
(
typeid
(
int
).
hash_code
()));
ASSERT_TRUE
(
TypeSystem
::
Global
().
Contains
(
"int32"
));
ASSERT_EQ
(
tensor_ty
,
tensor_ty1
);
}
}
// namespace lite
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录