diff --git a/tools/check_api_approvals.sh b/tools/check_api_approvals.sh index dcbe853d8a1bcc37522f5c57bf667cefc8d279f4..45d4731ba1dbac0fd396f8d66b4ea00cccf7942a 100644 --- a/tools/check_api_approvals.sh +++ b/tools/check_api_approvals.sh @@ -76,9 +76,21 @@ if [ "$op_type_spec_diff" != "" ]; then fi op_desc_diff=`python ${PADDLE_ROOT}/tools/check_op_desc.py ${PADDLE_ROOT}/paddle/fluid/OP_DESC_DEV.spec ${PADDLE_ROOT}/paddle/fluid/OP_DESC_PR.spec` +inference_approve=`echo "$op_desc_diff" | grep "need inference to review" -` +slim_approve=`echo "$op_desc_diff" | grep "need slim to review" -` if [ "$op_desc_diff" != "" ]; then - echo_line="You must have one RD (cyj1986, Superjomn) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n" - check_approval 1 39645414 328693 + echo_line="You must have one RD (inference[ Superjomn(Recommend), Shixiaowei02, cyj1986 ] or slim[ wanghaoshuang(Recommend), qingqing01 ]) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n" + check_approval 1 39645414 328693 39303645 7534971 7845005 +fi + +if [ "$slim_approve" != "" ]; then + echo_line="You must have one RD (wanghaoshuang(Recommend), qingqing01) approval for the changes of `quant` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${slim_approve}\n" + check_approval 1 7534971 7845005 +fi + +if [ "$inference_approve" != "" ]; then + echo_line="You must have one RD (Superjomn(Recommend), Shixiaowei02, cyj1986) approval for the changes of `def` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${inference_approve}\n" + check_approval 1 39645414 328693 39303645 fi DEV_OP_USE_DEFAULT_GRAD_MAKER_SPEC=${PADDLE_ROOT}/paddle/fluid/op_use_default_grad_maker_DEV.spec diff --git a/tools/check_op_desc.py b/tools/check_op_desc.py index 78abb6f36c60626d2e022e6be017f5dbfa23d2c3..19984a55a41af43e06c36af1e6b536eb5c8fe98e 100644 --- a/tools/check_op_desc.py +++ b/tools/check_op_desc.py @@ -40,6 +40,11 @@ TYPE = "type" GENERATED = "generated" DEFAULT_VALUE = "default_value" +# add_with_extra, add_with_quant and add_with_def +EXTRA = "extra" +QUANT = "quant" +DEF = "def" + error = False version_update_map = { @@ -64,6 +69,9 @@ def diff_vars(origin_vars, new_vars): var_add_dispensable_massage = [] var_deleted_error_massage = [] + var_add_quant_message = [] + var_add_def_message = [] + common_vars_name = set(origin_vars.keys()) & set(new_vars.keys()) vars_name_only_in_origin = set(origin_vars.keys()) - set(new_vars.keys()) vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys()) @@ -73,11 +81,12 @@ def diff_vars(origin_vars, new_vars): continue else: error, var_error = True, True - var_changed_error_massage[var_name] = {} for arg_name in origin_vars.get(var_name): new_arg_value = new_vars.get(var_name, {}).get(arg_name) origin_arg_value = origin_vars.get(var_name, {}).get(arg_name) if new_arg_value != origin_arg_value: + if var_name not in var_changed_error_massage.keys(): + var_changed_error_massage[var_name] = {} var_changed_error_massage[var_name][arg_name] = ( origin_arg_value, new_arg_value) @@ -91,6 +100,21 @@ def diff_vars(origin_vars, new_vars): error, var_error = True, True var_add_dispensable_massage.append(var_name) + # if added var is extra, then no need to check. + if new_vars.get(var_name).get(EXTRA): + continue + + # if added var is quant, slim needs to review, needs to register. + if new_vars.get(var_name).get(QUANT): + error, var_error = True, True + var_add_quant_message.append(var_name) + + # if added var is def, inference needs to review, needs to register. + if not new_vars.get(var_name).get(EXTRA) and not new_vars.get( + var_name).get(QUANT): + error, var_error = True, True + var_add_def_message.append(var_name) + var_diff_message = {} if var_add_massage: var_diff_message[ADD] = var_add_massage @@ -100,6 +124,10 @@ def diff_vars(origin_vars, new_vars): var_diff_message[CHANGE] = var_changed_error_massage if var_deleted_error_massage: var_diff_message[DELETE] = var_deleted_error_massage + if var_add_quant_message: + var_diff_message[QUANT] = var_add_quant_message + if var_add_def_message: + var_diff_message[DEF] = var_add_def_message return var_error, var_diff_message @@ -113,6 +141,9 @@ def diff_attr(ori_attrs, new_attrs): attr_added_def_error_massage = [] attr_deleted_error_massage = [] + attr_added_quant_message = [] + attr_added_define_message = [] + common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys()) attrs_only_in_origin = set(ori_attrs.keys()) - set(new_attrs.keys()) attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys()) @@ -122,11 +153,12 @@ def diff_attr(ori_attrs, new_attrs): continue else: error, attr_error = True, True - attr_changed_error_massage[attr_name] = {} for arg_name in ori_attrs.get(attr_name): new_arg_value = new_attrs.get(attr_name, {}).get(arg_name) origin_arg_value = ori_attrs.get(attr_name, {}).get(arg_name) if new_arg_value != origin_arg_value: + if attr_name not in attr_changed_error_massage.keys(): + attr_changed_error_massage[attr_name] = {} attr_changed_error_massage[attr_name][arg_name] = ( origin_arg_value, new_arg_value) @@ -140,6 +172,17 @@ def diff_attr(ori_attrs, new_attrs): error, attr_error = True, True attr_added_def_error_massage.append(attr_name) + # if added attr is quant, slim needs to review, needs to register + if new_attrs.get(attr_name).get(QUANT): + error, var_error = True, True + attr_added_quant_message.append(attr_name) + + # if added attr is def, inference needs to review, needs to register + if not new_attrs.get(attr_name).get(EXTRA) and not new_attrs.get( + attr_name).get(QUANT): + error, var_error = True, True + attr_added_define_message.append(attr_name) + attr_diff_message = {} if attr_added_error_massage: attr_diff_message[ADD] = attr_added_error_massage @@ -149,6 +192,10 @@ def diff_attr(ori_attrs, new_attrs): attr_diff_message[CHANGE] = attr_changed_error_massage if attr_deleted_error_massage: attr_diff_message[DELETE] = attr_deleted_error_massage + if attr_added_define_message: + attr_diff_message[DEF] = attr_added_define_message + if attr_added_quant_message: + attr_diff_message[QUANT] = attr_added_quant_message return attr_error, attr_diff_message @@ -157,23 +204,49 @@ def check_io_registry(io_type, op, diff): checker = OpLastCheckpointChecker() results = {} for update_type in [ADD]: - for item in diff.get(update_type, {}): + for item in diff.get(update_type, []): infos = checker.filter_updates( op, version_update_map[io_type][update_type], item) if not infos: - results[update_type] = (op, item, io_type) + if update_type not in results.keys(): + results[update_type] = [] + # extra not need to register. + qaunt_ios = diff.get(QUANT, []) + def_ios = diff.get(DEF, []) + if item in qaunt_ios or item in def_ios: + results[update_type].append((op, item, io_type)) + return results -def check_attr_registry(op, diff): +def check_attr_registry(op, diff, origin_attrs): checker = OpLastCheckpointChecker() results = {} + qaunt_attrs = diff.get(QUANT, []) + def_attrs = diff.get(DEF, []) + change_attrs = diff.get(CHANGE, {}) for update_type in [ADD, CHANGE]: for item in diff.get(update_type, {}): infos = checker.filter_updates( op, version_update_map[ATTRS][update_type], item) if not infos: - results[update_type] = (op, item) + if update_type == ADD: + if update_type not in results.keys(): + results[update_type] = [] + # extra not need to register. + if item in qaunt_attrs or item in def_attrs: + results[update_type].append((op, item)) + elif update_type == CHANGE: + if CHANGE not in results.keys(): + results[update_type] = {} + for attr_name, attr_change in change_attrs.items(): + # extra not need to register. + if not origin_attrs.get(attr_name).get(EXTRA): + results[update_type][attr_name] = attr_change + + for update_type in [ADD, CHANGE]: + if update_type in results.keys() and len(results[update_type]) == 0: + del results[update_type] return results @@ -206,13 +279,14 @@ def compare_op_desc(origin_op_desc, new_op_desc): origin_attrs = origin_info.get(ATTRS, {}) new_attrs = new_info.get(ATTRS, {}) attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs) - attrs_version_errors = check_attr_registry(op_type, attrs_diff) + attrs_version_errors = check_attr_registry(op_type, attrs_diff, + origin_attrs) - if ins_error: + if ins_diff: desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff - if outs_error: + if outs_diff: desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff - if attrs_error: + if attrs_diff: desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff if ins_version_errors: @@ -250,6 +324,14 @@ def print_desc_error_message(error_message): " * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in Inputs_error.get(QUANT, {}): + print(" * The added Input '{}' is `quant`, need slim to review.". + format(name)) + + for name in Inputs_error.get(DEF, {}): + print(" * The added Input '{}' is `def`, need inference to review.". + format(name)) + # 2. print outputs error message Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) for name in Outputs_error.get(ADD_DISPENSABLE, {}): @@ -266,6 +348,15 @@ def print_desc_error_message(error_message): " * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in Outputs_error.get(QUANT, {}): + print(" * The added Output '{}' is `quant`, need slim to review.". + format(name)) + + for name in Outputs_error.get(DEF, {}): + print( + " * The added Output '{}' is `def`, need inference to review.". + format(name)) + # 3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) for name in attrs_error.get(ADD_WITH_DEFAULT, {}): @@ -283,6 +374,16 @@ def print_desc_error_message(error_message): " * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.". format(arg, name, ori_value, new_value)) + for name in attrs_error.get(QUANT, {}): + # TODO(Wilber): + print(" * The added attr '{}' is `quant`, need slim to review.". + format(name)) + + for name in attrs_error.get(DEF, {}): + # TODO(Wilber): + print(" * The added attr '{}' is `def`, need inference to review.". + format(name)) + def print_version_error_message(error_message): print( @@ -294,28 +395,32 @@ def print_version_error_message(error_message): # 1. print inputs error message inputs_error = error_message.get(op_name, {}).get(INPUTS, {}) - tuple = inputs_error.get(ADD, {}) - if tuple: - print(" * The added input '{}' is not yet registered.".format(tuple[ - 1])) + error_list = inputs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added input '{}' is not yet registered.".format( + tup[1])) - # 2. print inputs error message + # 2. print outputs error message outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {}) - tuple = outputs_error.get(ADD, {}) - if tuple: - print(" * The added output '{}' is not yet registered.".format( - tuple[1])) + error_list = outputs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added output '{}' is not yet registered.".format( + tup[1])) #3. print attrs error message attrs_error = error_message.get(op_name, {}).get(ATTRS, {}) - tuple = attrs_error.get(ADD, {}) - if tuple: - print(" * The added attribute '{}' is not yet registered.".format( - tuple[1])) - tuple = attrs_error.get(CHANGE, {}) - if tuple: + error_list = attrs_error.get(ADD, []) + if error_list: + for tup in error_list: + print(" * The added attribute '{}' is not yet registered.". + format(tup[1])) + error_dic = error_message.get(op_name, {}).get(ATTRS, {}).get(CHANGE, + {}) + for key, val in error_dic.items(): print(" * The change of attribute '{}' is not yet registered.". - format(tuple[1])) + format(key)) def print_repeat_process(): diff --git a/tools/print_op_desc.py b/tools/print_op_desc.py index 64445bab3a62c5c182e2d2a4b8f654377190afe2..b85103a7a25e164c83880b518caa803d53923892 100644 --- a/tools/print_op_desc.py +++ b/tools/print_op_desc.py @@ -18,7 +18,9 @@ Print all ops desc in dict: {input_name1: {DISPENSABLE: bool, INTERMEDIATE: bool, - DUPLICABLE: bool + DUPLICABLE: bool, + EXTRA: bool, + QUANT: bool, }, input_name2:{} }, @@ -28,6 +30,8 @@ Print all ops desc in dict: {TYPE: int, GENERATED: bool, DEFAULT_VALUE: int/str/etc, + EXTRA: bool, + QUANT: bool, } } } @@ -55,6 +59,9 @@ TYPE = "type" GENERATED = "generated" DEFAULT_VALUE = "default_value" +EXTRA = "extra" +QUANT = "quant" + def get_attr_default_value(op_name): return core.get_op_attrs_default_value(cpt.to_bytes(op_name)) @@ -68,6 +75,8 @@ def get_vars_info(op_vars_proto): vars_info[name][DUPLICABLE] = var_proto.duplicable vars_info[name][DISPENSABLE] = var_proto.dispensable vars_info[name][INTERMEDIATE] = var_proto.intermediate + vars_info[name][EXTRA] = var_proto.extra + vars_info[name][QUANT] = var_proto.quant return vars_info @@ -81,6 +90,8 @@ def get_attrs_info(op_proto, op_attrs_proto): attrs_info[attr_name][GENERATED] = attr_proto.generated attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[ attr_name] if attr_name in attrs_default_values else None + attrs_info[attr_name][EXTRA] = attr_proto.extra + attrs_info[attr_name][QUANT] = attr_proto.quant return attrs_info