未验证 提交 b65708a8 编写于 作者: W Wilber 提交者: GitHub

update ci check_op_desc to support op_version and op_compat. (#37600)

* update check_op_desc to support op_version and op_compat.
上级 7e9b20b5
......@@ -76,9 +76,21 @@ if [ "$op_type_spec_diff" != "" ]; then
fi
op_desc_diff=`python ${PADDLE_ROOT}/tools/check_op_desc.py ${PADDLE_ROOT}/paddle/fluid/OP_DESC_DEV.spec ${PADDLE_ROOT}/paddle/fluid/OP_DESC_PR.spec`
inference_approve=`echo "$op_desc_diff" | grep "need inference to review" -`
slim_approve=`echo "$op_desc_diff" | grep "need slim to review" -`
if [ "$op_desc_diff" != "" ]; then
echo_line="You must have one RD (cyj1986, Superjomn) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n"
check_approval 1 39645414 328693
echo_line="You must have one RD (inference[ Superjomn(Recommend), Shixiaowei02, cyj1986 ] or slim[ wanghaoshuang(Recommend), qingqing01 ]) approval for the changes of Inputs/Output/Attrs of OPs. The changes of OPs will cause that the new version inference fails to load model trained by the old version. Please modify your code. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${op_desc_diff}\n"
check_approval 1 39645414 328693 39303645 7534971 7845005
fi
if [ "$slim_approve" != "" ]; then
echo_line="You must have one RD (wanghaoshuang(Recommend), qingqing01) approval for the changes of `quant` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${slim_approve}\n"
check_approval 1 7534971 7845005
fi
if [ "$inference_approve" != "" ]; then
echo_line="You must have one RD (Superjomn(Recommend), Shixiaowei02, cyj1986) approval for the changes of `def` Inputs/Output/Attrs of OPs. \n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/OP-Input-Output-Attribute-Compatibility-Modification].\n${inference_approve}\n"
check_approval 1 39645414 328693 39303645
fi
DEV_OP_USE_DEFAULT_GRAD_MAKER_SPEC=${PADDLE_ROOT}/paddle/fluid/op_use_default_grad_maker_DEV.spec
......
......@@ -40,6 +40,11 @@ TYPE = "type"
GENERATED = "generated"
DEFAULT_VALUE = "default_value"
# add_with_extra, add_with_quant and add_with_def
EXTRA = "extra"
QUANT = "quant"
DEF = "def"
error = False
version_update_map = {
......@@ -64,6 +69,9 @@ def diff_vars(origin_vars, new_vars):
var_add_dispensable_massage = []
var_deleted_error_massage = []
var_add_quant_message = []
var_add_def_message = []
common_vars_name = set(origin_vars.keys()) & set(new_vars.keys())
vars_name_only_in_origin = set(origin_vars.keys()) - set(new_vars.keys())
vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys())
......@@ -73,11 +81,12 @@ def diff_vars(origin_vars, new_vars):
continue
else:
error, var_error = True, True
var_changed_error_massage[var_name] = {}
for arg_name in origin_vars.get(var_name):
new_arg_value = new_vars.get(var_name, {}).get(arg_name)
origin_arg_value = origin_vars.get(var_name, {}).get(arg_name)
if new_arg_value != origin_arg_value:
if var_name not in var_changed_error_massage.keys():
var_changed_error_massage[var_name] = {}
var_changed_error_massage[var_name][arg_name] = (
origin_arg_value, new_arg_value)
......@@ -91,6 +100,21 @@ def diff_vars(origin_vars, new_vars):
error, var_error = True, True
var_add_dispensable_massage.append(var_name)
# if added var is extra, then no need to check.
if new_vars.get(var_name).get(EXTRA):
continue
# if added var is quant, slim needs to review, needs to register.
if new_vars.get(var_name).get(QUANT):
error, var_error = True, True
var_add_quant_message.append(var_name)
# if added var is def, inference needs to review, needs to register.
if not new_vars.get(var_name).get(EXTRA) and not new_vars.get(
var_name).get(QUANT):
error, var_error = True, True
var_add_def_message.append(var_name)
var_diff_message = {}
if var_add_massage:
var_diff_message[ADD] = var_add_massage
......@@ -100,6 +124,10 @@ def diff_vars(origin_vars, new_vars):
var_diff_message[CHANGE] = var_changed_error_massage
if var_deleted_error_massage:
var_diff_message[DELETE] = var_deleted_error_massage
if var_add_quant_message:
var_diff_message[QUANT] = var_add_quant_message
if var_add_def_message:
var_diff_message[DEF] = var_add_def_message
return var_error, var_diff_message
......@@ -113,6 +141,9 @@ def diff_attr(ori_attrs, new_attrs):
attr_added_def_error_massage = []
attr_deleted_error_massage = []
attr_added_quant_message = []
attr_added_define_message = []
common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys())
attrs_only_in_origin = set(ori_attrs.keys()) - set(new_attrs.keys())
attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys())
......@@ -122,11 +153,12 @@ def diff_attr(ori_attrs, new_attrs):
continue
else:
error, attr_error = True, True
attr_changed_error_massage[attr_name] = {}
for arg_name in ori_attrs.get(attr_name):
new_arg_value = new_attrs.get(attr_name, {}).get(arg_name)
origin_arg_value = ori_attrs.get(attr_name, {}).get(arg_name)
if new_arg_value != origin_arg_value:
if attr_name not in attr_changed_error_massage.keys():
attr_changed_error_massage[attr_name] = {}
attr_changed_error_massage[attr_name][arg_name] = (
origin_arg_value, new_arg_value)
......@@ -140,6 +172,17 @@ def diff_attr(ori_attrs, new_attrs):
error, attr_error = True, True
attr_added_def_error_massage.append(attr_name)
# if added attr is quant, slim needs to review, needs to register
if new_attrs.get(attr_name).get(QUANT):
error, var_error = True, True
attr_added_quant_message.append(attr_name)
# if added attr is def, inference needs to review, needs to register
if not new_attrs.get(attr_name).get(EXTRA) and not new_attrs.get(
attr_name).get(QUANT):
error, var_error = True, True
attr_added_define_message.append(attr_name)
attr_diff_message = {}
if attr_added_error_massage:
attr_diff_message[ADD] = attr_added_error_massage
......@@ -149,6 +192,10 @@ def diff_attr(ori_attrs, new_attrs):
attr_diff_message[CHANGE] = attr_changed_error_massage
if attr_deleted_error_massage:
attr_diff_message[DELETE] = attr_deleted_error_massage
if attr_added_define_message:
attr_diff_message[DEF] = attr_added_define_message
if attr_added_quant_message:
attr_diff_message[QUANT] = attr_added_quant_message
return attr_error, attr_diff_message
......@@ -157,23 +204,49 @@ def check_io_registry(io_type, op, diff):
checker = OpLastCheckpointChecker()
results = {}
for update_type in [ADD]:
for item in diff.get(update_type, {}):
for item in diff.get(update_type, []):
infos = checker.filter_updates(
op, version_update_map[io_type][update_type], item)
if not infos:
results[update_type] = (op, item, io_type)
if update_type not in results.keys():
results[update_type] = []
# extra not need to register.
qaunt_ios = diff.get(QUANT, [])
def_ios = diff.get(DEF, [])
if item in qaunt_ios or item in def_ios:
results[update_type].append((op, item, io_type))
return results
def check_attr_registry(op, diff):
def check_attr_registry(op, diff, origin_attrs):
checker = OpLastCheckpointChecker()
results = {}
qaunt_attrs = diff.get(QUANT, [])
def_attrs = diff.get(DEF, [])
change_attrs = diff.get(CHANGE, {})
for update_type in [ADD, CHANGE]:
for item in diff.get(update_type, {}):
infos = checker.filter_updates(
op, version_update_map[ATTRS][update_type], item)
if not infos:
results[update_type] = (op, item)
if update_type == ADD:
if update_type not in results.keys():
results[update_type] = []
# extra not need to register.
if item in qaunt_attrs or item in def_attrs:
results[update_type].append((op, item))
elif update_type == CHANGE:
if CHANGE not in results.keys():
results[update_type] = {}
for attr_name, attr_change in change_attrs.items():
# extra not need to register.
if not origin_attrs.get(attr_name).get(EXTRA):
results[update_type][attr_name] = attr_change
for update_type in [ADD, CHANGE]:
if update_type in results.keys() and len(results[update_type]) == 0:
del results[update_type]
return results
......@@ -206,13 +279,14 @@ def compare_op_desc(origin_op_desc, new_op_desc):
origin_attrs = origin_info.get(ATTRS, {})
new_attrs = new_info.get(ATTRS, {})
attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
attrs_version_errors = check_attr_registry(op_type, attrs_diff)
attrs_version_errors = check_attr_registry(op_type, attrs_diff,
origin_attrs)
if ins_error:
if ins_diff:
desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff
if outs_error:
if outs_diff:
desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
if attrs_error:
if attrs_diff:
desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
if ins_version_errors:
......@@ -250,6 +324,14 @@ def print_desc_error_message(error_message):
" * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value))
for name in Inputs_error.get(QUANT, {}):
print(" * The added Input '{}' is `quant`, need slim to review.".
format(name))
for name in Inputs_error.get(DEF, {}):
print(" * The added Input '{}' is `def`, need inference to review.".
format(name))
# 2. print outputs error message
Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
for name in Outputs_error.get(ADD_DISPENSABLE, {}):
......@@ -266,6 +348,15 @@ def print_desc_error_message(error_message):
" * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value))
for name in Outputs_error.get(QUANT, {}):
print(" * The added Output '{}' is `quant`, need slim to review.".
format(name))
for name in Outputs_error.get(DEF, {}):
print(
" * The added Output '{}' is `def`, need inference to review.".
format(name))
# 3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
for name in attrs_error.get(ADD_WITH_DEFAULT, {}):
......@@ -283,6 +374,16 @@ def print_desc_error_message(error_message):
" * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.".
format(arg, name, ori_value, new_value))
for name in attrs_error.get(QUANT, {}):
# TODO(Wilber):
print(" * The added attr '{}' is `quant`, need slim to review.".
format(name))
for name in attrs_error.get(DEF, {}):
# TODO(Wilber):
print(" * The added attr '{}' is `def`, need inference to review.".
format(name))
def print_version_error_message(error_message):
print(
......@@ -294,28 +395,32 @@ def print_version_error_message(error_message):
# 1. print inputs error message
inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
tuple = inputs_error.get(ADD, {})
if tuple:
print(" * The added input '{}' is not yet registered.".format(tuple[
1]))
error_list = inputs_error.get(ADD, [])
if error_list:
for tup in error_list:
print(" * The added input '{}' is not yet registered.".format(
tup[1]))
# 2. print inputs error message
# 2. print outputs error message
outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
tuple = outputs_error.get(ADD, {})
if tuple:
error_list = outputs_error.get(ADD, [])
if error_list:
for tup in error_list:
print(" * The added output '{}' is not yet registered.".format(
tuple[1]))
tup[1]))
#3. print attrs error message
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
tuple = attrs_error.get(ADD, {})
if tuple:
print(" * The added attribute '{}' is not yet registered.".format(
tuple[1]))
tuple = attrs_error.get(CHANGE, {})
if tuple:
error_list = attrs_error.get(ADD, [])
if error_list:
for tup in error_list:
print(" * The added attribute '{}' is not yet registered.".
format(tup[1]))
error_dic = error_message.get(op_name, {}).get(ATTRS, {}).get(CHANGE,
{})
for key, val in error_dic.items():
print(" * The change of attribute '{}' is not yet registered.".
format(tuple[1]))
format(key))
def print_repeat_process():
......
......@@ -18,7 +18,9 @@ Print all ops desc in dict:
{input_name1:
{DISPENSABLE: bool,
INTERMEDIATE: bool,
DUPLICABLE: bool
DUPLICABLE: bool,
EXTRA: bool,
QUANT: bool,
},
input_name2:{}
},
......@@ -28,6 +30,8 @@ Print all ops desc in dict:
{TYPE: int,
GENERATED: bool,
DEFAULT_VALUE: int/str/etc,
EXTRA: bool,
QUANT: bool,
}
}
}
......@@ -55,6 +59,9 @@ TYPE = "type"
GENERATED = "generated"
DEFAULT_VALUE = "default_value"
EXTRA = "extra"
QUANT = "quant"
def get_attr_default_value(op_name):
return core.get_op_attrs_default_value(cpt.to_bytes(op_name))
......@@ -68,6 +75,8 @@ def get_vars_info(op_vars_proto):
vars_info[name][DUPLICABLE] = var_proto.duplicable
vars_info[name][DISPENSABLE] = var_proto.dispensable
vars_info[name][INTERMEDIATE] = var_proto.intermediate
vars_info[name][EXTRA] = var_proto.extra
vars_info[name][QUANT] = var_proto.quant
return vars_info
......@@ -81,6 +90,8 @@ def get_attrs_info(op_proto, op_attrs_proto):
attrs_info[attr_name][GENERATED] = attr_proto.generated
attrs_info[attr_name][DEFAULT_VALUE] = attrs_default_values[
attr_name] if attr_name in attrs_default_values else None
attrs_info[attr_name][EXTRA] = attr_proto.extra
attrs_info[attr_name][QUANT] = attr_proto.quant
return attrs_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册