diff --git a/mindinsight/lineagemgr/collection/model/model_lineage.py b/mindinsight/lineagemgr/collection/model/model_lineage.py index 6b4547cf54b902f7117de6be6d46ecb6d7735580..4272d0addfa4d9cac70b669e838c3d80230d603a 100644 --- a/mindinsight/lineagemgr/collection/model/model_lineage.py +++ b/mindinsight/lineagemgr/collection/model/model_lineage.py @@ -284,8 +284,8 @@ class EvalLineage(Callback): self.lineage_summary = LineageSummary(self.lineage_log_dir) self.user_defined_info = user_defined_info - if user_defined_info: - validate_user_defined_info(user_defined_info) + if self.user_defined_info: + validate_user_defined_info(self.user_defined_info) except MindInsightException as err: log.error(err) diff --git a/mindinsight/lineagemgr/common/validator/validate.py b/mindinsight/lineagemgr/common/validator/validate.py index 91c71b22f59fcb5f3960c919985fead63a4077fa..c04a174ea256ca0395328f7714fa6dcd10d29743 100644 --- a/mindinsight/lineagemgr/common/validator/validate.py +++ b/mindinsight/lineagemgr/common/validator/validate.py @@ -410,7 +410,7 @@ def validate_path(summary_path): def validate_user_defined_info(user_defined_info): """ - Validate user defined info. + Validate user defined info, delete the item if its key is in lineage. Args: user_defined_info (dict): The user defined info. @@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info): field_map = set(FIELD_MAPPING.keys()) user_defined_keys = set(user_defined_info.keys()) - all_keys = field_map | user_defined_keys + insertion = list(field_map & user_defined_keys) - if len(field_map) + len(user_defined_keys) != len(all_keys): - raise LineageParamValueError("There are some keys have defined in lineage.") + if insertion: + for key in insertion: + user_defined_info.pop(key) + raise LineageParamValueError("There are some keys have defined in lineage. " + "Duplicated key(s): %s. " % insertion) def validate_train_id(relative_path): diff --git a/tests/st/func/lineagemgr/api/test_model_api.py b/tests/st/func/lineagemgr/api/test_model_api.py index c8aea9dbabbb0e27adbcbbd574cdfb5600e990a8..6da892f4773f9c80446b04523b58600f8078bfd3 100644 --- a/tests/st/func/lineagemgr/api/test_model_api.py +++ b/tests/st/func/lineagemgr/api/test_model_api.py @@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = { 'train_dataset_count': 1024, 'test_dataset_path': None, 'test_dataset_count': 1024, - 'user_defined': {}, + 'user_defined': {'info': 'info1', 'version': 'v1'}, 'network': 'ResNet', 'optimizer': 'Momentum', 'learning_rate': 0.11999999731779099, @@ -329,7 +329,7 @@ class TestModelApi(TestCase): def test_filter_summary_lineage(self): """Test the interface of filter_summary_lineage.""" expect_result = { - 'customized': event_data.CUSTOMIZED__0, + 'customized': event_data.CUSTOMIZED__1, 'object': [ LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_RUN1, @@ -383,7 +383,7 @@ class TestModelApi(TestCase): 'offset': 0 } expect_result = { - 'customized': event_data.CUSTOMIZED__0, + 'customized': event_data.CUSTOMIZED__1, 'object': [ LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN1 @@ -421,7 +421,7 @@ class TestModelApi(TestCase): 'offset': 0 } expect_result = { - 'customized': event_data.CUSTOMIZED__0, + 'customized': event_data.CUSTOMIZED__1, 'object': [ LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN1 @@ -449,7 +449,7 @@ class TestModelApi(TestCase): 'sorted_name': 'metric/accuracy', } expect_result = { - 'customized': event_data.CUSTOMIZED__0, + 'customized': event_data.CUSTOMIZED__1, 'object': [ LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_RUN1, diff --git a/tests/st/func/lineagemgr/cache/test_lineage_cache.py b/tests/st/func/lineagemgr/cache/test_lineage_cache.py index 30c6f08e4b6c3a534072239349bcbca5eb29ee2b..17e60f1f02eee4e3fb7c3c867bb6753b46ee9ae3 100644 --- a/tests/st/func/lineagemgr/cache/test_lineage_cache.py +++ b/tests/st/func/lineagemgr/cache/test_lineage_cache.py @@ -70,7 +70,7 @@ class TestModelApi(TestCase): def test_filter_summary_lineage(self): """Test the interface of filter_summary_lineage.""" expect_result = { - 'customized': event_data.CUSTOMIZED__0, + 'customized': event_data.CUSTOMIZED__1, 'object': [ LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_RUN1, diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index 866a2467f321d3a6244f05712a568fd414e8dc15..e048eac934eb1c83a5b2cc57d54e19afa0732b73 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -28,7 +28,7 @@ from unittest import mock, TestCase import numpy as np import pytest -from mindinsight.lineagemgr import get_summary_lineage +from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ AnalyzeObject from mindinsight.lineagemgr.common.utils import make_directory @@ -109,6 +109,36 @@ class TestModelLineage(TestCase): lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True + @pytest.mark.scene_train(2) + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend_training + @pytest.mark.platform_x86_gpu_training + @pytest.mark.platform_x86_ascend_training + @pytest.mark.platform_x86_cpu + @pytest.mark.env_single + def test_train_begin_with_user_defined_key_in_lineage(self): + """Test TrainLineage with nested user defined info.""" + expected_res = { + "info": "info1", + "version": "v1" + } + user_defined_info = { + "info": "info1", + "version": "v1", + "network": "LeNet" + } + train_callback = TrainLineage( + self.summary_record, + False, + user_defined_info + ) + train_callback.begin(RunContext(self.run_context)) + assert train_callback.initial_learning_rate == 0.12 + lineage_log_path = train_callback.lineage_summary.lineage_log_path + assert os.path.isfile(lineage_log_path) is True + res = filter_summary_lineage(os.path.dirname(lineage_log_path)) + assert expected_res == res['object'][0]['model_lineage']['user_defined'] + @pytest.mark.scene_train(2) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training diff --git a/tests/ut/lineagemgr/querier/event_data.py b/tests/ut/lineagemgr/querier/event_data.py index 7e606fda6458081bfb92f30a5b64c3a484c0b8b7..a46ecfbc8e90e8e727a001aabc12c3a68212538d 100644 --- a/tests/ut/lineagemgr/querier/event_data.py +++ b/tests/ut/lineagemgr/querier/event_data.py @@ -192,6 +192,12 @@ CUSTOMIZED__0 = { 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, } +CUSTOMIZED__1 = { + **CUSTOMIZED__0, + 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, + 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'} +} + CUSTOMIZED_0 = { **CUSTOMIZED__0, 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'},