Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c41fd033
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c41fd033
编写于
11月 05, 2020
作者:
石
石晓伟
提交者:
GitHub
11月 05, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
check op_version_registry in CI test, test=develop (#28402)
上级
2500dca8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
129 addition
and
25 deletion
+129
-25
paddle/fluid/framework/op_version_registry.h
paddle/fluid/framework/op_version_registry.h
+4
-3
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
.../fluid/operators/detection/distribute_fpn_proposals_op.cc
+1
-1
paddle/fluid/pybind/compatible.cc
paddle/fluid/pybind/compatible.cc
+1
-2
tools/check_op_desc.py
tools/check_op_desc.py
+123
-19
未找到文件。
paddle/fluid/framework/op_version_registry.h
浏览文件 @
c41fd033
...
...
@@ -92,7 +92,7 @@ enum class OpUpdateType {
class
OpUpdateBase
{
public:
virtual
const
OpUpdateInfo
*
info
()
const
=
0
;
virtual
const
OpUpdateInfo
&
info
()
const
=
0
;
virtual
OpUpdateType
type
()
const
=
0
;
virtual
~
OpUpdateBase
()
=
default
;
};
...
...
@@ -101,7 +101,7 @@ template <typename InfoType, OpUpdateType type__>
class
OpUpdate
:
public
OpUpdateBase
{
public:
explicit
OpUpdate
(
const
InfoType
&
info
)
:
info_
{
info
},
type_
{
type__
}
{}
const
OpUpdateInfo
*
info
()
const
override
{
return
&
info_
;
}
const
InfoType
&
info
()
const
override
{
return
info_
;
}
OpUpdateType
type
()
const
override
{
return
type_
;
}
private:
...
...
@@ -169,7 +169,6 @@ class OpVersion {
class
OpVersionRegistrar
{
public:
OpVersionRegistrar
()
=
default
;
static
OpVersionRegistrar
&
GetInstance
()
{
static
OpVersionRegistrar
instance
;
return
instance
;
...
...
@@ -185,6 +184,8 @@ class OpVersionRegistrar {
private:
std
::
unordered_map
<
std
::
string
,
OpVersion
>
op_version_map_
;
OpVersionRegistrar
()
=
default
;
OpVersionRegistrar
&
operator
=
(
const
OpVersionRegistrar
&
)
=
delete
;
};
inline
const
std
::
unordered_map
<
std
::
string
,
OpVersion
>&
get_op_version_map
()
{
...
...
paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc
浏览文件 @
c41fd033
...
...
@@ -130,7 +130,7 @@ REGISTER_OP_VERSION(distribute_fpn_proposals)
Upgrade distribute_fpn_proposals add a new input
[RoisNum] and add a new output [MultiLevelRoIsNum].)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
()
.
NewInput
(
"Ro
I
sNum"
,
"The number of RoIs in each image."
)
.
NewInput
(
"Ro
i
sNum"
,
"The number of RoIs in each image."
)
.
NewOutput
(
"MultiLevelRoisNum"
,
"The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
...
...
paddle/fluid/pybind/compatible.cc
浏览文件 @
c41fd033
...
...
@@ -95,8 +95,7 @@ void BindOpUpdateType(py::module *m) {
void
BindOpUpdateBase
(
py
::
module
*
m
)
{
py
::
class_
<
OpUpdateBase
>
(
*
m
,
"OpUpdateBase"
)
.
def
(
"info"
,
[](
const
OpUpdateBase
&
obj
)
{
return
obj
.
info
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"info"
,
&
OpUpdateBase
::
info
,
py
::
return_value_policy
::
reference
)
.
def
(
"type"
,
&
OpUpdateBase
::
type
);
}
...
...
tools/check_op_desc.py
浏览文件 @
c41fd033
...
...
@@ -14,6 +14,8 @@
import
json
import
sys
from
paddle.utils
import
OpLastCheckpointChecker
from
paddle.fluid.core
import
OpUpdateType
SAME
=
0
...
...
@@ -21,7 +23,14 @@ INPUTS = "Inputs"
OUTPUTS
=
"Outputs"
ATTRS
=
"Attrs"
# The constant `ADD` means that an item has been added. In particular,
# we use `ADD_WITH_DEFAULT` to mean adding attributes with default
# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or
# outputs.
ADD_WITH_DEFAULT
=
"Add_with_default"
ADD_DISPENSABLE
=
"Add_dispensable"
ADD
=
"Add"
DELETE
=
"Delete"
CHANGE
=
"Change"
...
...
@@ -35,12 +44,26 @@ DEFAULT_VALUE = "default_value"
error
=
False
version_update_map
=
{
INPUTS
:
{
ADD
:
OpUpdateType
.
kNewInput
,
},
OUTPUTS
:
{
ADD
:
OpUpdateType
.
kNewOutput
,
},
ATTRS
:
{
ADD
:
OpUpdateType
.
kNewAttr
,
CHANGE
:
OpUpdateType
.
kModifyAttr
,
},
}
def
diff_vars
(
origin_vars
,
new_vars
):
global
error
var_error
=
False
var_changed_error_massage
=
{}
var_added_error_massage
=
[]
var_add_massage
=
[]
var_add_dispensable_massage
=
[]
var_deleted_error_massage
=
[]
common_vars_name
=
set
(
origin_vars
.
keys
())
&
set
(
new_vars
.
keys
())
...
...
@@ -65,13 +88,16 @@ def diff_vars(origin_vars, new_vars):
var_deleted_error_massage
.
append
(
var_name
)
for
var_name
in
vars_name_only_in_new
:
var_add_massage
.
append
(
var_name
)
if
not
new_vars
.
get
(
var_name
).
get
(
DISPENSABLE
):
error
,
var_error
=
True
,
True
var_add
ed_error
_massage
.
append
(
var_name
)
var_add
_dispensable
_massage
.
append
(
var_name
)
var_diff_message
=
{}
if
var_added_error_massage
:
var_diff_message
[
ADD
]
=
var_added_error_massage
if
var_add_massage
:
var_diff_message
[
ADD
]
=
var_add_massage
if
var_add_dispensable_massage
:
var_diff_message
[
ADD_DISPENSABLE
]
=
var_add_dispensable_massage
if
var_changed_error_massage
:
var_diff_message
[
CHANGE
]
=
var_changed_error_massage
if
var_deleted_error_massage
:
...
...
@@ -86,6 +112,7 @@ def diff_attr(ori_attrs, new_attrs):
attr_changed_error_massage
=
{}
attr_added_error_massage
=
[]
attr_added_def_error_massage
=
[]
attr_deleted_error_massage
=
[]
common_attrs
=
set
(
ori_attrs
.
keys
())
&
set
(
new_attrs
.
keys
())
...
...
@@ -110,13 +137,16 @@ def diff_attr(ori_attrs, new_attrs):
attr_deleted_error_massage
.
append
(
attr_name
)
for
attr_name
in
attrs_only_in_new
:
attr_added_error_massage
.
append
(
attr_name
)
if
new_attrs
.
get
(
attr_name
).
get
(
DEFAULT_VALUE
)
==
None
:
error
,
attr_error
=
True
,
True
attr_added_error_massage
.
append
(
attr_name
)
attr_added_
def_
error_massage
.
append
(
attr_name
)
attr_diff_message
=
{}
if
attr_added_error_massage
:
attr_diff_message
[
ADD
]
=
attr_added_error_massage
if
attr_added_def_error_massage
:
attr_diff_message
[
ADD_WITH_DEFAULT
]
=
attr_added_def_error_massage
if
attr_changed_error_massage
:
attr_diff_message
[
CHANGE
]
=
attr_changed_error_massage
if
attr_deleted_error_massage
:
...
...
@@ -125,15 +155,39 @@ def diff_attr(ori_attrs, new_attrs):
return
attr_error
,
attr_diff_message
def
check_io_registry
(
io_type
,
op
,
diff
):
checker
=
OpLastCheckpointChecker
()
results
=
{}
for
update_type
in
[
ADD
]:
for
item
in
diff
.
get
(
update_type
,
{}):
infos
=
checker
.
filter_updates
(
op
,
version_update_map
[
io_type
][
update_type
],
item
)
if
not
infos
:
results
[
update_type
]
=
(
op
,
item
,
io_type
)
return
results
def
check_attr_registry
(
op
,
diff
):
checker
=
OpLastCheckpointChecker
()
results
=
{}
for
update_type
in
[
ADD
,
CHANGE
]:
for
item
in
diff
.
get
(
update_type
,
{}):
infos
=
checker
.
filter_updates
(
op
,
version_update_map
[
ATTRS
][
update_type
],
item
)
if
not
infos
:
results
[
update_type
]
=
(
op
,
item
)
return
results
def
compare_op_desc
(
origin_op_desc
,
new_op_desc
):
origin
=
json
.
loads
(
origin_op_desc
)
new
=
json
.
loads
(
new_op_desc
)
error_message
=
{}
desc_error_message
=
{}
version_error_message
=
{}
if
cmp
(
origin_op_desc
,
new_op_desc
)
==
SAME
:
return
error_message
return
desc_error_message
,
version_
error_message
for
op_type
in
origin
:
# no need to compare if the operator is deleted
if
op_type
not
in
new
:
continue
...
...
@@ -144,33 +198,47 @@ def compare_op_desc(origin_op_desc, new_op_desc):
origin_inputs
=
origin_info
.
get
(
INPUTS
,
{})
new_inputs
=
new_info
.
get
(
INPUTS
,
{})
ins_error
,
ins_diff
=
diff_vars
(
origin_inputs
,
new_inputs
)
ins_version_errors
=
check_io_registry
(
INPUTS
,
op_type
,
ins_diff
)
origin_outputs
=
origin_info
.
get
(
OUTPUTS
,
{})
new_outputs
=
new_info
.
get
(
OUTPUTS
,
{})
outs_error
,
outs_diff
=
diff_vars
(
origin_outputs
,
new_outputs
)
outs_version_errors
=
check_io_registry
(
OUTPUTS
,
op_type
,
outs_diff
)
origin_attrs
=
origin_info
.
get
(
ATTRS
,
{})
new_attrs
=
new_info
.
get
(
ATTRS
,
{})
attrs_error
,
attrs_diff
=
diff_attr
(
origin_attrs
,
new_attrs
)
attrs_version_errors
=
check_attr_registry
(
op_type
,
attrs_diff
)
if
ins_error
:
error_message
.
setdefault
(
op_type
,
{})[
INPUTS
]
=
ins_diff
desc_
error_message
.
setdefault
(
op_type
,
{})[
INPUTS
]
=
ins_diff
if
outs_error
:
error_message
.
setdefault
(
op_type
,
{})[
OUTPUTS
]
=
outs_diff
desc_
error_message
.
setdefault
(
op_type
,
{})[
OUTPUTS
]
=
outs_diff
if
attrs_error
:
error_message
.
setdefault
(
op_type
,
{})[
ATTRS
]
=
attrs_diff
desc_
error_message
.
setdefault
(
op_type
,
{})[
ATTRS
]
=
attrs_diff
return
error_message
if
ins_version_errors
:
version_error_message
.
setdefault
(
op_type
,
{})[
INPUTS
]
=
ins_version_errors
if
outs_version_errors
:
version_error_message
.
setdefault
(
op_type
,
{})[
OUTPUTS
]
=
outs_version_errors
if
attrs_version_errors
:
version_error_message
.
setdefault
(
op_type
,
{})[
ATTRS
]
=
attrs_version_errors
return
desc_error_message
,
version_error_message
def
print_error_message
(
error_message
):
print
(
"Op desc error for the changes of Inputs/Outputs/Attrs of OPs:
\n
"
)
def
print_desc_error_message
(
error_message
):
print
(
"
\n
=======================
\n
"
"Op desc error for the changes of Inputs/Outputs/Attrs of OPs:
\n
"
)
for
op_name
in
error_message
:
print
(
"For OP '{}':"
.
format
(
op_name
))
# 1. print inputs error message
Inputs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
INPUTS
,
{})
for
name
in
Inputs_error
.
get
(
ADD
,
{}):
for
name
in
Inputs_error
.
get
(
ADD
_DISPENSABLE
,
{}):
print
(
" * The added Input '{}' is not dispensable."
.
format
(
name
))
for
name
in
Inputs_error
.
get
(
DELETE
,
{}):
...
...
@@ -186,7 +254,7 @@ def print_error_message(error_message):
# 2. print outputs error message
Outputs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
OUTPUTS
,
{})
for
name
in
Outputs_error
.
get
(
ADD
,
{}):
for
name
in
Outputs_error
.
get
(
ADD
_DISPENSABLE
,
{}):
print
(
" * The added Output '{}' is not dispensable."
.
format
(
name
))
for
name
in
Outputs_error
.
get
(
DELETE
,
{}):
...
...
@@ -202,7 +270,7 @@ def print_error_message(error_message):
# 3. print attrs error message
attrs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
ATTRS
,
{})
for
name
in
attrs_error
.
get
(
ADD
,
{}):
for
name
in
attrs_error
.
get
(
ADD
_WITH_DEFAULT
,
{}):
print
(
" * The added attr '{}' doesn't set default value."
.
format
(
name
))
...
...
@@ -218,6 +286,40 @@ def print_error_message(error_message):
format
(
arg
,
name
,
ori_value
,
new_value
))
def
print_version_error_message
(
error_message
):
print
(
"
\n
=======================
\n
"
"Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:
\n
"
)
for
op_name
in
error_message
:
print
(
"For OP '{}':"
.
format
(
op_name
))
# 1. print inputs error message
inputs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
INPUTS
,
{})
tuple
=
inputs_error
.
get
(
ADD
,
{})
if
tuple
:
print
(
" * The added input '{}' is not yet registered."
.
format
(
tuple
[
1
]))
# 2. print inputs error message
outputs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
OUTPUTS
,
{})
tuple
=
outputs_error
.
get
(
ADD
,
{})
if
tuple
:
print
(
" * The added output '{}' is not yet registered."
.
format
(
tuple
[
1
]))
#3. print attrs error message
attrs_error
=
error_message
.
get
(
op_name
,
{}).
get
(
ATTRS
,
{})
tuple
=
attrs_error
.
get
(
ADD
,
{})
if
tuple
:
print
(
" * The added attribute '{}' is not yet registered."
.
format
(
tuple
[
1
]))
tuple
=
attrs_error
.
get
(
CHANGE
,
{})
if
tuple
:
print
(
" * The change of attribute '{}' is not yet registered."
.
format
(
tuple
[
1
]))
def
print_repeat_process
():
print
(
"Tips:"
...
...
@@ -241,10 +343,12 @@ if len(sys.argv) == 3:
with
open
(
sys
.
argv
[
2
],
'r'
)
as
f
:
new_op_desc
=
f
.
read
()
error_message
=
compare_op_desc
(
origin_op_desc
,
new_op_desc
)
desc_error_message
,
version_error_message
=
compare_op_desc
(
origin_op_desc
,
new_op_desc
)
if
error
:
print
(
"-"
*
30
)
print_error_message
(
error_message
)
print_desc_error_message
(
desc_error_message
)
print_version_error_message
(
version_error_message
)
print
(
"-"
*
30
)
else
:
print
(
"Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录