diff --git a/tools/check_op_desc.py b/tools/check_op_desc.py index 9e5bd08be698dc45d8f56de98557bc298569d8d0..15e410401216ce68fc4178bb4e60d3ac159523f4 100644 --- a/tools/check_op_desc.py +++ b/tools/check_op_desc.py @@ -14,7 +14,6 @@ import json import sys -import operator from paddle.utils import OpLastCheckpointChecker from paddle.fluid.core import OpUpdateType @@ -72,8 +71,7 @@ def diff_vars(origin_vars, new_vars): vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys()) for var_name in common_vars_name: - if operator.eq(origin_vars.get(var_name), - new_vars.get(var_name)) == SAME: + if cmp(origin_vars.get(var_name), new_vars.get(var_name)) == SAME: continue else: error, var_error = True, True @@ -122,8 +120,7 @@ def diff_attr(ori_attrs, new_attrs): attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys()) for attr_name in common_attrs: - if operator.eq(ori_attrs.get(attr_name), - new_attrs.get(attr_name)) == SAME: + if cmp(ori_attrs.get(attr_name), new_attrs.get(attr_name)) == SAME: continue else: error, attr_error = True, True @@ -187,7 +184,7 @@ def compare_op_desc(origin_op_desc, new_op_desc): new = json.loads(new_op_desc) desc_error_message = {} version_error_message = {} - if operator.eq(origin_op_desc, new_op_desc) == SAME: + if cmp(origin_op_desc, new_op_desc) == SAME: return desc_error_message, version_error_message for op_type in origin: