Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
13429c3e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
13429c3e
编写于
12月 20, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean code, remove void registration
test why MAC CI fail again test=develop
上级
ce4a26dd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
83 addition
and
41 deletion
+83
-41
paddle/fluid/framework/var_type_traits.cc
paddle/fluid/framework/var_type_traits.cc
+44
-14
paddle/fluid/framework/var_type_traits.h
paddle/fluid/framework/var_type_traits.h
+17
-16
paddle/fluid/framework/var_type_traits_test.cc
paddle/fluid/framework/var_type_traits_test.cc
+22
-11
未找到文件。
paddle/fluid/framework/var_type_traits.cc
浏览文件 @
13429c3e
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -23,54 +24,83 @@ namespace detail {
template
<
int
kStart
,
int
kEnd
,
bool
kStop
>
struct
VarIdToTypeIndexMapInitializerImpl
{
static
void
Init
(
std
::
unordered_map
<
int
,
std
::
type_index
>
*
m
)
{
template
<
typename
MapType1
,
typename
MapType2
>
static
void
Init
(
MapType1
*
id_to_type
,
MapType2
*
type_to_id
)
{
using
Type
=
typename
std
::
tuple_element
<
kStart
,
VarTypeRegistry
::
ArgTuple
>::
type
;
static_assert
(
!
std
::
is_same
<
Type
,
void
>::
value
,
"Type cannot be void"
);
constexpr
int
kId
=
VarTypeTrait
<
Type
>::
kId
;
if
(
!
std
::
is_same
<
Type
,
void
>::
value
)
{
m
->
emplace
(
kId
,
std
::
type_index
(
typeid
(
Type
)));
}
auto
type
=
std
::
type_index
(
typeid
(
Type
));
PADDLE_ENFORCE
(
id_to_type
->
count
(
kId
)
==
0
,
"Registered duplicate type id %d for type %s"
,
kId
,
type
.
name
());
PADDLE_ENFORCE
(
type_to_id
->
count
(
type
)
==
0
,
"Registered duplicate type_index %s for id %d"
,
type
.
name
(),
kId
);
id_to_type
->
emplace
(
kId
,
type
);
type_to_id
->
emplace
(
type
,
kId
);
VarIdToTypeIndexMapInitializerImpl
<
kStart
+
1
,
kEnd
,
kStart
+
1
==
kEnd
>::
Init
(
m
);
kStart
+
1
==
kEnd
>::
Init
(
id_to_type
,
type_to_id
);
}
};
template
<
int
kStart
,
int
kEnd
>
struct
VarIdToTypeIndexMapInitializerImpl
<
kStart
,
kEnd
,
true
>
{
static
void
Init
(
std
::
unordered_map
<
int
,
std
::
type_index
>
*
m
)
{}
template
<
typename
MapType1
,
typename
MapType2
>
static
void
Init
(
MapType1
*
,
MapType2
*
)
{}
};
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map
// std::type_index map
and std::type_index -> var_id map
using
VarIdToTypeIndexMapInitializer
=
VarIdToTypeIndexMapInitializerImpl
<
0
,
VarTypeRegistry
::
kRegisteredTypeNum
,
VarTypeRegistry
::
kRegisteredTypeNum
==
0
>
;
struct
VarIdToTypeIndexMapHolder
{
DISABLE_COPY_AND_ASSIGN
(
VarIdToTypeIndexMapHolder
);
public:
static
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
)
{
static
const
VarIdToTypeIndexMapHolder
instance
;
auto
it
=
instance
.
var_type_map_
.
find
(
var_id
);
PADDLE_ENFORCE
(
it
!=
instance
.
var_type_map_
.
end
(),
auto
it
=
Instance
().
id_to_type_map_
.
find
(
var_id
);
PADDLE_ENFORCE
(
it
!=
Instance
().
id_to_type_map_
.
end
(),
"VarId %d is not registered."
,
var_id
);
return
it
->
second
;
}
static
int
ToTypeId
(
const
std
::
type_index
&
type
)
{
auto
it
=
Instance
().
type_to_id_map_
.
find
(
type
);
PADDLE_ENFORCE
(
it
!=
Instance
().
type_to_id_map_
.
end
(),
"VarType %s is not registered."
,
type
.
name
());
return
it
->
second
;
}
private:
VarIdToTypeIndexMapHolder
()
{
VarIdToTypeIndexMapInitializer
::
Init
(
&
var_type_map_
);
VarIdToTypeIndexMapInitializer
::
Init
(
&
id_to_type_map_
,
&
type_to_id_map_
);
}
static
const
VarIdToTypeIndexMapHolder
&
Instance
()
{
static
const
VarIdToTypeIndexMapHolder
instance
;
return
instance
;
}
std
::
unordered_map
<
int
,
std
::
type_index
>
var_type_map_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
id_to_type_map_
;
std
::
unordered_map
<
std
::
type_index
,
int
>
type_to_id_map_
;
};
}
// namespace detail
const
char
*
ToTypeName
(
int
var_id
)
{
return
ToTypeIndex
(
var_id
).
name
();
}
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
)
{
return
detail
::
VarIdToTypeIndexMapHolder
::
ToTypeIndex
(
var_id
);
}
const
char
*
ToTypeName
(
int
var_id
)
{
return
ToTypeIndex
(
var_id
).
name
();
}
int
ToTypeId
(
const
std
::
type_index
&
type
)
{
return
detail
::
VarIdToTypeIndexMapHolder
::
ToTypeId
(
type
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/var_type_traits.h
浏览文件 @
13429c3e
...
...
@@ -42,6 +42,7 @@ namespace framework {
const
char
*
ToTypeName
(
int
var_id
);
const
std
::
type_index
&
ToTypeIndex
(
int
var_id
);
int
ToTypeId
(
const
std
::
type_index
&
type
);
namespace
detail
{
...
...
@@ -75,10 +76,10 @@ struct VarTypeRegistryImpl {
using
ArgTuple
=
std
::
tuple
<
Args
...
>
;
// TypePos() returns the position in which T is inside Args...
// If T is not inside Args...
or T is void
, return -1
// If T is not inside Args..., return -1
template
<
typename
T
>
static
constexpr
int
TypePos
()
{
return
std
::
is_same
<
T
,
void
>::
value
?
-
1
:
TypePosFinder
<
T
,
Args
...
>::
kPos
;
return
TypePosFinder
<
T
,
Args
...
>::
kPos
;
}
// IsRegistered() returns whether T is registered inside RegistryImpl
...
...
@@ -90,19 +91,22 @@ struct VarTypeRegistryImpl {
}
// namespace detail
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId =
proto_id;
\
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id)
\
template <>
\
struct VarTypeTrait<type> {
\
static_assert(VarTypeRegistry::IsRegistered<type>(),
\
"Must be registered type");
\
using Type = type;
\
static constexpr int kId =
static_cast<int>(proto_id);
\
}
/**
* The following codes are designed to register variable types.
* Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable.
*
* Caution: If you want to add more var types, please consider carefully
* whether you really need to add it.
*/
// Users should add other variable types below.
...
...
@@ -110,10 +114,9 @@ struct VarTypeRegistryImpl {
class
Scope
;
using
VarTypeRegistry
=
detail
::
VarTypeRegistryImpl
<
LoDTensor
,
SelectedRows
,
std
::
vector
<
Scope
*>
,
LoDRankTable
,
LoDTensorArray
,
platform
::
PlaceList
,
ReaderHolder
,
Tenso
r
,
std
::
string
,
Scope
*
,
Tensor
,
LoDTensor
,
SelectedRows
,
std
::
vector
<
Scope
*>
,
LoDRankTable
,
LoDTensorArray
,
platform
::
PlaceList
,
ReaderHolde
r
,
std
::
string
,
Scope
*
,
std
::
map
<
size_t
,
Tensor
>
,
operators
::
reader
::
LoDTensorBlockingQueueHolder
,
int
,
float
,
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
ncclUniqueId
,
platform
::
Communicator
,
...
...
@@ -123,13 +126,11 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
operators
::
AlgorithmsCache
<
cudnnConvolutionBwdFilterAlgo_t
>
,
operators
::
CudnnRNNCache
,
#endif
void
>
;
// void indicates end of registration, add other types before void
int
,
float
>
;
template
<
typename
T
>
struct
VarTypeTrait
{
static_assert
(
std
::
is_same
<
T
,
void
>::
value
||
VarTypeRegistry
::
IsRegistered
<
T
>
(),
"Must be registered type"
);
static_assert
(
VarTypeRegistry
::
IsRegistered
<
T
>
(),
"Must be registered type"
);
using
Type
=
T
;
// Default id generation
static
constexpr
int
kId
=
VarTypeRegistry
::
TypePos
<
T
>
()
+
...
...
paddle/fluid/framework/var_type_traits_test.cc
浏览文件 @
13429c3e
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/var_type_traits.h"
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <unordered_set>
namespace
paddle
{
...
...
@@ -29,15 +30,27 @@ struct TypeIndexChecker {
static_assert
(
std
::
is_same
<
typename
VarTypeTrait
<
Type
>::
Type
,
Type
>::
value
,
"Type must be the same"
);
constexpr
auto
kId
=
VarTypeTrait
<
Type
>::
kId
;
if
(
!
std
::
is_same
<
Type
,
void
>::
value
)
{
std
::
type_index
actual_type
(
typeid
(
Type
));
EXPECT_EQ
(
std
::
string
(
ToTypeName
(
kId
)),
std
::
string
(
actual_type
.
name
()));
EXPECT_EQ
(
ToTypeIndex
(
kId
),
actual_type
);
EXPECT_TRUE
(
var_id_set
->
count
(
kId
)
==
0
);
// NOLINT
EXPECT_TRUE
(
type_index_set
->
count
(
actual_type
)
==
0
);
// NOLINT
var_id_set
->
insert
(
kId
);
type_index_set
->
insert
(
std
::
type_index
(
typeid
(
Type
)));
std
::
type_index
actual_type
(
typeid
(
Type
));
EXPECT_EQ
(
std
::
string
(
ToTypeName
(
kId
)),
std
::
string
(
actual_type
.
name
()));
// For some reasons, comparing std::type_index using EXPECT_EQ would fail
// in MAC CI
bool
is_same_type_index
=
(
ToTypeIndex
(
kId
)
==
actual_type
);
if
(
!
is_same_type_index
)
{
std
::
string
s1
=
ToTypeName
(
kId
);
std
::
string
s2
=
actual_type
.
name
();
PADDLE_THROW
(
"Step %d: type %s is not the same as %s, var_id %d"
,
kPos
,
s1
.
c_str
(),
s2
.
c_str
(),
kId
);
}
EXPECT_TRUE
(
is_same_type_index
);
EXPECT_TRUE
(
ToTypeId
(
actual_type
)
==
kId
);
// NOLINT
is_same_type_index
=
(
ToTypeIndex
(
ToTypeId
(
actual_type
))
==
actual_type
);
EXPECT_TRUE
(
is_same_type_index
);
EXPECT_EQ
(
ToTypeId
(
ToTypeIndex
(
kId
)),
kId
);
EXPECT_TRUE
(
var_id_set
->
count
(
kId
)
==
0
);
// NOLINT
EXPECT_TRUE
(
type_index_set
->
count
(
actual_type
)
==
0
);
// NOLINT
var_id_set
->
insert
(
kId
);
type_index_set
->
insert
(
std
::
type_index
(
typeid
(
Type
)));
TypeIndexChecker
<
kPos
+
1
,
kEnd
,
kPos
+
1
==
kEnd
>::
Check
(
var_id_set
,
type_index_set
);
}
...
...
@@ -75,13 +88,11 @@ TEST(var_type_traits, check_proto_type_id) {
}
TEST
(
var_type_traits
,
test_registry
)
{
using
Registry
=
detail
::
VarTypeRegistryImpl
<
int8_t
,
int32_t
,
size_t
,
double
,
void
>
;
using
Registry
=
detail
::
VarTypeRegistryImpl
<
int8_t
,
int32_t
,
size_t
,
double
>
;
ASSERT_TRUE
(
Registry
::
TypePos
<
int8_t
>
()
==
0
);
ASSERT_TRUE
(
Registry
::
TypePos
<
int32_t
>
()
==
1
);
ASSERT_TRUE
(
Registry
::
TypePos
<
size_t
>
()
==
2
);
ASSERT_TRUE
(
Registry
::
TypePos
<
double
>
()
==
3
);
ASSERT_TRUE
(
Registry
::
TypePos
<
void
>
()
==
-
1
);
ASSERT_TRUE
(
Registry
::
TypePos
<
float
>
()
==
-
1
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录