diff --git a/mindinsight/backend/lineagemgr/lineage_api.py b/mindinsight/backend/lineagemgr/lineage_api.py index 97108928babc2b467ffd8e09bcc86ff611d1d325..2aa64ad2b88db485c3f4f3679f27db646f6a6af3 100644 --- a/mindinsight/backend/lineagemgr/lineage_api.py +++ b/mindinsight/backend/lineagemgr/lineage_api.py @@ -27,52 +27,20 @@ from mindinsight.utils.exceptions import MindInsightException, ParamValueError BLUEPRINT = Blueprint("lineage", __name__, url_prefix=settings.URL_PREFIX.rstrip("/")) -@BLUEPRINT.route("/models/model_lineage", methods=["POST"]) -def search_model(): +@BLUEPRINT.route("/lineagemgr/lineages", methods=["POST"]) +def get_lineage(): """ - Get model lineage info. - - Get model info by summary base dir return a model lineage information list of dict - contains model's all kinds of param and count of summary log. - - Returns: - str, the model lineage information. - - Raises: - MindInsightException: If method fails to be called. - ParamValueError: If parsing json data search_condition fails. - - Examples: - >>> POST http://xxxx/v1/mindinsight/models/model_lineage - """ - search_condition = request.stream.read() - try: - search_condition = json.loads(search_condition if search_condition else "{}") - except Exception: - raise ParamValueError("Json data parse failed.") - - model_lineage_info = _get_lineage_info( - lineage_type="model", - search_condition=search_condition - ) - - return jsonify(model_lineage_info) - - -@BLUEPRINT.route("/datasets/dataset_lineage", methods=["POST"]) -def get_datasets_lineage(): - """ - Get dataset lineage. + Get lineage. Returns: - str, the dataset lineage information. + str, the lineage information. Raises: MindInsightException: If method fails to be called. ParamValueError: If parsing json data search_condition fails. Examples: - >>> POST http://xxxx/v1/minddata/datasets/dataset_lineage + >>> POST http://xxxx/v1/mindinsight/lineagemgr/lineages """ search_condition = request.stream.read() try: @@ -80,20 +48,16 @@ def get_datasets_lineage(): except Exception: raise ParamValueError("Json data parse failed.") - dataset_lineage_info = _get_lineage_info( - lineage_type="dataset", - search_condition=search_condition - ) + lineage_info = _get_lineage_info(search_condition=search_condition) - return jsonify(dataset_lineage_info) + return jsonify(lineage_info) -def _get_lineage_info(lineage_type, search_condition): +def _get_lineage_info(search_condition): """ Get lineage info for dataset or model. Args: - lineage_type (str): Lineage type, 'dataset' or 'model'. search_condition (dict): Search condition. Returns: @@ -102,10 +66,6 @@ def _get_lineage_info(lineage_type, search_condition): Raises: MindInsightException: If method fails to be called. """ - if 'lineage_type' in search_condition: - raise ParamValueError("Lineage type does not need to be assigned in a specific interface.") - if lineage_type == 'dataset': - search_condition.update({'lineage_type': 'dataset'}) summary_base_dir = str(settings.SUMMARY_BASE_DIR) try: lineage_info = filter_summary_lineage( diff --git a/mindinsight/lineagemgr/api/model.py b/mindinsight/lineagemgr/api/model.py index d22f7a0650b07eca3838b42e447ffaa11a88408a..d37f9d8c38e5e0e0eba19a523ff29909a0fe07bf 100644 --- a/mindinsight/lineagemgr/api/model.py +++ b/mindinsight/lineagemgr/api/model.py @@ -262,8 +262,6 @@ def _convert_relative_path_to_abspath(summary_base_dir, search_condition): return search_condition summary_dir_condition = search_condition.get("summary_dir") - if not set(summary_dir_condition.keys()).issubset(['in', 'eq']): - raise LineageParamValueError("Invalid operation of summary dir.") if 'in' in summary_dir_condition: summary_paths = [] diff --git a/mindinsight/lineagemgr/common/exceptions/error_code.py b/mindinsight/lineagemgr/common/exceptions/error_code.py index fe81a6e8fe775f76c0badc9cb0d6cdd6fab5d3d0..0620112b3f7398468346613fe224e612009f8275 100644 --- a/mindinsight/lineagemgr/common/exceptions/error_code.py +++ b/mindinsight/lineagemgr/common/exceptions/error_code.py @@ -193,7 +193,7 @@ class LineageErrorMsg(Enum): "It should be a string." LINEAGE_PARAM_LINEAGE_TYPE_ERROR = "The parameter lineage_type is invalid. " \ - "It should be None, 'dataset' or 'model'." + "It should be 'dataset' or 'model'." SUMMARY_ANALYZE_ERROR = "Failed to analyze summary log. {}" SUMMARY_VERIFICATION_ERROR = "Verification failed in summary analysis. {}" diff --git a/mindinsight/lineagemgr/common/validator/model_parameter.py b/mindinsight/lineagemgr/common/validator/model_parameter.py index d4ee016b26046caea0b2fe81929ee7bcbc2b0eaf..88f2a2b14d6a412594e2dbeb930d4c676f8d12e5 100644 --- a/mindinsight/lineagemgr/common/validator/model_parameter.py +++ b/mindinsight/lineagemgr/common/validator/model_parameter.py @@ -14,7 +14,7 @@ # ============================================================================ """Define schema of model lineage input parameters.""" from marshmallow import Schema, fields, ValidationError, pre_load, validates -from marshmallow.validate import Range, OneOf +from marshmallow.validate import Range from mindinsight.lineagemgr.common.exceptions.error_code import LineageErrorMsg, \ LineageErrors @@ -129,10 +129,7 @@ class SearchModelConditionParameter(Schema): offset = fields.Int(validate=lambda n: 0 <= n <= 100000) sorted_name = fields.Str() sorted_type = fields.Str(allow_none=True) - lineage_type = fields.Str( - validate=OneOf(enum_to_list(LineageType)), - allow_none=True - ) + lineage_type = fields.Dict() @staticmethod def check_dict_value_type(data, value_type): @@ -174,53 +171,79 @@ class SearchModelConditionParameter(Schema): @validates("loss_function") def check_loss_function(self, data): + """Check loss function.""" SearchModelConditionParameter.check_dict_value_type(data, str) @validates("train_dataset_path") def check_train_dataset_path(self, data): + """Check train dataset path.""" SearchModelConditionParameter.check_dict_value_type(data, str) @validates("train_dataset_count") def check_train_dataset_count(self, data): + """Check train dataset count.""" SearchModelConditionParameter.check_dict_value_type(data, int) @validates("test_dataset_path") def check_test_dataset_path(self, data): + """Check test dataset path.""" SearchModelConditionParameter.check_dict_value_type(data, str) @validates("test_dataset_count") def check_test_dataset_count(self, data): + """Check test dataset count.""" SearchModelConditionParameter.check_dict_value_type(data, int) @validates("network") def check_network(self, data): + """Check network.""" SearchModelConditionParameter.check_dict_value_type(data, str) @validates("optimizer") def check_optimizer(self, data): + """Check optimizer.""" SearchModelConditionParameter.check_dict_value_type(data, str) @validates("epoch") def check_epoch(self, data): + """Check epoch.""" SearchModelConditionParameter.check_dict_value_type(data, int) @validates("batch_size") def check_batch_size(self, data): + """Check batch size.""" SearchModelConditionParameter.check_dict_value_type(data, int) @validates("model_size") def check_model_size(self, data): + """Check model size.""" SearchModelConditionParameter.check_dict_value_type(data, int) @validates("summary_dir") def check_summary_dir(self, data): + """Check summary dir.""" SearchModelConditionParameter.check_dict_value_type(data, str) + @validates("lineage_type") + def check_lineage_type(self, data): + """Check lineage type.""" + SearchModelConditionParameter.check_dict_value_type(data, str) + recv_types = [] + for key, value in data.items(): + if key == "in": + recv_types = value + else: + recv_types.append(value) + + lineage_types = enum_to_list(LineageType) + if not set(recv_types).issubset(lineage_types): + raise ValidationError("Given lineage type should be one of %s." % lineage_types) + @pre_load def check_comparision(self, data, **kwargs): """Check comparision for all parameters in schema.""" for attr, condition in data.items(): - if attr in ["limit", "offset", "sorted_name", "sorted_type", "lineage_type"]: + if attr in ["limit", "offset", "sorted_name", "sorted_type"]: continue if not isinstance(attr, str): @@ -233,6 +256,13 @@ class SearchModelConditionParameter(Schema): raise LineageParamTypeError("The search_condition element {} should be dict." .format(attr)) + if attr in ["summary_dir", "lineage_type"]: + if not set(condition.keys()).issubset(['in', 'eq']): + raise LineageParamValueError("Invalid operation of %s." % attr) + if len(condition.keys()) > 1: + raise LineageParamValueError("More than one operation of %s." % attr) + continue + for key in condition.keys(): if key not in ["eq", "lt", "gt", "le", "ge", "in"]: raise LineageParamValueError("The compare condition should be in " diff --git a/mindinsight/lineagemgr/querier/querier.py b/mindinsight/lineagemgr/querier/querier.py index 2f8798126e13e56a93a3152b1193d0e5d8773fc8..b7d9207e7f9d5d5c73a83b39edab787a02627fe7 100644 --- a/mindinsight/lineagemgr/querier/querier.py +++ b/mindinsight/lineagemgr/querier/querier.py @@ -23,6 +23,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import \ LineageEventNotExistException, LineageQuerierParamException, \ LineageSummaryParseException, LineageEventFieldNotExistException from mindinsight.lineagemgr.common.log import logger +from mindinsight.lineagemgr.common.utils import enum_to_list from mindinsight.lineagemgr.querier.query_model import LineageObj, FIELD_MAPPING from mindinsight.lineagemgr.summary.lineage_summary_analyzer import \ LineageSummaryAnalyzer @@ -318,18 +319,46 @@ class Querier: customized[label]["required"] = True customized[label]["type"] = type(value).__name__ - search_type = condition.get(ConditionParam.LINEAGE_TYPE.value) + lineage_types = condition.get(ConditionParam.LINEAGE_TYPE.value) + lineage_types = self._get_lineage_types(lineage_types) + + object_items = [] + for item in offset_results: + lineage_object = dict() + if LineageType.MODEL.value in lineage_types: + lineage_object.update(item.to_model_lineage_dict()) + if LineageType.DATASET.value in lineage_types: + lineage_object.update(item.to_dataset_lineage_dict()) + object_items.append(lineage_object) + lineage_info = { 'customized': customized, - 'object': [ - item.to_dataset_lineage_dict() if search_type == LineageType.DATASET.value - else item.to_filtration_dict() for item in offset_results - ], + 'object': object_items, 'count': len(results) } return lineage_info + def _get_lineage_types(self, lineage_type_param): + """ + Get lineage types. + + Args: + lineage_type_param (dict): A dict contains "in" or "eq". + + Returns: + list, lineage type. + + """ + # lineage_type_param is None or an empty dict + if not lineage_type_param: + return enum_to_list(LineageType) + + if lineage_type_param.get("in") is not None: + return lineage_type_param.get("in") + + return [lineage_type_param.get("eq")] + def _is_valid_field(self, field_name): """ Check if field name is valid. diff --git a/mindinsight/lineagemgr/querier/query_model.py b/mindinsight/lineagemgr/querier/query_model.py index a70a215ba16b844fe52183ace2a79604bd46f19a..b797cdc111a7cdd73582e16d813e6526e072040a 100644 --- a/mindinsight/lineagemgr/querier/query_model.py +++ b/mindinsight/lineagemgr/querier/query_model.py @@ -38,6 +38,7 @@ FIELD_MAPPING = { "loss": Field('algorithm', 'loss'), "model_size": Field('model', 'size'), "dataset_mark": Field('dataset_mark', None), + "lineage_type": Field(None, None) } @@ -75,6 +76,7 @@ class LineageObj: _name_dataset_graph = 'dataset_graph' _name_dataset_mark = 'dataset_mark' _name_user_defined = 'user_defined' + _name_model_lineage = 'model_lineage' def __init__(self, summary_dir, **kwargs): self._lineage_info = { @@ -227,15 +229,6 @@ class LineageObj: result[key] = getattr(self, key) return result - def to_filtration_dict(self): - """ - Returns the lineage information required by filtering interface. - - Returns: - dict, the lineage information required by filtering interface. - """ - return self._filtration_result - def to_dataset_lineage_dict(self): """ Returns the dataset part lineage information. @@ -250,6 +243,22 @@ class LineageObj: return dataset_lineage + def to_model_lineage_dict(self): + """ + Returns the model part lineage information. + + Returns: + dict, the model lineage information. + """ + filtration_result = dict(self._filtration_result) + filtration_result.pop(self._name_dataset_graph) + + model_lineage = dict() + model_lineage.update({self._name_summary_dir: filtration_result.pop(self._name_summary_dir)}) + model_lineage.update({self._name_model_lineage: filtration_result}) + + return model_lineage + def get_value_by_key(self, key): """ Get the value based on the key in `FIELD_MAPPING` or diff --git a/tests/st/func/lineagemgr/api/test_model_api.py b/tests/st/func/lineagemgr/api/test_model_api.py index 8b2df937c9cfd3ce2becc21e168c520acdac56b6..95fa5180d1c499501049bd754295587bfdcc0dc2 100644 --- a/tests/st/func/lineagemgr/api/test_model_api.py +++ b/tests/st/func/lineagemgr/api/test_model_api.py @@ -20,7 +20,6 @@ Usage: The query module test should be run after lineagemgr/collection/model/test_model_lineage.py pytest lineagemgr """ - import os from unittest import TestCase @@ -66,64 +65,70 @@ LINEAGE_INFO_RUN1 = { } LINEAGE_FILTRATION_EXCEPT_RUN = { 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'except_run'), - 'loss_function': 'SoftmaxCrossEntropyWithLogits', - 'train_dataset_path': None, - 'train_dataset_count': 1024, - 'user_defined': {}, - 'test_dataset_path': None, - 'test_dataset_count': None, - 'network': 'ResNet', - 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, - 'epoch': 10, - 'batch_size': 32, - 'loss': 0.029999999329447746, - 'model_size': 64, - 'metric': {}, - 'dataset_graph': DATASET_GRAPH, - 'dataset_mark': 2 + 'model_lineage': { + 'loss_function': 'SoftmaxCrossEntropyWithLogits', + 'train_dataset_path': None, + 'train_dataset_count': 1024, + 'test_dataset_path': None, + 'test_dataset_count': None, + 'user_defined': {}, + 'network': 'ResNet', + 'optimizer': 'Momentum', + 'learning_rate': 0.11999999731779099, + 'epoch': 10, + 'batch_size': 32, + 'loss': 0.029999999329447746, + 'model_size': 64, + 'metric': {}, + 'dataset_mark': 2 + }, + 'dataset_graph': DATASET_GRAPH } LINEAGE_FILTRATION_RUN1 = { 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), - 'loss_function': 'SoftmaxCrossEntropyWithLogits', - 'train_dataset_path': None, - 'train_dataset_count': 731, - 'test_dataset_path': None, - 'user_defined': {}, - 'test_dataset_count': 10240, - 'network': 'ResNet', - 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, - 'epoch': 14, - 'batch_size': 32, - 'loss': None, - 'model_size': 64, - 'metric': { - 'accuracy': 0.78 + 'model_lineage': { + 'loss_function': 'SoftmaxCrossEntropyWithLogits', + 'train_dataset_path': None, + 'train_dataset_count': 731, + 'test_dataset_path': None, + 'test_dataset_count': 10240, + 'user_defined': {}, + 'network': 'ResNet', + 'optimizer': 'Momentum', + 'learning_rate': 0.11999999731779099, + 'epoch': 14, + 'batch_size': 32, + 'loss': None, + 'model_size': 64, + 'metric': { + 'accuracy': 0.78 + }, + 'dataset_mark': 2 }, - 'dataset_graph': DATASET_GRAPH, - 'dataset_mark': 2 + 'dataset_graph': DATASET_GRAPH } LINEAGE_FILTRATION_RUN2 = { 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run2'), - 'loss_function': None, - 'train_dataset_path': None, - 'train_dataset_count': None, - 'user_defined': {}, - 'test_dataset_path': None, - 'test_dataset_count': 10240, - 'network': None, - 'optimizer': None, - 'learning_rate': None, - 'epoch': None, - 'batch_size': None, - 'loss': None, - 'model_size': None, - 'metric': { - 'accuracy': 2.7800000000000002 + 'model_lineage': { + 'loss_function': None, + 'train_dataset_path': None, + 'train_dataset_count': None, + 'test_dataset_path': None, + 'test_dataset_count': 10240, + 'user_defined': {}, + 'network': None, + 'optimizer': None, + 'learning_rate': None, + 'epoch': None, + 'batch_size': None, + 'loss': None, + 'model_size': None, + 'metric': { + 'accuracy': 2.7800000000000002 + }, + 'dataset_mark': 3 }, - 'dataset_graph': {}, - 'dataset_mark': 3 + 'dataset_graph': {} } @@ -150,6 +155,14 @@ class TestModelApi(TestCase): cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir') os.makedirs(cls.empty_dir) + def generate_lineage_object(self, lineage): + lineage = dict(lineage) + lineage_object = dict() + lineage_object.update({'summary_dir': lineage.pop('summary_dir')}) + lineage_object.update({'dataset_graph': lineage.pop('dataset_graph')}) + lineage_object.update({'model_lineage': lineage}) + return lineage_object + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_gpu_training @@ -337,7 +350,7 @@ class TestModelApi(TestCase): res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == res expect_result = { @@ -347,7 +360,7 @@ class TestModelApi(TestCase): res = filter_summary_lineage(self.dir_with_empty_lineage) expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == res @pytest.mark.level0 @@ -385,7 +398,7 @@ class TestModelApi(TestCase): partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == partial_res @pytest.mark.level0 @@ -423,7 +436,7 @@ class TestModelApi(TestCase): partial_res = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == partial_res @pytest.mark.level0 @@ -439,7 +452,6 @@ class TestModelApi(TestCase): 'ge': 30 }, 'sorted_name': 'metric/accuracy', - 'lineage_type': None } expect_result = { 'customized': event_data.CUSTOMIZED__0, @@ -452,14 +464,16 @@ class TestModelApi(TestCase): partial_res1 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition1) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res1.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == partial_res1 search_condition2 = { 'batch_size': { 'lt': 30 }, - 'lineage_type': 'model' + 'lineage_type': { + 'eq': 'model' + }, } expect_result = { 'customized': {}, @@ -469,7 +483,7 @@ class TestModelApi(TestCase): partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res2.get('object')): - expect_objects[idx]['dataset_mark'] = res_object.get('dataset_mark') + expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == partial_res2 @pytest.mark.level0 @@ -485,7 +499,9 @@ class TestModelApi(TestCase): 'summary_dir': { 'in': [summary_dir] }, - 'lineage_type': 'dataset' + 'lineage_type': { + 'eq': 'dataset' + }, } expect_result = { 'customized': {}, @@ -705,15 +721,29 @@ class TestModelApi(TestCase): search_condition ) - # the condition type not supported in summary dir search_condition = { - 'summary_dir': { - 'lt': '/xxx' + 'lineage_type': { + 'in': [ + 'xxx' + ] } } self.assertRaisesRegex( - LineageParamSummaryPathError, - 'Invalid operation of summary dir.', + LineageSearchConditionParamError, + "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", + filter_summary_lineage, + BASE_SUMMARY_DIR, + search_condition + ) + + search_condition = { + 'lineage_type': { + 'eq': None + } + } + self.assertRaisesRegex( + LineageSearchConditionParamError, + "The parameter lineage_type is invalid. It should be 'dataset' or 'model'.", filter_summary_lineage, BASE_SUMMARY_DIR, search_condition @@ -779,3 +809,42 @@ class TestModelApi(TestCase): } partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) assert expect_result == partial_res2 + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.platform_x86_cpu + @pytest.mark.env_single + def test_filter_summary_lineage_exception_7(self): + """Test the abnormal execution of the filter_summary_lineage interface.""" + condition_keys = ["summary_dir", "lineage_type"] + for condition_key in condition_keys: + # the condition type not supported in summary_dir and lineage_type + search_condition = { + condition_key: { + 'lt': '/xxx' + } + } + self.assertRaisesRegex( + LineageSearchConditionParamError, + f'Invalid operation of {condition_key}.', + filter_summary_lineage, + BASE_SUMMARY_DIR, + search_condition + ) + + # more than one operation in summary_dir and lineage_type + search_condition = { + condition_key: { + 'in': ['/xxx', '/yyy'], + 'eq': '/zzz', + } + } + self.assertRaisesRegex( + LineageSearchConditionParamError, + f'More than one operation of {condition_key}.', + filter_summary_lineage, + BASE_SUMMARY_DIR, + search_condition + ) diff --git a/tests/ut/backend/lineagemgr/test_lineage_api.py b/tests/ut/backend/lineagemgr/test_lineage_api.py index 5e80ff439487a90a7b8f7996b5e853247942f49f..8ff5b222f67fb5cb6fb6c604b7a2459047883530 100644 --- a/tests/ut/backend/lineagemgr/test_lineage_api.py +++ b/tests/ut/backend/lineagemgr/test_lineage_api.py @@ -67,7 +67,7 @@ class TestSearchModel(TestCase): """Test init.""" APP.response_class = Response self.app_client = APP.test_client() - self.url = '/v1/mindinsight/models/model_lineage' + self.url = '/v1/mindinsight/lineagemgr/lineages' @mock.patch('mindinsight.backend.lineagemgr.lineage_api.settings') @mock.patch('mindinsight.backend.lineagemgr.lineage_api.filter_summary_lineage') @@ -78,11 +78,11 @@ class TestSearchModel(TestCase): 'object': [ { 'summary_dir': base_dir, - **LINEAGE_FILTRATION_BASE + 'model_lineage': LINEAGE_FILTRATION_BASE }, { 'summary_dir': os.path.join(base_dir, 'run1'), - **LINEAGE_FILTRATION_RUN1 + 'model_lineage': LINEAGE_FILTRATION_RUN1 } ], 'count': 2 @@ -101,11 +101,11 @@ class TestSearchModel(TestCase): 'object': [ { 'summary_dir': './', - **LINEAGE_FILTRATION_BASE + 'model_lineage': LINEAGE_FILTRATION_BASE }, { 'summary_dir': './run1', - **LINEAGE_FILTRATION_RUN1 + 'model_lineage': LINEAGE_FILTRATION_RUN1 } ], 'count': 2 diff --git a/tests/ut/lineagemgr/api/test_model.py b/tests/ut/lineagemgr/api/test_model.py index 6dccde6a2022c5dc05262d5100251b7bfb4aff4f..7e7b442095cb17c9a5cc941e6f797e88b8789687 100644 --- a/tests/ut/lineagemgr/api/test_model.py +++ b/tests/ut/lineagemgr/api/test_model.py @@ -131,18 +131,6 @@ class TestModel(TestCase): self.assertDictEqual( result, search_condition ) - search_condition = { - 'summary_dir': { - 'gt': 3 - } - } - self.assertRaisesRegex( - LineageParamValueError, - 'Invalid operation of summary dir', - _convert_relative_path_to_abspath, - summary_base_dir, - search_condition - ) class TestFilterAPI(TestCase): diff --git a/tests/ut/lineagemgr/querier/test_querier.py b/tests/ut/lineagemgr/querier/test_querier.py index 8ededa0ed02b2bd2950ceb1c658cce8e413285a8..6059b246252a4f7234fdac22361e4bc259170771 100644 --- a/tests/ut/lineagemgr/querier/test_querier.py +++ b/tests/ut/lineagemgr/querier/test_querier.py @@ -82,22 +82,24 @@ def create_filtration_result(summary_dir, train_event_dict, """ filtration_result = { "summary_dir": summary_dir, - "loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'], - "train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'], - "train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'], - "test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'], - "test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'], - "network": train_event_dict['train_lineage']['algorithm']['network'], - "optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'], - "learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'], - "epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'], - "batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'], - "loss": train_event_dict['train_lineage']['algorithm']['loss'], - "model_size": train_event_dict['train_lineage']['model']['size'], - "metric": metric_dict, + "model_lineage": { + "loss_function": train_event_dict['train_lineage']['hyper_parameters']['loss_function'], + "train_dataset_path": train_event_dict['train_lineage']['train_dataset']['train_dataset_path'], + "train_dataset_count": train_event_dict['train_lineage']['train_dataset']['train_dataset_size'], + "test_dataset_path": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_path'], + "test_dataset_count": eval_event_dict['evaluation_lineage']['valid_dataset']['valid_dataset_size'], + "network": train_event_dict['train_lineage']['algorithm']['network'], + "optimizer": train_event_dict['train_lineage']['hyper_parameters']['optimizer'], + "learning_rate": train_event_dict['train_lineage']['hyper_parameters']['learning_rate'], + "epoch": train_event_dict['train_lineage']['hyper_parameters']['epoch'], + "batch_size": train_event_dict['train_lineage']['hyper_parameters']['batch_size'], + "loss": train_event_dict['train_lineage']['algorithm']['loss'], + "model_size": train_event_dict['train_lineage']['model']['size'], + "metric": metric_dict, + "dataset_mark": '2', + "user_defined": {} + }, "dataset_graph": dataset_dict, - "dataset_mark": '2', - "user_defined": {} } return filtration_result @@ -192,47 +194,50 @@ LINEAGE_FILTRATION_4 = create_filtration_result( ) LINEAGE_FILTRATION_5 = { "summary_dir": '/path/to/summary5', - "loss_function": - event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], - "train_dataset_path": None, - "train_dataset_count": - event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], - "test_dataset_path": None, - "test_dataset_count": None, - "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], - "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], - "learning_rate": - event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], - "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], - "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], - "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], - "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], - "metric": {}, - "dataset_graph": event_data.DATASET_DICT_0, - "dataset_mark": '2', - "user_defined": {} - + "model_lineage": { + "loss_function": + event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['loss_function'], + "train_dataset_path": None, + "train_dataset_count": + event_data.EVENT_TRAIN_DICT_5['train_lineage']['train_dataset']['train_dataset_size'], + "test_dataset_path": None, + "test_dataset_count": None, + "network": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['network'], + "optimizer": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['optimizer'], + "learning_rate": + event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['learning_rate'], + "epoch": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['epoch'], + "batch_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['hyper_parameters']['batch_size'], + "loss": event_data.EVENT_TRAIN_DICT_5['train_lineage']['algorithm']['loss'], + "model_size": event_data.EVENT_TRAIN_DICT_5['train_lineage']['model']['size'], + "metric": {}, + "dataset_mark": '2', + "user_defined": {} + }, + "dataset_graph": event_data.DATASET_DICT_0 } LINEAGE_FILTRATION_6 = { "summary_dir": '/path/to/summary6', - "loss_function": None, - "train_dataset_path": None, - "train_dataset_count": None, - "test_dataset_path": - event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], - "test_dataset_count": - event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], - "network": None, - "optimizer": None, - "learning_rate": None, - "epoch": None, - "batch_size": None, - "loss": None, - "model_size": None, - "metric": event_data.METRIC_5, - "dataset_graph": event_data.DATASET_DICT_0, - "dataset_mark": '2', - "user_defined": {} + "model_lineage": { + "loss_function": None, + "train_dataset_path": None, + "train_dataset_count": None, + "test_dataset_path": + event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_path'], + "test_dataset_count": + event_data.EVENT_EVAL_DICT_5['evaluation_lineage']['valid_dataset']['valid_dataset_size'], + "network": None, + "optimizer": None, + "learning_rate": None, + "epoch": None, + "batch_size": None, + "loss": None, + "model_size": None, + "metric": event_data.METRIC_5, + "dataset_mark": '2', + "user_defined": {} + }, + "dataset_graph": event_data.DATASET_DICT_0 } diff --git a/tests/ut/lineagemgr/querier/test_query_model.py b/tests/ut/lineagemgr/querier/test_query_model.py index bcf94ef24a3c3afb22f9d63c73f3ec6568968765..ed89ff40cb76e1425274a0cf4b3c848caecdbe45 100644 --- a/tests/ut/lineagemgr/querier/test_query_model.py +++ b/tests/ut/lineagemgr/querier/test_query_model.py @@ -108,8 +108,8 @@ class TestLineageObj(TestCase): result = self.lineage_obj.get_summary_info(filter_keys) self.assertDictEqual(expected_result, result) - def test_to_filtration_dict(self): - """Test the function of to_filtration_dict.""" + def test_to_model_lineage_dict(self): + """Test the function of to_model_lineage_dict.""" expected_result = create_filtration_result( self.summary_dir, event_data.EVENT_TRAIN_DICT_0, @@ -117,8 +117,18 @@ class TestLineageObj(TestCase): event_data.METRIC_0, event_data.DATASET_DICT_0 ) - expected_result['dataset_mark'] = None - result = self.lineage_obj.to_filtration_dict() + expected_result['model_lineage']['dataset_mark'] = None + expected_result.pop('dataset_graph') + result = self.lineage_obj.to_model_lineage_dict() + self.assertDictEqual(expected_result, result) + + def test_to_dataset_lineage_dict(self): + """Test the function of to_dataset_lineage_dict.""" + expected_result = { + "summary_dir": self.summary_dir, + "dataset_graph": event_data.DATASET_DICT_0 + } + result = self.lineage_obj.to_dataset_lineage_dict() self.assertDictEqual(expected_result, result) def test_get_value_by_key(self):