Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
61fc7a3e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
61fc7a3e
编写于
9月 03, 2020
作者:
S
Shang Zhizhou
提交者:
GitHub
9月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pass version check (#26887)
上级
f772540d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
202 addition
and
0 deletion
+202
-0
paddle/fluid/framework/op_version_registry.h
paddle/fluid/framework/op_version_registry.h
+136
-0
paddle/fluid/framework/op_version_registry_test.cc
paddle/fluid/framework/op_version_registry_test.cc
+66
-0
未找到文件。
paddle/fluid/framework/op_version_registry.h
浏览文件 @
61fc7a3e
...
@@ -133,6 +133,9 @@ class OpVersion {
...
@@ -133,6 +133,9 @@ class OpVersion {
checkpoints_
.
push_back
(
Checkpoint
({
note
,
op_version_desc
}));
checkpoints_
.
push_back
(
Checkpoint
({
note
,
op_version_desc
}));
return
*
this
;
return
*
this
;
}
}
uint32_t
GetVersionID
()
const
{
return
static_cast
<
uint32_t
>
(
checkpoints_
.
size
());
}
private:
private:
struct
Checkpoint
{
struct
Checkpoint
{
...
@@ -156,6 +159,14 @@ class OpVersionRegistrar {
...
@@ -156,6 +159,14 @@ class OpVersionRegistrar {
op_version_map_
.
insert
({
op_type
,
OpVersion
()});
op_version_map_
.
insert
({
op_type
,
OpVersion
()});
return
op_version_map_
[
op_type
];
return
op_version_map_
[
op_type
];
}
}
uint32_t
GetVersionID
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
op_version_map_
.
find
(
op_type
);
if
(
it
==
op_version_map_
.
end
())
{
return
0
;
}
return
it
->
second
.
GetVersionID
();
}
private:
private:
std
::
unordered_map
<
std
::
string
,
OpVersion
>
op_version_map_
;
std
::
unordered_map
<
std
::
string
,
OpVersion
>
op_version_map_
;
...
@@ -164,6 +175,125 @@ class OpVersionRegistrar {
...
@@ -164,6 +175,125 @@ class OpVersionRegistrar {
OpVersionRegistrar
&
operator
=
(
const
OpVersionRegistrar
&
)
=
delete
;
OpVersionRegistrar
&
operator
=
(
const
OpVersionRegistrar
&
)
=
delete
;
};
};
class
OpVersionComparator
{
public:
virtual
bool
operator
()()
=
0
;
virtual
~
OpVersionComparator
()
=
default
;
};
#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \
return OpVersionRegistrar::GetInstance().GetVersionID(op_name_) \
cmp_math target_version_; \
} \
virtual ~OpVersion##cmp_name##Comparator() {} \
\
private: \
std::string op_name_; \
uint32_t target_version_; \
};
ADD_OP_VERSION_COMPARATOR
(
LE
,
<=
);
ADD_OP_VERSION_COMPARATOR
(
EQ
,
==
);
ADD_OP_VERSION_COMPARATOR
(
GE
,
>=
);
ADD_OP_VERSION_COMPARATOR
(
NE
,
!=
);
class
OpVersionComparatorCombination
{
public:
OpVersionComparatorCombination
()
{}
OpVersionComparatorCombination
&
LE
(
const
std
::
string
&
op_name
,
int
target_version
)
{
op_version_comparators_
.
push_back
(
std
::
shared_ptr
<
OpVersionComparator
>
(
new
OpVersionLEComparator
(
op_name
,
target_version
)));
return
*
this
;
}
OpVersionComparatorCombination
&
EQ
(
const
std
::
string
&
op_name
,
int
target_version
)
{
op_version_comparators_
.
push_back
(
std
::
shared_ptr
<
OpVersionComparator
>
(
new
OpVersionEQComparator
(
op_name
,
target_version
)));
return
*
this
;
}
OpVersionComparatorCombination
&
GE
(
const
std
::
string
&
op_name
,
int
target_version
)
{
op_version_comparators_
.
push_back
(
std
::
shared_ptr
<
OpVersionComparator
>
(
new
OpVersionGEComparator
(
op_name
,
target_version
)));
return
*
this
;
}
OpVersionComparatorCombination
&
NE
(
const
std
::
string
&
op_name
,
int
target_version
)
{
op_version_comparators_
.
push_back
(
std
::
shared_ptr
<
OpVersionComparator
>
(
new
OpVersionNEComparator
(
op_name
,
target_version
)));
return
*
this
;
}
bool
IsMatched
()
const
{
for
(
const
auto
&
cmp
:
op_version_comparators_
)
{
if
(
!
(
*
cmp
)())
{
return
false
;
}
}
return
true
;
}
private:
std
::
vector
<
std
::
shared_ptr
<
OpVersionComparator
>>
op_version_comparators_
;
};
class
PassVersionCheckers
{
public:
PassVersionCheckers
&
AddCombination
(
const
OpVersionComparatorCombination
&
combinations
)
{
pass_version_checkers_
.
push_back
(
combinations
);
return
*
this
;
}
bool
IsPassCompatible
()
const
{
if
(
pass_version_checkers_
.
empty
())
{
return
true
;
}
for
(
const
auto
&
checker
:
pass_version_checkers_
)
{
if
(
checker
.
IsMatched
())
{
return
true
;
}
}
return
false
;
}
private:
std
::
vector
<
OpVersionComparatorCombination
>
pass_version_checkers_
;
};
class
PassVersionCheckerRegistrar
{
public:
static
PassVersionCheckerRegistrar
&
GetInstance
()
{
static
PassVersionCheckerRegistrar
instance
;
return
instance
;
}
PassVersionCheckers
&
Register
(
const
std
::
string
&
pass_name
)
{
return
pass_version_checkers_map_
[
pass_name
];
}
bool
IsPassCompatible
(
const
std
::
string
&
fuse_pass_name
)
const
{
auto
iter
=
pass_version_checkers_map_
.
find
(
fuse_pass_name
);
if
(
iter
==
pass_version_checkers_map_
.
end
())
{
return
true
;
}
return
iter
->
second
.
IsPassCompatible
();
}
private:
std
::
unordered_map
<
std
::
string
,
PassVersionCheckers
>
pass_version_checkers_map_
;
PassVersionCheckerRegistrar
()
=
default
;
PassVersionCheckerRegistrar
&
operator
=
(
const
PassVersionCheckerRegistrar
&
)
=
delete
;
};
}
// namespace compatible
}
// namespace compatible
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
@@ -173,3 +303,9 @@ class OpVersionRegistrar {
...
@@ -173,3 +303,9 @@ class OpVersionRegistrar {
RegisterOpVersion__##op_type = \
RegisterOpVersion__##op_type = \
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
paddle::framework::compatible::OpVersionRegistrar::GetInstance() \
.Register(#op_type)
.Register(#op_type)
#define REGISTER_PASS_CAPABILITY(pass_name) \
static auto RegisterOpPassVersionChecker__##pass_name = \
paddle::framework::compatible::PassVersionCheckerRegistrar:: \
GetInstance() \
.Register(#pass_name)
paddle/fluid/framework/op_version_registry_test.cc
浏览文件 @
61fc7a3e
...
@@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) {
...
@@ -55,6 +55,72 @@ TEST(test_operator_version, test_operator_version) {
.
NewInput
(
"X2"
,
"The second input."
)
.
NewInput
(
"X2"
,
"The second input."
)
.
NewOutput
(
"Y2"
,
"The second output."
));
.
NewOutput
(
"Y2"
,
"The second output."
));
}
}
TEST
(
test_pass_op_version_checker
,
test_pass_op_version_checker
)
{
ASSERT_TRUE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"no_bind_pass"
));
REGISTER_PASS_CAPABILITY
(
test_pass1
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"mul"
,
1
)
.
EQ
(
"fc"
,
0
));
ASSERT_TRUE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass1"
));
REGISTER_PASS_CAPABILITY
(
test_pass2
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"mul"
,
0
)
.
NE
(
"fc"
,
0
));
ASSERT_FALSE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass2"
));
REGISTER_PASS_CAPABILITY
(
test_pass3
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"mul"
,
0
)
.
NE
(
"fc"
,
0
))
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"mul"
,
1
)
.
EQ
(
"fc"
,
0
));
ASSERT_TRUE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass3"
));
REGISTER_PASS_CAPABILITY
(
test_pass4
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"test__"
,
5
)
.
EQ
(
"fc"
,
0
));
ASSERT_FALSE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass4"
));
REGISTER_PASS_CAPABILITY
(
test_pass5
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"test__"
,
4
)
.
EQ
(
"fc"
,
0
));
ASSERT_TRUE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass5"
));
REGISTER_PASS_CAPABILITY
(
test_pass6
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"test__"
,
4
)
.
EQ
(
"fc"
,
0
));
ASSERT_TRUE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass6"
));
REGISTER_PASS_CAPABILITY
(
test_pass7
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
NE
(
"test__"
,
4
)
.
EQ
(
"fc"
,
0
));
ASSERT_FALSE
(
PassVersionCheckerRegistrar
::
GetInstance
().
IsPassCompatible
(
"test_pass7"
));
}
}
// namespace compatible
}
// namespace compatible
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录