diff --git a/mindinsight/lineagemgr/cache_item_updater.py b/mindinsight/lineagemgr/cache_item_updater.py index abcea1a8221a2b5e13c21db5b159aa576e55996a..52e522c3fe737bcec2213bca716ecb5c885c5719 100644 --- a/mindinsight/lineagemgr/cache_item_updater.py +++ b/mindinsight/lineagemgr/cache_item_updater.py @@ -18,7 +18,7 @@ import os from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob from mindinsight.lineagemgr.common.log import logger from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError -from mindinsight.lineagemgr.common.validator.validate import validate_train_id +from mindinsight.lineagemgr.common.validator.validate import validate_train_id, validate_added_info from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE from mindinsight.utils.exceptions import ParamValueError @@ -26,6 +26,7 @@ from mindinsight.utils.exceptions import ParamValueError def update_lineage_object(data_manager, train_id, added_info: dict): """Update lineage objects about tag and remark.""" validate_train_id(train_id) + validate_added_info(added_info) cache_item = data_manager.get_brief_train_job(train_id) cached_added_info = cache_item.get(key=LINEAGE).added_info new_added_info = dict(cached_added_info) diff --git a/mindinsight/lineagemgr/common/validator/validate.py b/mindinsight/lineagemgr/common/validator/validate.py index 4a839d95445b70099c02eeafff19efebc90d29b5..5e6418daeebe41598d61543ff8621757ef071bb8 100644 --- a/mindinsight/lineagemgr/common/validator/validate.py +++ b/mindinsight/lineagemgr/common/validator/validate.py @@ -362,8 +362,9 @@ def validate_condition(search_condition): log.error(err_msg) raise LineageParamValueError(err_msg) if not (sorted_name in FIELD_MAPPING - or (sorted_name.startswith('metric/') and len(sorted_name) > 7) - or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13)): + or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/')) + or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/')) + or sorted_name in ['tag']): log.error(err_msg) raise LineageParamValueError(err_msg) @@ -460,3 +461,54 @@ def validate_train_id(relative_path): raise ParamValueError( "Summary dir should be relative path starting with './'." ) + + +def validate_range(name, value, min_value, max_value): + """ + Check if value is in [min_value, max_value]. + + Args: + name (str): Value name. + value (Union[int, float]): Value to be check. + min_value (Union[int, float]): Min value. + max_value (Union[int, float]): Max value. + + Raises: + LineageParamValueError, if value type is invalid or value is out of [min_value, max_value]. + + """ + if not isinstance(value, (int, float)): + raise LineageParamValueError("Value should be int or float.") + + if value < min_value or value > max_value: + raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value)) + + +def validate_added_info(added_info: dict): + """ + Check if added_info is valid. + + Args: + added_info (dict): The added info. + + Raises: + bool, if added_info is valid, return True. + + """ + added_info_keys = ["tag", "remark"] + if not set(added_info.keys()).issubset(added_info_keys): + err_msg = "Keys must be in {}.".format(added_info_keys) + log.error(err_msg) + raise LineageParamValueError(err_msg) + + for key, value in added_info.items(): + if key == "tag": + if not isinstance(value, int): + raise LineageParamValueError("'tag' must be int.") + # tag should be in [0, 10]. + validate_range("tag", value, min_value=0, max_value=10) + elif key == "remark": + if not isinstance(value, str): + raise LineageParamValueError("'remark' must be str.") + # length of remark should be in [0, 128]. + validate_range("length of remark", len(value), min_value=0, max_value=128) diff --git a/mindinsight/lineagemgr/querier/querier.py b/mindinsight/lineagemgr/querier/querier.py index 1407dc3bd5657dba21c919a9280f28d07b6c1285..142e4c0ae2803d4785b7da6f11a0457421eab963 100644 --- a/mindinsight/lineagemgr/querier/querier.py +++ b/mindinsight/lineagemgr/querier/querier.py @@ -271,25 +271,6 @@ class Querier: return False return True - def _cmp(obj1: SuperLineageObj, obj2: SuperLineageObj): - value1 = obj1.lineage_obj.get_value_by_key(sorted_name) - value2 = obj2.lineage_obj.get_value_by_key(sorted_name) - - if value1 is None and value2 is None: - cmp_result = 0 - elif value1 is None: - cmp_result = -1 - elif value2 is None: - cmp_result = 1 - else: - try: - cmp_result = (value1 > value2) - (value1 < value2) - except TypeError: - type1 = type(value1).__name__ - type2 = type(value2).__name__ - cmp_result = (type1 > type2) - (type1 < type2) - return cmp_result - if condition is None: condition = {} @@ -298,19 +279,7 @@ class Querier: super_lineage_objs.sort(key=lambda x: x.update_time, reverse=True) results = list(filter(_filter, super_lineage_objs)) - - if ConditionParam.SORTED_NAME.value in condition: - sorted_name = condition.get(ConditionParam.SORTED_NAME.value) - if self._is_valid_field(sorted_name): - raise LineageQuerierParamException( - 'condition', - 'The sorted name {} not supported.'.format(sorted_name) - ) - sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) - reverse = sorted_type == 'descending' - results = sorted( - results, key=functools.cmp_to_key(_cmp), reverse=reverse - ) + results = self._sorted_results(results, condition) offset_results = self._handle_limit_and_offset(condition, results) @@ -338,6 +307,55 @@ class Querier: return lineage_info + def _sorted_results(self, results, condition): + """Get sorted results.""" + def _cmp(value1, value2): + if value1 is None and value2 is None: + cmp_result = 0 + elif value1 is None: + cmp_result = -1 + elif value2 is None: + cmp_result = 1 + else: + try: + cmp_result = (value1 > value2) - (value1 < value2) + except TypeError: + type1 = type(value1).__name__ + type2 = type(value2).__name__ + cmp_result = (type1 > type2) - (type1 < type2) + return cmp_result + + def _cmp_added_info(obj1: SuperLineageObj, obj2: SuperLineageObj): + value1 = obj1.added_info.get(sorted_name) + value2 = obj2.added_info.get(sorted_name) + return _cmp(value1, value2) + + def _cmp_super_lineage_obj(obj1: SuperLineageObj, obj2: SuperLineageObj): + value1 = obj1.lineage_obj.get_value_by_key(sorted_name) + value2 = obj2.lineage_obj.get_value_by_key(sorted_name) + + return _cmp(value1, value2) + + if ConditionParam.SORTED_NAME.value in condition: + sorted_name = condition.get(ConditionParam.SORTED_NAME.value) + sorted_type = condition.get(ConditionParam.SORTED_TYPE.value) + reverse = sorted_type == 'descending' + if sorted_name in ['tag']: + results = sorted( + results, key=functools.cmp_to_key(_cmp_added_info), reverse=reverse + ) + return results + + if self._is_valid_field(sorted_name): + raise LineageQuerierParamException( + 'condition', + 'The sorted name {} not supported.'.format(sorted_name) + ) + results = sorted( + results, key=functools.cmp_to_key(_cmp_super_lineage_obj), reverse=reverse + ) + return results + def _organize_customized(self, offset_results): """Organize customized.""" customized = dict() @@ -403,8 +421,8 @@ class Querier: Returns: bool, `True` if the field name is valid, else `False`. """ - return field_name not in FIELD_MAPPING and \ - not field_name.startswith(('metric/', 'user_defined/')) + return field_name not in FIELD_MAPPING \ + and not field_name.startswith(('metric/', 'user_defined/')) def _handle_limit_and_offset(self, condition, result): """