diff --git a/paddle/fluid/framework/op_version_registry.h b/paddle/fluid/framework/op_version_registry.h index 5822dfa11dd25a5e84800871c7efc73f375e2109..c9d3084724bcdbcee6e0e43a985d5c41c5a8ae84 100644 --- a/paddle/fluid/framework/op_version_registry.h +++ b/paddle/fluid/framework/op_version_registry.h @@ -92,7 +92,7 @@ enum class OpUpdateType { class OpUpdateBase { public: - virtual const OpUpdateInfo* info() const = 0; + virtual const OpUpdateInfo& info() const = 0; virtual OpUpdateType type() const = 0; virtual ~OpUpdateBase() = default; }; @@ -101,7 +101,7 @@ template class OpUpdate : public OpUpdateBase { public: explicit OpUpdate(const InfoType& info) : info_{info}, type_{type__} {} - const OpUpdateInfo* info() const override { return &info_; } + const InfoType& info() const override { return info_; } OpUpdateType type() const override { return type_; } private: @@ -169,7 +169,6 @@ class OpVersion { class OpVersionRegistrar { public: - OpVersionRegistrar() = default; static OpVersionRegistrar& GetInstance() { static OpVersionRegistrar instance; return instance; @@ -185,6 +184,8 @@ class OpVersionRegistrar { private: std::unordered_map op_version_map_; + OpVersionRegistrar() = default; + OpVersionRegistrar& operator=(const OpVersionRegistrar&) = delete; }; inline const std::unordered_map& get_op_version_map() { diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 614b37e703e721337057e04c5611386ff87a1e9e..b0c9d968e47b7968584d6af234fe1debcde153d0 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -130,7 +130,7 @@ REGISTER_OP_VERSION(distribute_fpn_proposals) Upgrade distribute_fpn_proposals add a new input [RoisNum] and add a new output [MultiLevelRoIsNum].)ROC", paddle::framework::compatible::OpVersionDesc() - .NewInput("RoIsNum", "The number of RoIs in each image.") + .NewInput("RoisNum", "The number of RoIs in each image.") .NewOutput("MultiLevelRoisNum", "The RoIs' number of each image on multiple " "levels. The number on each level has the shape of (B)," diff --git a/paddle/fluid/pybind/compatible.cc b/paddle/fluid/pybind/compatible.cc index 57b024c25cbaf9ce87081a33e2b8756ed4e725eb..cfe87a86cf0e559e9e0ef314b5ba475571f08e3e 100644 --- a/paddle/fluid/pybind/compatible.cc +++ b/paddle/fluid/pybind/compatible.cc @@ -95,8 +95,7 @@ void BindOpUpdateType(py::module *m) { void BindOpUpdateBase(py::module *m) { py::class_(*m, "OpUpdateBase") - .def("info", [](const OpUpdateBase &obj) { return obj.info(); }, - py::return_value_policy::reference) + .def("info", &OpUpdateBase::info, py::return_value_policy::reference) .def("type", &OpUpdateBase::type); } diff --git a/tools/check_op_desc.py b/tools/check_op_desc.py index 1873fde0c432db386af8122dba317f2bb6652da6..15e410401216ce68fc4178bb4e60d3ac159523f4 100644 --- a/tools/check_op_desc.py +++ b/tools/check_op_desc.py @@ -14,6 +14,8 @@ import json import sys +from paddle.utils import OpLastCheckpointChecker +from paddle.fluid.core import OpUpdateType SAME = 0 @@ -21,7 +23,14 @@ INPUTS = "Inputs" OUTPUTS = "Outputs" ATTRS = "Attrs" +# The constant `ADD` means that an item has been added. In particular, +# we use `ADD_WITH_DEFAULT` to mean adding attributes with default +# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or +# outputs. +ADD_WITH_DEFAULT = "Add_with_default" +ADD_DISPENSABLE = "Add_dispensable" ADD = "Add" + DELETE = "Delete" CHANGE = "Change" @@ -35,12 +44,26 @@ DEFAULT_VALUE = "default_value" error = False +version_update_map = { + INPUTS: { + ADD: OpUpdateType.kNewInput, + }, + OUTPUTS: { + ADD: OpUpdateType.kNewOutput, + }, + ATTRS: { + ADD: OpUpdateType.kNewAttr, + CHANGE: OpUpdateType.kModifyAttr, + }, +} + def diff_vars(origin_vars, new_vars): global error var_error = False var_changed_error_massage = {} - var_added_error_massage = [] + var_add_massage = [] + var_add_dispensable_massage = [] var_deleted_error_massage = [] common_vars_name = set(origin_vars.keys()) & set(new_vars.keys()) @@ -65,13 +88,16 @@ def diff_vars(origin_vars, new_vars): var_deleted_error_massage.append(var_name) for var_name in vars_name_only_in_new: + var_add_massage.append(var_name) if not new_vars.get(var_name).get(DISPENSABLE): error, var_error = True, True - var_added_error_massage.append(var_name) + var_add_dispensable_massage.append(var_name) var_diff_message = {} - if var_added_error_massage: - var_diff_message[ADD] = var_added_error_massage + if var_add_massage: + var_diff_message[ADD] = var_add_massage + if var_add_dispensable_massage: + var_diff_message[ADD_DISPENSABLE] = var_add_dispensable_massage if var_changed_error_massage: var_diff_message[CHANGE] = var_changed_error_massage if var_deleted_error_massage: @@ -86,6 +112,7 @@ def diff_attr(ori_attrs, new_attrs): attr_changed_error_massage = {} attr_added_error_massage = [] + attr_added_def_error_massage = [] attr_deleted_error_massage = [] common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys()) @@ -110,13 +137,16 @@ def diff_attr(ori_attrs, new_attrs): attr_deleted_error_massage.append(attr_name) for attr_name in attrs_only_in_new: + attr_added_error_massage.append(attr_name) if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None: error, attr_error = True, True - attr_added_error_massage.append(attr_name) + attr_added_def_error_massage.append(attr_name) attr_diff_message = {} if attr_added_error_massage: attr_diff_message[ADD] = attr_added_error_massage + if attr_added_def_error_massage: + attr_diff_message[ADD_WITH_DEFAULT] = attr_added_def_error_massage if attr_changed_error_massage: attr_diff_message[CHANGE] = attr_changed_error_massage if attr_deleted_error_massage: @@ -125,15 +155,39 @@ def diff_attr(ori_attrs, new_attrs): return attr_error, attr_diff_message +def check_io_registry(io_type, op, diff): + checker = OpLastCheckpointChecker() + results = {} + for update_type in [ADD]: + for item in diff.get(update_type, {}): + infos = checker.filter_updates( + op, version_update_map[io_type][update_type], item) + if not infos: + results[update_type] = (op, item, io_type) + return results + + +def check_attr_registry(op, diff): + checker = OpLastCheckpointChecker() + results = {} + for update_type in [ADD, CHANGE]: + for item in diff.get(update_type, {}): + infos = checker.filter_updates( + op, version_update_map[ATTRS][update_type], item) + if not infos: + results[update_type] = (op, item) + return results + + def compare_op_desc(origin_op_desc, new_op_desc): origin = json.loads(origin_op_desc) new = json.loads(new_op_desc) - error_message = {} + desc_error_message = {} + version_error_message = {} if cmp(origin_op_desc, new_op_desc) == SAME: - return error_message + return desc_error_message, version_error_message for op_type in origin: - # no need to compare if the operator is deleted if op_type not in new: continue @@ -144,33 +198,47 @@ def compare_op_desc(origin_op_desc, new_op_desc): origin_inputs = origin_info.get(INPUTS, {}) new_inputs = new_info.get(INPUTS, {}) ins_error, ins_diff = diff_vars(origin_inputs, new_inputs) + ins_version_errors = check_io_registry(INPUTS, op_type, ins_diff) origin_outputs = origin_info.get(OUTPUTS, {}) new_outputs = new_info.get(OUTPUTS, {}) outs_error, outs_diff = diff_vars(origin_outputs, new_outputs) + outs_version_errors = check_io_registry(OUTPUTS, op_type, outs_diff) origin_attrs = origin_info.get(ATTRS, {}) new_attrs = new_info.get(ATTRS, {}) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) + attrs_version_errors = check_attr_registry(op_type, attrs_diff) if ins_error: - error_message.setdefault(op_type, {})[INPUTS] = ins_diff + desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff if outs_error: - error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff + desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff if attrs_error: - error_message.setdefault(op_type, {})[ATTRS] = attrs_diff + desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff - return error_message + if ins_version_errors: + version_error_message.setdefault(op_type, + {})[INPUTS] = ins_version_errors + if outs_version_errors: + version_error_message.setdefault(op_type, + {})[OUTPUTS] = outs_version_errors + if attrs_version_errors: + version_error_message.setdefault(op_type, + {})[ATTRS] = attrs_version_errors + return desc_error_message, version_error_message -def print_error_message(error_message): - print("Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n") + +def print_desc_error_message(error_message): + print("\n======================= \n" + "Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n") for op_name in error_message: print("For OP '{}':".format(op_name)) # 1. print inputs error message Inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) - for name in Inputs_error.get(ADD, {}): + for name in Inputs_error.get(ADD_DISPENSABLE, {}): print(" * The added Input '{}' is not dispensable.".format(name)) for name in Inputs_error.get(DELETE, {}): @@ -186,7 +254,7 @@ def print_error_message(error_message): # 2. print outputs error message Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) - for name in Outputs_error.get(ADD, {}): + for name in Outputs_error.get(ADD_DISPENSABLE, {}): print(" * The added Output '{}' is not dispensable.".format(name)) for name in Outputs_error.get(DELETE, {}): @@ -202,7 +270,7 @@ def print_error_message(error_message): # 3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) - for name in attrs_error.get(ADD, {}): + for name in attrs_error.get(ADD_WITH_DEFAULT, {}): print(" * The added attr '{}' doesn't set default value.".format( name)) @@ -218,6 +286,40 @@ def print_error_message(error_message): format(arg, name, ori_value, new_value)) +def print_version_error_message(error_message): + print( + "\n======================= \n" + "Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:\n" + ) + for op_name in error_message: + print("For OP '{}':".format(op_name)) + + # 1. print inputs error message + inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) + tuple = inputs_error.get(ADD, {}) + if tuple: + print(" * The added input '{}' is not yet registered.".format(tuple[ + 1])) + + # 2. print inputs error message + outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) + tuple = outputs_error.get(ADD, {}) + if tuple: + print(" * The added output '{}' is not yet registered.".format( + tuple[1])) + + #3. print attrs error message + attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) + tuple = attrs_error.get(ADD, {}) + if tuple: + print(" * The added attribute '{}' is not yet registered.".format( + tuple[1])) + tuple = attrs_error.get(CHANGE, {}) + if tuple: + print(" * The change of attribute '{}' is not yet registered.". + format(tuple[1])) + + def print_repeat_process(): print( "Tips:" @@ -241,10 +343,12 @@ if len(sys.argv) == 3: with open(sys.argv[2], 'r') as f: new_op_desc = f.read() - error_message = compare_op_desc(origin_op_desc, new_op_desc) + desc_error_message, version_error_message = compare_op_desc(origin_op_desc, + new_op_desc) if error: print("-" * 30) - print_error_message(error_message) + print_desc_error_message(desc_error_message) + print_version_error_message(version_error_message) print("-" * 30) else: print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")