From b65708a812cac2f3cf4e6d8ed9e03a5dd8278a91 Mon Sep 17 00:00:00 2001 From: Wilber Date: Fri, 3 Dec 2021 15:24:58 +0800 Subject: [PATCH] update ci check_op_desc to support op_version and op_compat. (#37600) * update check_op_desc to support op_version and op_compat. --- tools/check_api_approvals.sh | 16 +++- tools/check_op_desc.py | 157 +++++++++++++++++++++++++++++------ tools/print_op_desc.py | 13 ++- 3 files changed, 157 insertions(+), 29 deletions(-) diff --git a/tools/check_api_approvals.sh b/tools/check_api_approvals.sh index dcbe853d8a1..45d4731ba1d 100644 --- a/tools/check_api_approvals.sh +++ b/tools/check_api_approvals.sh @@ -76,9 +76,21 @@ if [ "$op_type_spec_diff" != "" ]; then fi op_desc_diff=`python ${PADDLE_ROOT}/tools/check_op_desc.py ${PADDLE_ROOT}/paddle/fluid/OP_DESC_DEV.spec ${PADDLE_ROOT}/paddle/fluid/OP_DESC_PR.spec` +inference_approve=`echo "$op_desc_diff" | grep "need inference to review" -` +slim_approve=`echo "$op_desc_diff" | grep "need slim to review" -` if [ "$op_desc_diff" != "" ]; then - echo_line="You must have one RD (cyj1986, Superjomn) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n" - check_approval 1 39645414 328693 + echo_line="You must have one RD (inference[ Superjomn(Recommend), Shixiaowei02, cyj1986 ] or slim[ wanghaoshuang(Recommend), qingqing01 ]) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n" + check_approval 1 39645414 328693 39303645 7534971 7845005 +fi + +if [ "$slim_approve" != "" ]; then + echo_line="You must have one RD (wanghaoshuang(Recommend), qingqing01) approval for the changes of `quant` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${slim_approve}\n" + check_approval 1 7534971 7845005 +fi + +if [ "$inference_approve" != "" ]; then + echo_line="You must have one RD (Superjomn(Recommend), Shixiaowei02, cyj1986) approval for the changes of `def` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${inference_approve}\n" + check_approval 1 39645414 328693 39303645 fi DEV_OP_USE_DEFAULT_GRAD_MAKER_SPEC=${PADDLE_ROOT}/paddle/fluid/op_use_default_grad_maker_DEV.spec diff --git a/tools/check_op_desc.py b/tools/check_op_desc.py index 78abb6f36c6..19984a55a41 100644 --- a/tools/check_op_desc.py +++ b/tools/check_op_desc.py @@ -40,6 +40,11 @@ TYPE = "type" GENERATED = "generated" DEFAULT_VALUE = "default_value" +# add_with_extra, add_with_quant and add_with_def +EXTRA = "extra" +QUANT = "quant" +DEF = "def" + error = False version_update_map = { @@ -64,6 +69,9 @@ def diff_vars(origin_vars, new_vars): var_add_dispensable_massage = [] var_deleted_error_massage = [] + var_add_quant_message = [] + var_add_def_message = [] + common_vars_name = set(origin_vars.keys()) & set(new_vars.keys()) vars_name_only_in_origin = set(origin_vars.keys()) - set(new_vars.keys()) vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys()) @@ -73,11 +81,12 @@ def diff_vars(origin_vars, new_vars): continue else: error, var_error = True, True - var_changed_error_massage[var_name] = {} for arg_name in origin_vars.get(var_name): new_arg_value = new_vars.get(var_name, {}).get(arg_name) origin_arg_value = origin_vars.get(var_name, {}).get(arg_name) if new_arg_value != origin_arg_value: + if var_name not in var_changed_error_massage.keys(): + var_changed_error_massage[var_name] = {} var_changed_error_massage[var_name][arg_name] = ( origin_arg_value, new_arg_value) @@ -91,6 +100,21 @@ def diff_vars(origin_vars, new_vars): error, var_error = True, True var_add_dispensable_massage.append(var_name) + # if added var is extra, then no need to check. + if new_vars.get(var_name).get(EXTRA): + continue + + # if added var is quant, slim needs to review, needs to register. + if new_vars.get(var_name).get(QUANT): + error, var_error = True, True + var_add_quant_message.append(var_name) + + # if added var is def, inference needs to review, needs to register. + if not new_vars.get(var_name).get(EXTRA) and not new_vars.get( + var_name).get(QUANT): + error, var_error = True, True + var_add_def_message.append(var_name) + var_diff_message = {} if var_add_massage: var_diff_message[ADD] = var_add_massage @@ -100,6 +124,10 @@ def diff_vars(origin_vars, new_vars): var_diff_message[CHANGE] = var_changed_error_massage if var_deleted_error_massage: var_diff_message[DELETE] = var_deleted_error_massage + if var_add_quant_message: + var_diff_message[QUANT] = var_add_quant_message + if var_add_def_message: + var_diff_message[DEF] = var_add_def_message return var_error, var_diff_message @@ -113,6 +141,9 @@ def diff_attr(ori_attrs, new_attrs): attr_added_def_error_massage = [] attr_deleted_error_massage = [] + attr_added_quant_message = [] + attr_added_define_message = [] + common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys()) attrs_only_in_origin = set(ori_attrs.keys()) - set(new_attrs.keys()) attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys()) @@ -122,11 +153,12 @@ def diff_attr(ori_attrs, new_attrs): continue else: error, attr_error = True, True - attr_changed_error_massage[attr_name] = {} for arg_name in ori_attrs.get(attr_name): new_arg_value = new_attrs.get(attr_name, {}).get(arg_name) origin_arg_value = ori_attrs.get(attr_name, {}).get(arg_name) if new_arg_value != origin_arg_value: + if attr_name not in attr_changed_error_massage.keys(): + attr_changed_error_massage[attr_name] = {} attr_changed_error_massage[attr_name][arg_name] = ( origin_arg_value, new_arg_value) @@ -140,6 +172,17 @@ def diff_attr(ori_attrs, new_attrs): error, attr_error = True, True attr_added_def_error_massage.append(attr_name) + # if added attr is quant, slim needs to review, needs to register + if new_attrs.get(attr_name).get(QUANT): + error, var_error = True, True + attr_added_quant_message.append(attr_name) + + # if added attr is def, inference needs to review, needs to register + if not new_attrs.get(attr_name).get(EXTRA) and not new_attrs.get( + attr_name).get(QUANT): + error, var_error = True, True + attr_added_define_message.append(attr_name) + attr_diff_message = {} if attr_added_error_massage: attr_diff_message[ADD] = attr_added_error_massage @@ -149,6 +192,10 @@ def diff_attr(ori_attrs, new_attrs): attr_diff_message[CHANGE] = attr_changed_error_massage if attr_deleted_error_massage: attr_diff_message[DELETE] = attr_deleted_error_massage + if attr_added_define_message: + attr_diff_message[DEF] = attr_added_define_message + if attr_added_quant_message: + attr_diff_message[QUANT] = attr_added_quant_message return attr_error, attr_diff_message @@ -157,23 +204,49 @@ def check_io_registry(io_type, op, diff): checker = OpLastCheckpointChecker() results = {} for update_type in [ADD]: - for item in diff.get(update_type, {}): + for item in diff.get(update_type, []): infos = checker.filter_updates( op, version_update_map[io_type][update_type], item) if not infos: - results[update_type] = (op, item, io_type) + if update_type not in results.keys(): + results[update_type] = [] + # extra not need to register. + qaunt_ios = diff.get(QUANT, []) + def_ios = diff.get(DEF, []) + if item in qaunt_ios or item in def_ios: + results[update_type].append((op, item, io_type)) + return results -def check_attr_registry(op, diff): +def check_attr_registry(op, diff, origin_attrs): checker = OpLastCheckpointChecker() results = {} + qaunt_attrs = diff.get(QUANT, []) + def_attrs = diff.get(DEF, []) + change_attrs = diff.get(CHANGE, {}) for update_type in [ADD, CHANGE]: for item in diff.get(update_type, {}): infos = checker.filter_updates( op, version_update_map[ATTRS][update_type], item) if not infos: - results[update_type] = (op, item) + if update_type == ADD: + if update_type not in results.keys(): + results[update_type] = [] + # extra not need to register. + if item in qaunt_attrs or item in def_attrs: + results[update_type].append((op, item)) + elif update_type == CHANGE: + if CHANGE not in results.keys(): + results[update_type] = {} + for attr_name, attr_change in change_attrs.items(): + # extra not need to register. + if not origin_attrs.get(attr_name).get(EXTRA): + results[update_type][attr_name] = attr_change + + for update_type in [ADD, CHANGE]: + if update_type in results.keys() and len(results[update_type]) == 0: + del results[update_type] return results @@ -206,13 +279,14 @@ def compare_op_desc(origin_op_desc, new_op_desc): origin_attrs = origin_info.get(ATTRS, {}) new_attrs = new_info.get(ATTRS, {}) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) - attrs_version_errors = check_attr_registry(op_type, attrs_diff) + attrs_version_errors = check_attr_registry(op_type, attrs_diff, + origin_attrs) - if ins_error: + if ins_diff: desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff - if outs_error: + if outs_diff: desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff - if attrs_error: + if attrs_diff: desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff if ins_version_errors: @@ -250,6 +324,14 @@ def print_desc_error_message(error_message): " * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in Inputs_error.get(QUANT, {}): + print(" * The added Input '{}' is `quant`, need slim to review.". + format(name)) + + for name in Inputs_error.get(DEF, {}): + print(" * The added Input '{}' is `def`, need inference to review.". + format(name)) + # 2. print outputs error message Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) for name in Outputs_error.get(ADD_DISPENSABLE, {}): @@ -266,6 +348,15 @@ def print_desc_error_message(error_message): " * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in Outputs_error.get(QUANT, {}): + print(" * The added Output '{}' is `quant`, need slim to review.". + format(name)) + + for name in Outputs_error.get(DEF, {}): + print( + " * The added Output '{}' is `def`, need inference to review.". + format(name)) + # 3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) for name in attrs_error.get(ADD_WITH_DEFAULT, {}): @@ -283,6 +374,16 @@ def print_desc_error_message(error_message): " * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in attrs_error.get(QUANT, {}): + # TODO(Wilber): + print(" * The added attr '{}' is `quant`, need slim to review.". + format(name)) + + for name in attrs_error.get(DEF, {}): + # TODO(Wilber): + print(" * The added attr '{}' is `def`, need inference to review.". + format(name)) + def print_version_error_message(error_message): print( @@ -294,28 +395,32 @@ def print_version_error_message(error_message): # 1. print inputs error message inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) - tuple = inputs_error.get(ADD, {}) - if tuple: - print(" * The added input '{}' is not yet registered.".format(tuple[ - 1])) + error_list = inputs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added input '{}' is not yet registered.".format( + tup[1])) - # 2. print inputs error message + # 2. print outputs error message outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) - tuple = outputs_error.get(ADD, {}) - if tuple: - print(" * The added output '{}' is not yet registered.".format( - tuple[1])) + error_list = outputs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added output '{}' is not yet registered.".format( + tup[1])) #3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) - tuple = attrs_error.get(ADD, {}) - if tuple: - print(" * The added attribute '{}' is not yet registered.".format( - tuple[1])) - tuple = attrs_error.get(CHANGE, {}) - if tuple: + error_list = attrs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added attribute '{}' is not yet registered.". + format(tup[1])) + error_dic = error_message.get(op_name, {}).get(ATTRS, {}).get(CHANGE, + {}) + for key, val in error_dic.items(): print(" * The change of attribute '{}' is not yet registered.". - format(tuple[1])) + format(key)) def print_repeat_process(): diff --git a/tools/print_op_desc.py b/tools/print_op_desc.py index 64445bab3a6..b85103a7a25 100644 --- a/tools/print_op_desc.py +++ b/tools/print_op_desc.py @@ -18,7 +18,9 @@ Print all ops desc in dict: {input_name1: {DISPENSABLE: bool, INTERMEDIATE: bool, - DUPLICABLE: bool + DUPLICABLE: bool, + EXTRA: bool, + QUANT: bool, }, input_name2:{} }, @@ -28,6 +30,8 @@ Print all ops desc in dict: {TYPE: int, GENERATED: bool, DEFAULT_VALUE: int/str/etc, + EXTRA: bool, + QUANT: bool, } } } @@ -55,6 +59,9 @@ TYPE = "type" GENERATED = "generated" DEFAULT_VALUE = "default_value" +EXTRA = "extra" +QUANT = "quant" + def get_attr_default_value(op_name): return core.get_op_attrs_default_value(cpt.to_bytes(op_name)) @@ -68,6 +75,8 @@ def get_vars_info(op_vars_proto): vars_info[name][DUPLICABLE] = var_proto.duplicable vars_info[name][DISPENSABLE] = var_proto.dispensable vars_info[name][INTERMEDIATE] = var_proto.intermediate + vars_info[name][EXTRA] = var_proto.extra + vars_info[name][QUANT] = var_proto.quant return vars_info @@ -81,6 +90,8 @@ def get_attrs_info(op_proto, op_attrs_proto): attrs_info[attr_name][GENERATED] = attr_proto.generated attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[ attr_name] if attr_name in attrs_default_values else None + attrs_info[attr_name][EXTRA] = attr_proto.extra + attrs_info[attr_name][QUANT] = attr_proto.quant return attrs_info -- GitLab