未验证 提交 c41fd033 编写于 作者: 石晓伟 提交者: GitHub

check op_version_registry in CI test, test=develop (#28402)

上级 2500dca8
...@@ -92,7 +92,7 @@ enum class OpUpdateType { ...@@ -92,7 +92,7 @@ enum class OpUpdateType {
class OpUpdateBase { class OpUpdateBase {
public: public:
virtual const OpUpdateInfo* info() const = 0; virtual const OpUpdateInfo& info() const = 0;
virtual OpUpdateType type() const = 0; virtual OpUpdateType type() const = 0;
virtual ~OpUpdateBase() = default; virtual ~OpUpdateBase() = default;
}; };
...@@ -101,7 +101,7 @@ template <typename InfoType, OpUpdateType type__> ...@@ -101,7 +101,7 @@ template <typename InfoType, OpUpdateType type__>
class OpUpdate : public OpUpdateBase { class OpUpdate : public OpUpdateBase {
public: public:
explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {} explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {}
const OpUpdateInfo* info() const override { return &info_; } const InfoType& info() const override { return info_; }
OpUpdateType type() const override { return type_; } OpUpdateType type() const override { return type_; }
private: private:
...@@ -169,7 +169,6 @@ class OpVersion { ...@@ -169,7 +169,6 @@ class OpVersion {
class OpVersionRegistrar { class OpVersionRegistrar {
public: public:
OpVersionRegistrar() = default;
static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar& GetInstance() {
static OpVersionRegistrar instance; static OpVersionRegistrar instance;
return instance; return instance;
...@@ -185,6 +184,8 @@ class OpVersionRegistrar { ...@@ -185,6 +184,8 @@ class OpVersionRegistrar {
private: private:
std::unordered_map<std::string, OpVersion> op_version_map_; std::unordered_map<std::string, OpVersion> op_version_map_;
OpVersionRegistrar() = default;
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
}; };
inline const std::unordered_map<std::string, OpVersion>& get_op_version_map() { inline const std::unordered_map<std::string, OpVersion>& get_op_version_map() {
......
...@@ -130,7 +130,7 @@ REGISTER_OP_VERSION(distribute_fpn_proposals) ...@@ -130,7 +130,7 @@ REGISTER_OP_VERSION(distribute_fpn_proposals)
Upgrade distribute_fpn_proposals add a new input Upgrade distribute_fpn_proposals add a new input
[RoisNum] and add a new output [MultiLevelRoIsNum].)ROC", [RoisNum] and add a new output [MultiLevelRoIsNum].)ROC",
paddle::framework::compatible::OpVersionDesc() paddle::framework::compatible::OpVersionDesc()
.NewInput("RoIsNum", "The number of RoIs in each image.") .NewInput("RoisNum", "The number of RoIs in each image.")
.NewOutput("MultiLevelRoisNum", .NewOutput("MultiLevelRoisNum",
"The RoIs' number of each image on multiple " "The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B)," "levels. The number on each level has the shape of (B),"
......
...@@ -95,8 +95,7 @@ void BindOpUpdateType(py::module *m) { ...@@ -95,8 +95,7 @@ void BindOpUpdateType(py::module *m) {
void BindOpUpdateBase(py::module *m) { void BindOpUpdateBase(py::module *m) {
py::class_<OpUpdateBase>(*m, "OpUpdateBase") py::class_<OpUpdateBase>(*m, "OpUpdateBase")
.def("info", [](const OpUpdateBase &obj) { return obj.info(); }, .def("info", &OpUpdateBase::info, py::return_value_policy::reference)
py::return_value_policy::reference)
.def("type", &OpUpdateBase::type); .def("type", &OpUpdateBase::type);
} }
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import json import json
import sys import sys
from paddle.utils import OpLastCheckpointChecker
from paddle.fluid.core import OpUpdateType
SAME = 0 SAME = 0
...@@ -21,7 +23,14 @@ INPUTS = "Inputs" ...@@ -21,7 +23,14 @@ INPUTS = "Inputs"
OUTPUTS = "Outputs" OUTPUTS = "Outputs"
ATTRS = "Attrs" ATTRS = "Attrs"
# The constant `ADD` means that an item has been added. In particular,
# we use `ADD_WITH_DEFAULT` to mean adding attributes with default
# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or
# outputs.
ADD_WITH_DEFAULT = "Add_with_default"
ADD_DISPENSABLE = "Add_dispensable"
ADD = "Add" ADD = "Add"
DELETE = "Delete" DELETE = "Delete"
CHANGE = "Change" CHANGE = "Change"
...@@ -35,12 +44,26 @@ DEFAULT_VALUE = "default_value" ...@@ -35,12 +44,26 @@ DEFAULT_VALUE = "default_value"
error = False error = False
version_update_map = {
INPUTS: {
ADD: OpUpdateType.kNewInput,
},
OUTPUTS: {
ADD: OpUpdateType.kNewOutput,
},
ATTRS: {
ADD: OpUpdateType.kNewAttr,
CHANGE: OpUpdateType.kModifyAttr,
},
}
def diff_vars(origin_vars, new_vars): def diff_vars(origin_vars, new_vars):
global error global error
var_error = False var_error = False
var_changed_error_massage = {} var_changed_error_massage = {}
var_added_error_massage = [] var_add_massage = []
var_add_dispensable_massage = []
var_deleted_error_massage = [] var_deleted_error_massage = []
common_vars_name = set(origin_vars.keys()) & set(new_vars.keys()) common_vars_name = set(origin_vars.keys()) & set(new_vars.keys())
...@@ -65,13 +88,16 @@ def diff_vars(origin_vars, new_vars): ...@@ -65,13 +88,16 @@ def diff_vars(origin_vars, new_vars):
var_deleted_error_massage.append(var_name) var_deleted_error_massage.append(var_name)
for var_name in vars_name_only_in_new: for var_name in vars_name_only_in_new:
var_add_massage.append(var_name)
if not new_vars.get(var_name).get(DISPENSABLE): if not new_vars.get(var_name).get(DISPENSABLE):
error, var_error = True, True error, var_error = True, True
var_added_error_massage.append(var_name) var_add_dispensable_massage.append(var_name)
var_diff_message = {} var_diff_message = {}
if var_added_error_massage: if var_add_massage:
var_diff_message[ADD] = var_added_error_massage var_diff_message[ADD] = var_add_massage
if var_add_dispensable_massage:
var_diff_message[ADD_DISPENSABLE] = var_add_dispensable_massage
if var_changed_error_massage: if var_changed_error_massage:
var_diff_message[CHANGE] = var_changed_error_massage var_diff_message[CHANGE] = var_changed_error_massage
if var_deleted_error_massage: if var_deleted_error_massage:
...@@ -86,6 +112,7 @@ def diff_attr(ori_attrs, new_attrs): ...@@ -86,6 +112,7 @@ def diff_attr(ori_attrs, new_attrs):
attr_changed_error_massage = {} attr_changed_error_massage = {}
attr_added_error_massage = [] attr_added_error_massage = []
attr_added_def_error_massage = []
attr_deleted_error_massage = [] attr_deleted_error_massage = []
common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys()) common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys())
...@@ -110,13 +137,16 @@ def diff_attr(ori_attrs, new_attrs): ...@@ -110,13 +137,16 @@ def diff_attr(ori_attrs, new_attrs):
attr_deleted_error_massage.append(attr_name) attr_deleted_error_massage.append(attr_name)
for attr_name in attrs_only_in_new: for attr_name in attrs_only_in_new:
attr_added_error_massage.append(attr_name)
if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None: if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None:
error, attr_error = True, True error, attr_error = True, True
attr_added_error_massage.append(attr_name) attr_added_def_error_massage.append(attr_name)
attr_diff_message = {} attr_diff_message = {}
if attr_added_error_massage: if attr_added_error_massage:
attr_diff_message[ADD] = attr_added_error_massage attr_diff_message[ADD] = attr_added_error_massage
if attr_added_def_error_massage:
attr_diff_message[ADD_WITH_DEFAULT] = attr_added_def_error_massage
if attr_changed_error_massage: if attr_changed_error_massage:
attr_diff_message[CHANGE] = attr_changed_error_massage attr_diff_message[CHANGE] = attr_changed_error_massage
if attr_deleted_error_massage: if attr_deleted_error_massage:
...@@ -125,15 +155,39 @@ def diff_attr(ori_attrs, new_attrs): ...@@ -125,15 +155,39 @@ def diff_attr(ori_attrs, new_attrs):
return attr_error, attr_diff_message return attr_error, attr_diff_message
def check_io_registry(io_type, op, diff):
checker = OpLastCheckpointChecker()
results = {}
for update_type in [ADD]:
for item in diff.get(update_type, {}):
infos = checker.filter_updates(
op, version_update_map[io_type][update_type], item)
if not infos:
results[update_type] = (op, item, io_type)
return results
def check_attr_registry(op, diff):
checker = OpLastCheckpointChecker()
results = {}
for update_type in [ADD, CHANGE]:
for item in diff.get(update_type, {}):
infos = checker.filter_updates(
op, version_update_map[ATTRS][update_type], item)
if not infos:
results[update_type] = (op, item)
return results
def compare_op_desc(origin_op_desc, new_op_desc): def compare_op_desc(origin_op_desc, new_op_desc):
origin = json.loads(origin_op_desc) origin = json.loads(origin_op_desc)
new = json.loads(new_op_desc) new = json.loads(new_op_desc)
error_message = {} desc_error_message = {}
version_error_message = {}
if cmp(origin_op_desc, new_op_desc) == SAME: if cmp(origin_op_desc, new_op_desc) == SAME:
return error_message return desc_error_message, version_error_message
for op_type in origin: for op_type in origin:
# no need to compare if the operator is deleted # no need to compare if the operator is deleted
if op_type not in new: if op_type not in new:
continue continue
...@@ -144,33 +198,47 @@ def compare_op_desc(origin_op_desc, new_op_desc): ...@@ -144,33 +198,47 @@ def compare_op_desc(origin_op_desc, new_op_desc):
origin_inputs = origin_info.get(INPUTS, {}) origin_inputs = origin_info.get(INPUTS, {})
new_inputs = new_info.get(INPUTS, {}) new_inputs = new_info.get(INPUTS, {})
ins_error, ins_diff = diff_vars(origin_inputs, new_inputs) ins_error, ins_diff = diff_vars(origin_inputs, new_inputs)
ins_version_errors = check_io_registry(INPUTS, op_type, ins_diff)
origin_outputs = origin_info.get(OUTPUTS, {}) origin_outputs = origin_info.get(OUTPUTS, {})
new_outputs = new_info.get(OUTPUTS, {}) new_outputs = new_info.get(OUTPUTS, {})
outs_error, outs_diff = diff_vars(origin_outputs, new_outputs) outs_error, outs_diff = diff_vars(origin_outputs, new_outputs)
outs_version_errors = check_io_registry(OUTPUTS, op_type, outs_diff)
origin_attrs = origin_info.get(ATTRS, {}) origin_attrs = origin_info.get(ATTRS, {})
new_attrs = new_info.get(ATTRS, {}) new_attrs = new_info.get(ATTRS, {})
attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
attrs_version_errors = check_attr_registry(op_type, attrs_diff)
if ins_error: if ins_error:
error_message.setdefault(op_type, {})[INPUTS] = ins_diff desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff
if outs_error: if outs_error:
error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
if attrs_error: if attrs_error:
error_message.setdefault(op_type, {})[ATTRS] = attrs_diff desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
return error_message if ins_version_errors:
version_error_message.setdefault(op_type,
{})[INPUTS] = ins_version_errors
if outs_version_errors:
version_error_message.setdefault(op_type,
{})[OUTPUTS] = outs_version_errors
if attrs_version_errors:
version_error_message.setdefault(op_type,
{})[ATTRS] = attrs_version_errors
return desc_error_message, version_error_message
def print_error_message(error_message):
print("Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n") def print_desc_error_message(error_message):
print("\n======================= \n"
"Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n")
for op_name in error_message: for op_name in error_message:
print("For OP '{}':".format(op_name)) print("For OP '{}':".format(op_name))
# 1. print inputs error message # 1. print inputs error message
Inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) Inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
for name in Inputs_error.get(ADD, {}): for name in Inputs_error.get(ADD_DISPENSABLE, {}):
print(" * The added Input '{}' is not dispensable.".format(name)) print(" * The added Input '{}' is not dispensable.".format(name))
for name in Inputs_error.get(DELETE, {}): for name in Inputs_error.get(DELETE, {}):
...@@ -186,7 +254,7 @@ def print_error_message(error_message): ...@@ -186,7 +254,7 @@ def print_error_message(error_message):
# 2. print outputs error message # 2. print outputs error message
Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
for name in Outputs_error.get(ADD, {}): for name in Outputs_error.get(ADD_DISPENSABLE, {}):
print(" * The added Output '{}' is not dispensable.".format(name)) print(" * The added Output '{}' is not dispensable.".format(name))
for name in Outputs_error.get(DELETE, {}): for name in Outputs_error.get(DELETE, {}):
...@@ -202,7 +270,7 @@ def print_error_message(error_message): ...@@ -202,7 +270,7 @@ def print_error_message(error_message):
# 3. print attrs error message # 3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
for name in attrs_error.get(ADD, {}): for name in attrs_error.get(ADD_WITH_DEFAULT, {}):
print(" * The added attr '{}' doesn't set default value.".format( print(" * The added attr '{}' doesn't set default value.".format(
name)) name))
...@@ -218,6 +286,40 @@ def print_error_message(error_message): ...@@ -218,6 +286,40 @@ def print_error_message(error_message):
format(arg, name, ori_value, new_value)) format(arg, name, ori_value, new_value))
def print_version_error_message(error_message):
print(
"\n======================= \n"
"Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:\n"
)
for op_name in error_message:
print("For OP '{}':".format(op_name))
# 1. print inputs error message
inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
tuple = inputs_error.get(ADD, {})
if tuple:
print(" * The added input '{}' is not yet registered.".format(tuple[
1]))
# 2. print inputs error message
outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
tuple = outputs_error.get(ADD, {})
if tuple:
print(" * The added output '{}' is not yet registered.".format(
tuple[1]))
#3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
tuple = attrs_error.get(ADD, {})
if tuple:
print(" * The added attribute '{}' is not yet registered.".format(
tuple[1]))
tuple = attrs_error.get(CHANGE, {})
if tuple:
print(" * The change of attribute '{}' is not yet registered.".
format(tuple[1]))
def print_repeat_process(): def print_repeat_process():
print( print(
"Tips:" "Tips:"
...@@ -241,10 +343,12 @@ if len(sys.argv) == 3: ...@@ -241,10 +343,12 @@ if len(sys.argv) == 3:
with open(sys.argv[2], 'r') as f: with open(sys.argv[2], 'r') as f:
new_op_desc = f.read() new_op_desc = f.read()
error_message = compare_op_desc(origin_op_desc, new_op_desc) desc_error_message, version_error_message = compare_op_desc(origin_op_desc,
new_op_desc)
if error: if error:
print("-" * 30) print("-" * 30)
print_error_message(error_message) print_desc_error_message(desc_error_message)
print_version_error_message(version_error_message)
print("-" * 30) print("-" * 30)
else: else:
print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec") print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册