diff --git a/mindinsight/profiler/common/validator/validate.py b/mindinsight/profiler/common/validator/validate.py index 86bf4965f1c126564b264c4fcbc15bbd44d44f56..68a3aef79196314aec498b10e535a344b1cd7a7b 100644 --- a/mindinsight/profiler/common/validator/validate.py +++ b/mindinsight/profiler/common/validator/validate.py @@ -14,7 +14,7 @@ # ============================================================================ """Validate the profiler parameters.""" from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamTypeErrorException, \ - ProfilerParamValueErrorException, ProfilerDeviceIdException, ProfilerOpTypeException, \ + ProfilerDeviceIdException, ProfilerOpTypeException, \ ProfilerSortConditionException, ProfilerFilterConditionException, ProfilerGroupConditionException from mindinsight.profiler.common.log import logger as log @@ -63,77 +63,125 @@ def validate_condition(search_condition): raise ProfilerOpTypeException("The op_type must in ['aicpu', 'aicore_type', 'aicore_detail']") if "group_condition" in search_condition: - group_condition = search_condition.get("group_condition") - if not isinstance(group_condition, dict): - raise ProfilerGroupConditionException("The group condition must be dict.") - if "limit" in group_condition: - limit = group_condition.get("limit", 0) - if isinstance(limit, bool) \ - or not isinstance(group_condition.get("limit"), int): - log.error("The limit must be int.") - raise ProfilerGroupConditionException("The limit must be int.") - if limit < 1 or limit > 100: - raise ProfilerGroupConditionException("The limit must in [1, 100].") - - if "offset" in group_condition: - offset = group_condition.get("offset", 0) - if isinstance(offset, bool) \ - or not isinstance(group_condition.get("offset"), int): - log.error("The offset must be int.") - raise ProfilerGroupConditionException("The offset must be int.") - if offset < 0: - raise ProfilerGroupConditionException("The offset must ge 0.") - - if offset > 1000000: - raise ProfilerGroupConditionException("The offset must le 1000000.") + validata_group_condition(search_condition) if "sort_condition" in search_condition: - sort_condition = search_condition.get("sort_condition") - if not isinstance(sort_condition, dict): - raise ProfilerSortConditionException("The sort condition must be dict.") - if "name" in sort_condition: - sorted_name = sort_condition.get("name", "") - err_msg = "The sorted_name must be in {}".format(search_scope) - if not isinstance(sorted_name, str): - log.error("Wrong sorted name type.") - raise ProfilerSortConditionException("Wrong sorted name type.") - if sorted_name not in search_scope: - log.error(err_msg) - raise ProfilerSortConditionException(err_msg) - - if "type" in sort_condition: - sorted_type_param = ['ascending', 'descending'] - sorted_type = sort_condition.get("type") - if sorted_type not in sorted_type_param: - err_msg = "The sorted type must be ascending or descending." - log.error(err_msg) - raise ProfilerParamValueErrorException(err_msg) + validate_sort_condition(search_condition, search_scope) if "filter_condition" in search_condition: - def validate_op_filter_condition(op_condition): - if not isinstance(op_condition, dict): - raise ProfilerFilterConditionException("Wrong op_type filter condition.") - for key, value in op_condition.items(): - if not isinstance(key, str): - raise ProfilerFilterConditionException("The filter key must be str") - if not isinstance(value, list): - raise ProfilerFilterConditionException("The filter value must be list") - if key not in filter_key: - raise ProfilerFilterConditionException("The filter key must in {}.".format(filter_key)) - for item in value: - if not isinstance(item, str): - raise ProfilerFilterConditionException("The item in filter value must be str") - - filter_condition = search_condition.get("filter_condition") - if not isinstance(filter_condition, dict): - raise ProfilerFilterConditionException("The filter condition must be dict.") - filter_key = ["in", "not_in", "partial_match_str_in"] - if filter_condition: - if "op_type" in filter_condition: - op_type_condition = filter_condition.get("op_type") - validate_op_filter_condition(op_type_condition) - if "op_name" in filter_condition: - op_name_condition = filter_condition.get("op_name") - validate_op_filter_condition(op_name_condition) - if "op_type" not in filter_condition and "op_name" not in filter_condition: - raise ProfilerFilterConditionException("The key of filter_condition is not support") + validate_filter_condition(search_condition) + +def validata_group_condition(search_condition): + """ + Verify the group_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + + Raises: + ProfilerGroupConditionException: If the group_condition param in search_condition is invalid. + """ + group_condition = search_condition.get("group_condition") + if not isinstance(group_condition, dict): + raise ProfilerGroupConditionException("The group condition must be dict.") + if "limit" in group_condition: + limit = group_condition.get("limit", 0) + if isinstance(limit, bool) \ + or not isinstance(group_condition.get("limit"), int): + log.error("The limit must be int.") + raise ProfilerGroupConditionException("The limit must be int.") + if limit < 1 or limit > 100: + raise ProfilerGroupConditionException("The limit must in [1, 100].") + + if "offset" in group_condition: + offset = group_condition.get("offset", 0) + if isinstance(offset, bool) \ + or not isinstance(group_condition.get("offset"), int): + log.error("The offset must be int.") + raise ProfilerGroupConditionException("The offset must be int.") + if offset < 0: + raise ProfilerGroupConditionException("The offset must ge 0.") + + if offset > 1000000: + raise ProfilerGroupConditionException("The offset must le 1000000.") + + +def validate_sort_condition(search_condition, search_scope): + """ + Verify the sort_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + search_scope (list): The search scope. + + Raises: + ProfilerSortConditionException: If the sort_condition param in search_condition is invalid. + """ + sort_condition = search_condition.get("sort_condition") + if not isinstance(sort_condition, dict): + raise ProfilerSortConditionException("The sort condition must be dict.") + if "name" in sort_condition: + sorted_name = sort_condition.get("name", "") + err_msg = "The sorted_name must be in {}".format(search_scope) + if not isinstance(sorted_name, str): + log.error("Wrong sorted name type.") + raise ProfilerSortConditionException("Wrong sorted name type.") + if sorted_name not in search_scope: + log.error(err_msg) + raise ProfilerSortConditionException(err_msg) + + if "type" in sort_condition: + sorted_type_param = ['ascending', 'descending'] + sorted_type = sort_condition.get("type") + if sorted_type not in sorted_type_param: + err_msg = "The sorted type must be ascending or descending." + log.error(err_msg) + raise ProfilerSortConditionException(err_msg) + + +def validate_filter_condition(search_condition): + """ + Verify the filter_condition in search_condition is valid or not. + + Args: + search_condition (dict): The search condition. + + Raises: + ProfilerFilterConditionException: If the filter_condition param in search_condition is invalid. + """ + def validate_op_filter_condition(op_condition): + """ + Verify the op_condition in filter_condition is valid or not. + + Args: + op_condition (dict): The op_condition in search_condition. + + Raises: + ProfilerFilterConditionException: If the filter_condition param in search_condition is invalid. + """ + if not isinstance(op_condition, dict): + raise ProfilerFilterConditionException("Wrong op_type filter condition.") + for key, value in op_condition.items(): + if not isinstance(key, str): + raise ProfilerFilterConditionException("The filter key must be str") + if not isinstance(value, list): + raise ProfilerFilterConditionException("The filter value must be list") + if key not in filter_key: + raise ProfilerFilterConditionException("The filter key must in {}.".format(filter_key)) + for item in value: + if not isinstance(item, str): + raise ProfilerFilterConditionException("The item in filter value must be str") + + filter_condition = search_condition.get("filter_condition") + if not isinstance(filter_condition, dict): + raise ProfilerFilterConditionException("The filter condition must be dict.") + filter_key = ["in", "not_in", "partial_match_str_in"] + if filter_condition: + if "op_type" in filter_condition: + op_type_condition = filter_condition.get("op_type") + validate_op_filter_condition(op_type_condition) + if "op_name" in filter_condition: + op_name_condition = filter_condition.get("op_name") + validate_op_filter_condition(op_name_condition) + if "op_type" not in filter_condition and "op_name" not in filter_condition: + raise ProfilerFilterConditionException("The key of filter_condition is not support") diff --git a/tests/ut/backend/profiler/test_profiler_restful_api.py b/tests/ut/backend/profiler/test_profiler_restful_api.py index 2775cba328874c46ccd8b6ac8fd782dd085dac4f..278394076ba28ccd2eb4b7272cfa2cab2a34d46a 100644 --- a/tests/ut/backend/profiler/test_profiler_restful_api.py +++ b/tests/ut/backend/profiler/test_profiler_restful_api.py @@ -78,3 +78,44 @@ class TestProfilerRestfulApi(TestCase): result = response.get_json() del result["error_msg"] self.assertDictEqual(expect_result, result) + + body_data = {"op_type": "aicore_type", "device_id": 1} + response = self.app_client.post(self.url, data=json.dumps(body_data)) + self.assertEqual(400, response.status_code) + expect_result = { + 'error_code': '50546182', + } + result = response.get_json() + del result["error_msg"] + self.assertDictEqual(expect_result, result) + + body_data = {"op_type": "aicore_type", "device_id": "1", "group_condition": 1} + response = self.app_client.post(self.url, data=json.dumps(body_data)) + self.assertEqual(400, response.status_code) + expect_result = { + 'error_code': '50546184', + } + result = response.get_json() + del result["error_msg"] + self.assertDictEqual(expect_result, result) + + body_data = {"op_type": "aicore_type", "device_id": "1", "sort_condition": {"type": 1}} + response = self.app_client.post(self.url, data=json.dumps(body_data)) + self.assertEqual(400, response.status_code) + expect_result = { + 'error_code': '50546185', + } + result = response.get_json() + del result["error_msg"] + self.assertDictEqual(expect_result, result) + + body_data = {"op_type": "aicore_type", "device_id": "1", + "filter_condition": {"op_type": {"in": ["1", 2]}}} + response = self.app_client.post(self.url, data=json.dumps(body_data)) + self.assertEqual(400, response.status_code) + expect_result = { + 'error_code': '50546186', + } + result = response.get_json() + del result["error_msg"] + self.assertDictEqual(expect_result, result)