提交 d15c81c6 编写于 作者: L luopengting

remove invalid item in user_defined info, record valid items

上级 c4fc9bfb
...@@ -284,8 +284,8 @@ class EvalLineage(Callback): ...@@ -284,8 +284,8 @@ class EvalLineage(Callback):
self.lineage_summary = LineageSummary(self.lineage_log_dir) self.lineage_summary = LineageSummary(self.lineage_log_dir)
self.user_defined_info = user_defined_info self.user_defined_info = user_defined_info
if user_defined_info: if self.user_defined_info:
validate_user_defined_info(user_defined_info) validate_user_defined_info(self.user_defined_info)
except MindInsightException as err: except MindInsightException as err:
log.error(err) log.error(err)
......
...@@ -410,7 +410,7 @@ def validate_path(summary_path): ...@@ -410,7 +410,7 @@ def validate_path(summary_path):
def validate_user_defined_info(user_defined_info): def validate_user_defined_info(user_defined_info):
""" """
Validate user defined info. Validate user defined info, delete the item if its key is in lineage.
Args: Args:
user_defined_info (dict): The user defined info. user_defined_info (dict): The user defined info.
...@@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info): ...@@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info):
field_map = set(FIELD_MAPPING.keys()) field_map = set(FIELD_MAPPING.keys())
user_defined_keys = set(user_defined_info.keys()) user_defined_keys = set(user_defined_info.keys())
all_keys = field_map | user_defined_keys insertion = list(field_map & user_defined_keys)
if len(field_map) + len(user_defined_keys) != len(all_keys): if insertion:
raise LineageParamValueError("There are some keys have defined in lineage.") for key in insertion:
user_defined_info.pop(key)
raise LineageParamValueError("There are some keys have defined in lineage. "
"Duplicated key(s): %s. " % insertion)
def validate_train_id(relative_path): def validate_train_id(relative_path):
......
...@@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = { ...@@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 1024, 'test_dataset_count': 1024,
'user_defined': {}, 'user_defined': {'info': 'info1', 'version': 'v1'},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099, 'learning_rate': 0.11999999731779099,
...@@ -329,7 +329,7 @@ class TestModelApi(TestCase): ...@@ -329,7 +329,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self): def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
...@@ -383,7 +383,7 @@ class TestModelApi(TestCase): ...@@ -383,7 +383,7 @@ class TestModelApi(TestCase):
'offset': 0 'offset': 0
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1 LINEAGE_FILTRATION_RUN1
...@@ -421,7 +421,7 @@ class TestModelApi(TestCase): ...@@ -421,7 +421,7 @@ class TestModelApi(TestCase):
'offset': 0 'offset': 0
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_RUN2, LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1 LINEAGE_FILTRATION_RUN1
...@@ -449,7 +449,7 @@ class TestModelApi(TestCase): ...@@ -449,7 +449,7 @@ class TestModelApi(TestCase):
'sorted_name': 'metric/accuracy', 'sorted_name': 'metric/accuracy',
} }
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
......
...@@ -70,7 +70,7 @@ class TestModelApi(TestCase): ...@@ -70,7 +70,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self): def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage.""" """Test the interface of filter_summary_lineage."""
expect_result = { expect_result = {
'customized': event_data.CUSTOMIZED__0, 'customized': event_data.CUSTOMIZED__1,
'object': [ 'object': [
LINEAGE_FILTRATION_EXCEPT_RUN, LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN1,
......
...@@ -28,7 +28,7 @@ from unittest import mock, TestCase ...@@ -28,7 +28,7 @@ from unittest import mock, TestCase
import numpy as np import numpy as np
import pytest import pytest
from mindinsight.lineagemgr import get_summary_lineage from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \ from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory from mindinsight.lineagemgr.common.utils import make_directory
...@@ -109,6 +109,36 @@ class TestModelLineage(TestCase): ...@@ -109,6 +109,36 @@ class TestModelLineage(TestCase):
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_key_in_lineage(self):
"""Test TrainLineage with nested user defined info."""
expected_res = {
"info": "info1",
"version": "v1"
}
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert expected_res == res['object'][0]['model_lineage']['user_defined']
@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
......
...@@ -192,6 +192,12 @@ CUSTOMIZED__0 = { ...@@ -192,6 +192,12 @@ CUSTOMIZED__0 = {
'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'}, 'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'},
} }
CUSTOMIZED__1 = {
**CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}
}
CUSTOMIZED_0 = { CUSTOMIZED_0 = {
**CUSTOMIZED__0, **CUSTOMIZED__0,
'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'}, 'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册