提交 d15c81c6 编写于 作者: L luopengting

remove invalid item in user_defined info, record valid items

上级 c4fc9bfb
......@@ -284,8 +284,8 @@ class EvalLineage(Callback):
self.lineage_summary = LineageSummary(self.lineage_log_dir)
self.user_defined_info = user_defined_info
if user_defined_info:
validate_user_defined_info(user_defined_info)
if self.user_defined_info:
validate_user_defined_info(self.user_defined_info)
except MindInsightException as err:
log.error(err)
......
......@@ -410,7 +410,7 @@ def validate_path(summary_path):
def validate_user_defined_info(user_defined_info):
"""
Validate user defined info.
Validate user defined info, delete the item if its key is in lineage.
Args:
user_defined_info (dict): The user defined info.
......@@ -437,10 +437,13 @@ def validate_user_defined_info(user_defined_info):
field_map = set(FIELD_MAPPING.keys())
user_defined_keys = set(user_defined_info.keys())
all_keys = field_map | user_defined_keys
insertion = list(field_map & user_defined_keys)
if len(field_map) + len(user_defined_keys) != len(all_keys):
raise LineageParamValueError("There are some keys have defined in lineage.")
if insertion:
for key in insertion:
user_defined_info.pop(key)
raise LineageParamValueError("There are some keys have defined in lineage. "
"Duplicated key(s): %s. " % insertion)
def validate_train_id(relative_path):
......
......@@ -92,7 +92,7 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024,
'test_dataset_path': None,
'test_dataset_count': 1024,
'user_defined': {},
'user_defined': {'info': 'info1', 'version': 'v1'},
'network': 'ResNet',
'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099,
......@@ -329,7 +329,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
......@@ -383,7 +383,7 @@ class TestModelApi(TestCase):
'offset': 0
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
......@@ -421,7 +421,7 @@ class TestModelApi(TestCase):
'offset': 0
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_RUN2,
LINEAGE_FILTRATION_RUN1
......@@ -449,7 +449,7 @@ class TestModelApi(TestCase):
'sorted_name': 'metric/accuracy',
}
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
......
......@@ -70,7 +70,7 @@ class TestModelApi(TestCase):
def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'customized': event_data.CUSTOMIZED__1,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
......
......@@ -28,7 +28,7 @@ from unittest import mock, TestCase
import numpy as np
import pytest
from mindinsight.lineagemgr import get_summary_lineage
from mindinsight.lineagemgr import get_summary_lineage, filter_summary_lineage
from mindinsight.lineagemgr.collection.model.model_lineage import TrainLineage, EvalLineage, \
AnalyzeObject
from mindinsight.lineagemgr.common.utils import make_directory
......@@ -109,6 +109,36 @@ class TestModelLineage(TestCase):
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_key_in_lineage(self):
"""Test TrainLineage with nested user defined info."""
expected_res = {
"info": "info1",
"version": "v1"
}
user_defined_info = {
"info": "info1",
"version": "v1",
"network": "LeNet"
}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert expected_res == res['object'][0]['model_lineage']['user_defined']
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
......
......@@ -192,6 +192,12 @@ CUSTOMIZED__0 = {
'metric/accuracy': {'label': 'metric/accuracy', 'required': True, 'type': 'float'},
}
CUSTOMIZED__1 = {
**CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}
}
CUSTOMIZED_0 = {
**CUSTOMIZED__0,
'metric/mae': {'label': 'metric/mae', 'required': True, 'type': 'float'},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册