未验证 提交 c41fd033 编写于 作者: 石晓伟 提交者: GitHub

check op_version_registry in CI test, test=develop (#28402)

上级 2500dca8
......@@ -92,7 +92,7 @@ enum class OpUpdateType {
class OpUpdateBase {
public:
virtual const OpUpdateInfo* info() const = 0;
virtual const OpUpdateInfo& info() const = 0;
virtual OpUpdateType type() const = 0;
virtual ~OpUpdateBase() = default;
};
......@@ -101,7 +101,7 @@ template <typename InfoType, OpUpdateType type__>
class OpUpdate : public OpUpdateBase {
public:
explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {}
const OpUpdateInfo* info() const override { return &info_; }
const InfoType& info() const override { return info_; }
OpUpdateType type() const override { return type_; }
private:
......@@ -169,7 +169,6 @@ class OpVersion {
class OpVersionRegistrar {
public:
OpVersionRegistrar() = default;
static OpVersionRegistrar& GetInstance() {
static OpVersionRegistrar instance;
return instance;
......@@ -185,6 +184,8 @@ class OpVersionRegistrar {
private:
std::unordered_map<std::string, OpVersion> op_version_map_;
OpVersionRegistrar() = default;
OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete;
};
inline const std::unordered_map<std::string, OpVersion>& get_op_version_map() {
......
......@@ -130,7 +130,7 @@ REGISTER_OP_VERSION(distribute_fpn_proposals)
Upgrade distribute_fpn_proposals add a new input
[RoisNum] and add a new output [MultiLevelRoIsNum].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("RoIsNum", "The number of RoIs in each image.")
.NewInput("RoisNum", "The number of RoIs in each image.")
.NewOutput("MultiLevelRoisNum",
"The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
......
......@@ -95,8 +95,7 @@ void BindOpUpdateType(py::module *m) {
void BindOpUpdateBase(py::module *m) {
py::class_<OpUpdateBase>(*m, "OpUpdateBase")
.def("info", [](const OpUpdateBase &obj) { return obj.info(); },
py::return_value_policy::reference)
.def("info", &OpUpdateBase::info, py::return_value_policy::reference)
.def("type", &OpUpdateBase::type);
}
......
......@@ -14,6 +14,8 @@
import json
import sys
from paddle.utils import OpLastCheckpointChecker
from paddle.fluid.core import OpUpdateType
SAME = 0
......@@ -21,7 +23,14 @@ INPUTS = "Inputs"
OUTPUTS = "Outputs"
ATTRS = "Attrs"
# The constant `ADD` means that an item has been added. In particular,
# we use `ADD_WITH_DEFAULT` to mean adding attributes with default
# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or
# outputs.
ADD_WITH_DEFAULT = "Add_with_default"
ADD_DISPENSABLE = "Add_dispensable"
ADD = "Add"
DELETE = "Delete"
CHANGE = "Change"
......@@ -35,12 +44,26 @@ DEFAULT_VALUE = "default_value"
error = False
version_update_map = {
INPUTS: {
ADD: OpUpdateType.kNewInput,
},
OUTPUTS: {
ADD: OpUpdateType.kNewOutput,
},
ATTRS: {
ADD: OpUpdateType.kNewAttr,
CHANGE: OpUpdateType.kModifyAttr,
},
}
def diff_vars(origin_vars, new_vars):
global error
var_error = False
var_changed_error_massage = {}
var_added_error_massage = []
var_add_massage = []
var_add_dispensable_massage = []
var_deleted_error_massage = []
common_vars_name = set(origin_vars.keys()) & set(new_vars.keys())
......@@ -65,13 +88,16 @@ def diff_vars(origin_vars, new_vars):
var_deleted_error_massage.append(var_name)
for var_name in vars_name_only_in_new:
var_add_massage.append(var_name)
if not new_vars.get(var_name).get(DISPENSABLE):
error, var_error = True, True
var_added_error_massage.append(var_name)
var_add_dispensable_massage.append(var_name)
var_diff_message = {}
if var_added_error_massage:
var_diff_message[ADD] = var_added_error_massage
if var_add_massage:
var_diff_message[ADD] = var_add_massage
if var_add_dispensable_massage:
var_diff_message[ADD_DISPENSABLE] = var_add_dispensable_massage
if var_changed_error_massage:
var_diff_message[CHANGE] = var_changed_error_massage
if var_deleted_error_massage:
......@@ -86,6 +112,7 @@ def diff_attr(ori_attrs, new_attrs):
attr_changed_error_massage = {}
attr_added_error_massage = []
attr_added_def_error_massage = []
attr_deleted_error_massage = []
common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys())
......@@ -110,13 +137,16 @@ def diff_attr(ori_attrs, new_attrs):
attr_deleted_error_massage.append(attr_name)
for attr_name in attrs_only_in_new:
attr_added_error_massage.append(attr_name)
if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None:
error, attr_error = True, True
attr_added_error_massage.append(attr_name)
attr_added_def_error_massage.append(attr_name)
attr_diff_message = {}
if attr_added_error_massage:
attr_diff_message[ADD] = attr_added_error_massage
if attr_added_def_error_massage:
attr_diff_message[ADD_WITH_DEFAULT] = attr_added_def_error_massage
if attr_changed_error_massage:
attr_diff_message[CHANGE] = attr_changed_error_massage
if attr_deleted_error_massage:
......@@ -125,15 +155,39 @@ def diff_attr(ori_attrs, new_attrs):
return attr_error, attr_diff_message
def check_io_registry(io_type, op, diff):
checker = OpLastCheckpointChecker()
results = {}
for update_type in [ADD]:
for item in diff.get(update_type, {}):
infos = checker.filter_updates(
op, version_update_map[io_type][update_type], item)
if not infos:
results[update_type] = (op, item, io_type)
return results
def check_attr_registry(op, diff):
checker = OpLastCheckpointChecker()
results = {}
for update_type in [ADD, CHANGE]:
for item in diff.get(update_type, {}):
infos = checker.filter_updates(
op, version_update_map[ATTRS][update_type], item)
if not infos:
results[update_type] = (op, item)
return results
def compare_op_desc(origin_op_desc, new_op_desc):
origin = json.loads(origin_op_desc)
new = json.loads(new_op_desc)
error_message = {}
desc_error_message = {}
version_error_message = {}
if cmp(origin_op_desc, new_op_desc) == SAME:
return error_message
return desc_error_message, version_error_message
for op_type in origin:
# no need to compare if the operator is deleted
if op_type not in new:
continue
......@@ -144,33 +198,47 @@ def compare_op_desc(origin_op_desc, new_op_desc):
origin_inputs = origin_info.get(INPUTS, {})
new_inputs = new_info.get(INPUTS, {})
ins_error, ins_diff = diff_vars(origin_inputs, new_inputs)
ins_version_errors = check_io_registry(INPUTS, op_type, ins_diff)
origin_outputs = origin_info.get(OUTPUTS, {})
new_outputs = new_info.get(OUTPUTS, {})
outs_error, outs_diff = diff_vars(origin_outputs, new_outputs)
outs_version_errors = check_io_registry(OUTPUTS, op_type, outs_diff)
origin_attrs = origin_info.get(ATTRS, {})
new_attrs = new_info.get(ATTRS, {})
attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
attrs_version_errors = check_attr_registry(op_type, attrs_diff)
if ins_error:
error_message.setdefault(op_type, {})[INPUTS] = ins_diff
desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff
if outs_error:
error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
if attrs_error:
error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
if ins_version_errors:
version_error_message.setdefault(op_type,
{})[INPUTS] = ins_version_errors
if outs_version_errors:
version_error_message.setdefault(op_type,
{})[OUTPUTS] = outs_version_errors
if attrs_version_errors:
version_error_message.setdefault(op_type,
{})[ATTRS] = attrs_version_errors
return error_message
return desc_error_message, version_error_message
def print_error_message(error_message):
print("Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n")
def print_desc_error_message(error_message):
print("\n======================= \n"
"Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n")
for op_name in error_message:
print("For OP '{}':".format(op_name))
# 1. print inputs error message
Inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
for name in Inputs_error.get(ADD, {}):
for name in Inputs_error.get(ADD_DISPENSABLE, {}):
print(" * The added Input '{}' is not dispensable.".format(name))
for name in Inputs_error.get(DELETE, {}):
......@@ -186,7 +254,7 @@ def print_error_message(error_message):
# 2. print outputs error message
Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
for name in Outputs_error.get(ADD, {}):
for name in Outputs_error.get(ADD_DISPENSABLE, {}):
print(" * The added Output '{}' is not dispensable.".format(name))
for name in Outputs_error.get(DELETE, {}):
......@@ -202,7 +270,7 @@ def print_error_message(error_message):
# 3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
for name in attrs_error.get(ADD, {}):
for name in attrs_error.get(ADD_WITH_DEFAULT, {}):
print(" * The added attr '{}' doesn't set default value.".format(
name))
......@@ -218,6 +286,40 @@ def print_error_message(error_message):
format(arg, name, ori_value, new_value))
def print_version_error_message(error_message):
print(
"\n======================= \n"
"Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:\n"
)
for op_name in error_message:
print("For OP '{}':".format(op_name))
# 1. print inputs error message
inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
tuple = inputs_error.get(ADD, {})
if tuple:
print(" * The added input '{}' is not yet registered.".format(tuple[
1]))
# 2. print inputs error message
outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
tuple = outputs_error.get(ADD, {})
if tuple:
print(" * The added output '{}' is not yet registered.".format(
tuple[1]))
#3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
tuple = attrs_error.get(ADD, {})
if tuple:
print(" * The added attribute '{}' is not yet registered.".format(
tuple[1]))
tuple = attrs_error.get(CHANGE, {})
if tuple:
print(" * The change of attribute '{}' is not yet registered.".
format(tuple[1]))
def print_repeat_process():
print(
"Tips:"
......@@ -241,10 +343,12 @@ if len(sys.argv) == 3:
with open(sys.argv[2], 'r') as f:
new_op_desc = f.read()
error_message = compare_op_desc(origin_op_desc, new_op_desc)
desc_error_message, version_error_message = compare_op_desc(origin_op_desc,
new_op_desc)
if error:
print("-" * 30)
print_error_message(error_message)
print_desc_error_message(desc_error_message)
print_version_error_message(version_error_message)
print("-" * 30)
else:
print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册