提交 98ec661b 编写于 作者: L luopengting

clean lineage when train lineage comes in, add st cases for lineage cache

上级 750385ad
...@@ -90,6 +90,9 @@ def general_get_summary_lineage(data_manager=None, summary_dir=None, keys=None): ...@@ -90,6 +90,9 @@ def general_get_summary_lineage(data_manager=None, summary_dir=None, keys=None):
default_result = {} default_result = {}
if data_manager is None and summary_dir is None: if data_manager is None and summary_dir is None:
raise LineageParamTypeError("One of data_manager or summary_dir needs to be specified.") raise LineageParamTypeError("One of data_manager or summary_dir needs to be specified.")
if data_manager is not None and summary_dir is None:
raise LineageParamTypeError("If data_manager is specified, the summary_dir needs to be "
"specified as relative path.")
if keys is not None: if keys is not None:
validate_filter_key(keys) validate_filter_key(keys)
......
...@@ -497,8 +497,7 @@ def validate_added_info(added_info: dict): ...@@ -497,8 +497,7 @@ def validate_added_info(added_info: dict):
""" """
added_info_keys = ["tag", "remark"] added_info_keys = ["tag", "remark"]
if not set(added_info.keys()).issubset(added_info_keys): if not set(added_info.keys()).issubset(added_info_keys):
err_msg = "Keys must be in {}.".format(added_info_keys) err_msg = "Keys of added_info must be in {}.".format(added_info_keys)
log.error(err_msg)
raise LineageParamValueError(err_msg) raise LineageParamValueError(err_msg)
for key, value in added_info.items(): for key, value in added_info.items():
......
...@@ -112,6 +112,11 @@ class LineageObj: ...@@ -112,6 +112,11 @@ class LineageObj:
dataset_graph = kwargs.get('dataset_graph') dataset_graph = kwargs.get('dataset_graph')
if not any([train_lineage, evaluation_lineage, dataset_graph]): if not any([train_lineage, evaluation_lineage, dataset_graph]):
raise LineageEventNotExistException() raise LineageEventNotExistException()
# If new train lineage, will clean the lineage saved before.
if train_lineage is not None or dataset_graph is not None:
self._init_lineage()
self._parse_user_defined_info(user_defined_info_list) self._parse_user_defined_info(user_defined_info_list)
self._parse_train_lineage(train_lineage) self._parse_train_lineage(train_lineage)
self._parse_evaluation_lineage(evaluation_lineage) self._parse_evaluation_lineage(evaluation_lineage)
......
...@@ -157,14 +157,6 @@ class TestModelApi(TestCase): ...@@ -157,14 +157,6 @@ class TestModelApi(TestCase):
cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir') cls.empty_dir = os.path.join(BASE_SUMMARY_DIR, 'empty_dir')
os.makedirs(cls.empty_dir) os.makedirs(cls.empty_dir)
def generate_lineage_object(self, lineage):
lineage = dict(lineage)
lineage_object = dict()
lineage_object.update({'summary_dir': lineage.pop('summary_dir')})
lineage_object.update({'dataset_graph': lineage.pop('dataset_graph')})
lineage_object.update({'model_lineage': lineage})
return lineage_object
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Function:
Test the query module about lineage information.
Usage:
The query module test should be run after lineagemgr/collection/model/test_model_lineage.py
pytest lineagemgr
"""
from unittest import TestCase
import pytest
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
from mindinsight.lineagemgr.api.model import general_filter_summary_lineage, \
general_get_summary_lineage
from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RUN, \
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2
from ..conftest import BASE_SUMMARY_DIR
from .....ut.lineagemgr.querier import event_data
from .....utils.tools import check_loading_done
@pytest.mark.usefixtures("create_summary_dir")
class TestModelApi(TestCase):
"""Test get lineage from data_manager."""
@classmethod
def setup_class(cls):
data_manager = DataManager(BASE_SUMMARY_DIR)
data_manager.register_brief_cache_item_updater(LineageCacheItemUpdater())
data_manager.start_load_data(reload_interval=0)
check_loading_done(data_manager)
cls._data_manger = data_manager
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_get_summary_lineage(self):
"""Test the interface of get_summary_lineage."""
total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1")
expect_total_res = LINEAGE_INFO_RUN1
assert expect_total_res == total_res
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_filter_summary_lineage(self):
"""Test the interface of filter_summary_lineage."""
expect_result = {
'customized': event_data.CUSTOMIZED__0,
'object': [
LINEAGE_FILTRATION_EXCEPT_RUN,
LINEAGE_FILTRATION_RUN1,
LINEAGE_FILTRATION_RUN2
],
'count': 3
}
search_condition = {
'sorted_name': 'summary_dir'
}
res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res
expect_result = {
'customized': {},
'object': [],
'count': 0
}
search_condition = {
'summary_dir': {
"in": ['./dir_with_empty_lineage']
}
}
res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition)
assert expect_result == res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册