提交 f379b204 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!133 add sorting and checking for added_info

Merge pull request !133 from luopengting/lineage_added_info
......@@ -18,7 +18,7 @@ import os
from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob
from mindinsight.lineagemgr.common.log import logger
from mindinsight.lineagemgr.common.exceptions.exceptions import LineageFileNotFoundError
from mindinsight.lineagemgr.common.validator.validate import validate_train_id
from mindinsight.lineagemgr.common.validator.validate import validate_train_id, validate_added_info
from mindinsight.lineagemgr.lineage_parser import LineageParser, LINEAGE
from mindinsight.utils.exceptions import ParamValueError
......@@ -26,6 +26,7 @@ from mindinsight.utils.exceptions import ParamValueError
def update_lineage_object(data_manager, train_id, added_info: dict):
"""Update lineage objects about tag and remark."""
validate_train_id(train_id)
validate_added_info(added_info)
cache_item = data_manager.get_brief_train_job(train_id)
lineage_item = cache_item.get(key=LINEAGE, raise_exception=False)
if lineage_item is None:
......
......@@ -362,8 +362,9 @@ def validate_condition(search_condition):
log.error(err_msg)
raise LineageParamValueError(err_msg)
if not (sorted_name in FIELD_MAPPING
or (sorted_name.startswith('metric/') and len(sorted_name) > 7)
or (sorted_name.startswith('user_defined/') and len(sorted_name) > 13)):
or (sorted_name.startswith('metric/') and len(sorted_name) > len('metric/'))
or (sorted_name.startswith('user_defined/') and len(sorted_name) > len('user_defined/'))
or sorted_name in ['tag']):
log.error(err_msg)
raise LineageParamValueError(err_msg)
......@@ -460,3 +461,54 @@ def validate_train_id(relative_path):
raise ParamValueError(
"Summary dir should be relative path starting with './'."
)
def validate_range(name, value, min_value, max_value):
"""
Check if value is in [min_value, max_value].
Args:
name (str): Value name.
value (Union[int, float]): Value to be check.
min_value (Union[int, float]): Min value.
max_value (Union[int, float]): Max value.
Raises:
LineageParamValueError, if value type is invalid or value is out of [min_value, max_value].
"""
if not isinstance(value, (int, float)):
raise LineageParamValueError("Value should be int or float.")
if value < min_value or value > max_value:
raise LineageParamValueError("The %s should in [%d, %d]." % (name, min_value, max_value))
def validate_added_info(added_info: dict):
"""
Check if added_info is valid.
Args:
added_info (dict): The added info.
Raises:
bool, if added_info is valid, return True.
"""
added_info_keys = ["tag", "remark"]
if not set(added_info.keys()).issubset(added_info_keys):
err_msg = "Keys must be in {}.".format(added_info_keys)
log.error(err_msg)
raise LineageParamValueError(err_msg)
for key, value in added_info.items():
if key == "tag":
if not isinstance(value, int):
raise LineageParamValueError("'tag' must be int.")
# tag should be in [0, 10].
validate_range("tag", value, min_value=0, max_value=10)
elif key == "remark":
if not isinstance(value, str):
raise LineageParamValueError("'remark' must be str.")
# length of remark should be in [0, 128].
validate_range("length of remark", len(value), min_value=0, max_value=128)
......@@ -271,25 +271,6 @@ class Querier:
return False
return True
def _cmp(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.lineage_obj.get_value_by_key(sorted_name)
value2 = obj2.lineage_obj.get_value_by_key(sorted_name)
if value1 is None and value2 is None:
cmp_result = 0
elif value1 is None:
cmp_result = -1
elif value2 is None:
cmp_result = 1
else:
try:
cmp_result = (value1 > value2) - (value1 < value2)
except TypeError:
type1 = type(value1).__name__
type2 = type(value2).__name__
cmp_result = (type1 > type2) - (type1 < type2)
return cmp_result
if condition is None:
condition = {}
......@@ -298,19 +279,7 @@ class Querier:
super_lineage_objs.sort(key=lambda x: x.update_time, reverse=True)
results = list(filter(_filter, super_lineage_objs))
if ConditionParam.SORTED_NAME.value in condition:
sorted_name = condition.get(ConditionParam.SORTED_NAME.value)
if self._is_valid_field(sorted_name):
raise LineageQuerierParamException(
'condition',
'The sorted name {} not supported.'.format(sorted_name)
)
sorted_type = condition.get(ConditionParam.SORTED_TYPE.value)
reverse = sorted_type == 'descending'
results = sorted(
results, key=functools.cmp_to_key(_cmp), reverse=reverse
)
results = self._sorted_results(results, condition)
offset_results = self._handle_limit_and_offset(condition, results)
......@@ -338,6 +307,55 @@ class Querier:
return lineage_info
def _sorted_results(self, results, condition):
"""Get sorted results."""
def _cmp(value1, value2):
if value1 is None and value2 is None:
cmp_result = 0
elif value1 is None:
cmp_result = -1
elif value2 is None:
cmp_result = 1
else:
try:
cmp_result = (value1 > value2) - (value1 < value2)
except TypeError:
type1 = type(value1).__name__
type2 = type(value2).__name__
cmp_result = (type1 > type2) - (type1 < type2)
return cmp_result
def _cmp_added_info(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.added_info.get(sorted_name)
value2 = obj2.added_info.get(sorted_name)
return _cmp(value1, value2)
def _cmp_super_lineage_obj(obj1: SuperLineageObj, obj2: SuperLineageObj):
value1 = obj1.lineage_obj.get_value_by_key(sorted_name)
value2 = obj2.lineage_obj.get_value_by_key(sorted_name)
return _cmp(value1, value2)
if ConditionParam.SORTED_NAME.value in condition:
sorted_name = condition.get(ConditionParam.SORTED_NAME.value)
sorted_type = condition.get(ConditionParam.SORTED_TYPE.value)
reverse = sorted_type == 'descending'
if sorted_name in ['tag']:
results = sorted(
results, key=functools.cmp_to_key(_cmp_added_info), reverse=reverse
)
return results
if self._is_valid_field(sorted_name):
raise LineageQuerierParamException(
'condition',
'The sorted name {} not supported.'.format(sorted_name)
)
results = sorted(
results, key=functools.cmp_to_key(_cmp_super_lineage_obj), reverse=reverse
)
return results
def _organize_customized(self, offset_results):
"""Organize customized."""
customized = dict()
......@@ -403,8 +421,8 @@ class Querier:
Returns:
bool, `True` if the field name is valid, else `False`.
"""
return field_name not in FIELD_MAPPING and \
not field_name.startswith(('metric/', 'user_defined/'))
return field_name not in FIELD_MAPPING \
and not field_name.startswith(('metric/', 'user_defined/'))
def _handle_limit_and_offset(self, condition, result):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册