提交 a1c35761 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!231 enhance float cmp in tests.lineagemgr, fix probabilistic failure in st

Merge pull request !231 from luopengting/lineage_parsing
...@@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF ...@@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF
LineageSearchConditionParamError) LineageSearchConditionParamError)
from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import assert_equal_lineages
LINEAGE_INFO_RUN1 = { LINEAGE_INFO_RUN1 = {
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
...@@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = { ...@@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = {
}, },
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099, 'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
...@@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = { ...@@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = {
'user_defined': {}, 'user_defined': {},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099, 'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746, 'loss': 0.03,
'model_size': 64, 'model_size': 64,
'metric': {}, 'metric': {},
'dataset_mark': 2 'dataset_mark': 2
...@@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = { ...@@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = {
'train_dataset_count': 1024, 'train_dataset_count': 1024,
'test_dataset_path': None, 'test_dataset_path': None,
'test_dataset_count': 1024, 'test_dataset_count': 1024,
'user_defined': {'info': 'info1', 'version': 'v1'}, 'user_defined': {
'info': 'info1',
'version': 'v1',
'eval_version': 'version2'
},
'network': 'ResNet', 'network': 'ResNet',
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099, 'learning_rate': 0.12,
'epoch': 14, 'epoch': 14,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
...@@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = { ...@@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = {
'user_defined': {}, 'user_defined': {},
'network': "ResNet", 'network': "ResNet",
'optimizer': "Momentum", 'optimizer': "Momentum",
'learning_rate': 0.11999999731779099, 'learning_rate': 0.12,
'epoch': 10, 'epoch': 10,
'batch_size': 32, 'batch_size': 32,
'device_num': 2, 'device_num': 2,
'loss': 0.029999999329447746, 'loss': 0.03,
'model_size': 10, 'model_size': 10,
'metric': { 'metric': {
'accuracy': 2.7800000000000002 'accuracy': 2.78
}, },
'dataset_mark': 3 'dataset_mark': 3
}, },
...@@ -173,7 +178,7 @@ class TestModelApi(TestCase): ...@@ -173,7 +178,7 @@ class TestModelApi(TestCase):
'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'Momentum', 'optimizer': 'Momentum',
'learning_rate': 0.11999999731779099, 'learning_rate': 0.12,
'loss_function': 'SoftmaxCrossEntropyWithLogits', 'loss_function': 'SoftmaxCrossEntropyWithLogits',
'epoch': 14, 'epoch': 14,
'parallel_mode': 'stand_alone', 'parallel_mode': 'stand_alone',
...@@ -190,9 +195,9 @@ class TestModelApi(TestCase): ...@@ -190,9 +195,9 @@ class TestModelApi(TestCase):
'network': 'ResNet' 'network': 'ResNet'
} }
} }
assert expect_total_res == total_res assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)
assert expect_partial_res1 == partial_res1 assert_equal_lineages(expect_partial_res1, partial_res1, self.assertDictEqual)
assert expect_partial_res2 == partial_res2 assert_equal_lineages(expect_partial_res2, partial_res2, self.assertDictEqual)
# the lineage summary file is empty # the lineage summary file is empty
result = get_summary_lineage(self.dir_with_empty_lineage) result = get_summary_lineage(self.dir_with_empty_lineage)
...@@ -345,7 +350,7 @@ class TestModelApi(TestCase): ...@@ -345,7 +350,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res assert_equal_lineages(expect_result, res, self.assertDictEqual)
expect_result = { expect_result = {
'customized': {}, 'customized': {},
...@@ -356,7 +361,7 @@ class TestModelApi(TestCase): ...@@ -356,7 +361,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res assert_equal_lineages(expect_result, res, self.assertDictEqual)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
...@@ -394,7 +399,7 @@ class TestModelApi(TestCase): ...@@ -394,7 +399,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
...@@ -432,7 +437,7 @@ class TestModelApi(TestCase): ...@@ -432,7 +437,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res.get('object')): for idx, res_object in enumerate(partial_res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res assert_equal_lineages(expect_result, partial_res, self.assertDictEqual)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
...@@ -461,7 +466,7 @@ class TestModelApi(TestCase): ...@@ -461,7 +466,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res1.get('object')): for idx, res_object in enumerate(partial_res1.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res1 assert_equal_lineages(expect_result, partial_res1, self.assertDictEqual)
search_condition2 = { search_condition2 = {
'batch_size': { 'batch_size': {
...@@ -477,9 +482,6 @@ class TestModelApi(TestCase): ...@@ -477,9 +482,6 @@ class TestModelApi(TestCase):
'count': 0 'count': 0
} }
partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2)
expect_objects = expect_result.get('object')
for idx, res_object in enumerate(partial_res2.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == partial_res2 assert expect_result == partial_res2
@pytest.mark.level0 @pytest.mark.level0
......
...@@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU ...@@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU
LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2
from ..conftest import BASE_SUMMARY_DIR from ..conftest import BASE_SUMMARY_DIR
from .....ut.lineagemgr.querier import event_data from .....ut.lineagemgr.querier import event_data
from .....utils.tools import check_loading_done from .....utils.tools import check_loading_done, assert_equal_lineages
@pytest.mark.usefixtures("create_summary_dir") @pytest.mark.usefixtures("create_summary_dir")
...@@ -58,8 +58,7 @@ class TestModelApi(TestCase): ...@@ -58,8 +58,7 @@ class TestModelApi(TestCase):
"""Test the interface of get_summary_lineage.""" """Test the interface of get_summary_lineage."""
total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1") total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1")
expect_total_res = LINEAGE_INFO_RUN1 expect_total_res = LINEAGE_INFO_RUN1
assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual)
assert expect_total_res == total_res
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_arm_ascend_training @pytest.mark.platform_arm_ascend_training
...@@ -86,7 +85,7 @@ class TestModelApi(TestCase): ...@@ -86,7 +85,7 @@ class TestModelApi(TestCase):
expect_objects = expect_result.get('object') expect_objects = expect_result.get('object')
for idx, res_object in enumerate(res.get('object')): for idx, res_object in enumerate(res.get('object')):
expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark')
assert expect_result == res assert_equal_lineages(expect_result, res, self.assertDictEqual)
expect_result = { expect_result = {
'customized': {}, 'customized': {},
...@@ -100,4 +99,4 @@ class TestModelApi(TestCase): ...@@ -100,4 +99,4 @@ class TestModelApi(TestCase):
} }
} }
res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition) res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition)
assert expect_result == res assert_equal_lineages(expect_result, res, self.assertDictEqual)
...@@ -73,6 +73,10 @@ class TestModelLineage(TestCase): ...@@ -73,6 +73,10 @@ class TestModelLineage(TestCase):
TrainLineage(cls.summary_record) TrainLineage(cls.summary_record)
] ]
cls.run_context['list_callback'] = _ListCallback(callback) cls.run_context['list_callback'] = _ListCallback(callback)
cls.user_defined_info = {
"info": "info1",
"version": "v1"
}
@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
...@@ -83,7 +87,7 @@ class TestModelLineage(TestCase): ...@@ -83,7 +87,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin(self): def test_train_begin(self):
"""Test the begin function in TrainLineage.""" """Test the begin function in TrainLineage."""
train_callback = TrainLineage(self.summary_record, True) train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.begin(RunContext(self.run_context)) train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12 assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
...@@ -98,30 +102,6 @@ class TestModelLineage(TestCase): ...@@ -98,30 +102,6 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_train_begin_with_user_defined_info(self): def test_train_begin_with_user_defined_info(self):
"""Test TrainLineage with nested user defined info.""" """Test TrainLineage with nested user defined info."""
user_defined_info = {"info": {"version": "v1"}}
train_callback = TrainLineage(
self.summary_record,
False,
user_defined_info
)
train_callback.begin(RunContext(self.run_context))
assert train_callback.initial_learning_rate == 0.12
lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True
@pytest.mark.scene_train(2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_cpu
@pytest.mark.env_single
def test_train_begin_with_user_defined_key_in_lineage(self):
"""Test TrainLineage with nested user defined info."""
expected_res = {
"info": "info1",
"version": "v1"
}
user_defined_info = { user_defined_info = {
"info": "info1", "info": "info1",
"version": "v1", "version": "v1",
...@@ -137,7 +117,7 @@ class TestModelLineage(TestCase): ...@@ -137,7 +117,7 @@ class TestModelLineage(TestCase):
lineage_log_path = train_callback.lineage_summary.lineage_log_path lineage_log_path = train_callback.lineage_summary.lineage_log_path
assert os.path.isfile(lineage_log_path) is True assert os.path.isfile(lineage_log_path) is True
res = filter_summary_lineage(os.path.dirname(lineage_log_path)) res = filter_summary_lineage(os.path.dirname(lineage_log_path))
assert expected_res == res['object'][0]['model_lineage']['user_defined'] assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined']
@pytest.mark.scene_train(2) @pytest.mark.scene_train(2)
@pytest.mark.level0 @pytest.mark.level0
...@@ -168,7 +148,7 @@ class TestModelLineage(TestCase): ...@@ -168,7 +148,7 @@ class TestModelLineage(TestCase):
def test_training_end(self, *args): def test_training_end(self, *args):
"""Test the end function in TrainLineage.""" """Test the end function in TrainLineage."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True) train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
train_callback.initial_learning_rate = 0.12 train_callback.initial_learning_rate = 0.12
train_callback.end(RunContext(self.run_context)) train_callback.end(RunContext(self.run_context))
res = get_summary_lineage(SUMMARY_DIR) res = get_summary_lineage(SUMMARY_DIR)
...@@ -188,7 +168,7 @@ class TestModelLineage(TestCase): ...@@ -188,7 +168,7 @@ class TestModelLineage(TestCase):
@pytest.mark.env_single @pytest.mark.env_single
def test_eval_end(self): def test_eval_end(self):
"""Test the end function in EvalLineage.""" """Test the end function in EvalLineage."""
eval_callback = EvalLineage(self.summary_record, True) eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'})
eval_run_context = self.run_context eval_run_context = self.run_context
eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['metrics'] = {'accuracy': 0.78}
eval_run_context['valid_dataset'] = self.run_context['train_dataset'] eval_run_context['valid_dataset'] = self.run_context['train_dataset']
...@@ -361,7 +341,7 @@ class TestModelLineage(TestCase): ...@@ -361,7 +341,7 @@ class TestModelLineage(TestCase):
def test_train_with_customized_network(self, *args): def test_train_with_customized_network(self, *args):
"""Test train with customized network.""" """Test train with customized network."""
args[0].return_value = 64 args[0].return_value = 64
train_callback = TrainLineage(self.summary_record, True) train_callback = TrainLineage(self.summary_record, True, self.user_defined_info)
run_context_customized = self.run_context run_context_customized = self.run_context
del run_context_customized['optimizer'] del run_context_customized['optimizer']
del run_context_customized['net_outputs'] del run_context_customized['net_outputs']
......
...@@ -195,7 +195,8 @@ CUSTOMIZED__0 = { ...@@ -195,7 +195,8 @@ CUSTOMIZED__0 = {
CUSTOMIZED__1 = { CUSTOMIZED__1 = {
**CUSTOMIZED__0, **CUSTOMIZED__0,
'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'},
'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'} 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'},
'user_defined/eval_version': {'label': 'user_defined/eval_version', 'required': False, 'type': 'str'}
} }
CUSTOMIZED_0 = { CUSTOMIZED_0 = {
......
...@@ -27,7 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier ...@@ -27,7 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
from . import event_data from . import event_data
from ....utils.tools import deal_float_for_dict from ....utils.tools import assert_equal_lineages
def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
...@@ -282,31 +282,17 @@ class TestQuerier(TestCase): ...@@ -282,31 +282,17 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects) self.multi_querier = Querier(lineage_objects)
def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1
def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)
def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)
def test_get_summary_lineage_success_1(self): def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_success_3(self): def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -320,7 +306,7 @@ class TestQuerier(TestCase): ...@@ -320,7 +306,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm'] filter_keys=['model', 'algorithm']
) )
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -367,7 +353,7 @@ class TestQuerier(TestCase): ...@@ -367,7 +353,7 @@ class TestQuerier(TestCase):
} }
] ]
result = self.multi_querier.get_summary_lineage() result = self.multi_querier.get_summary_lineage()
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -375,7 +361,7 @@ class TestQuerier(TestCase): ...@@ -375,7 +361,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_success_6(self): def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -394,7 +380,7 @@ class TestQuerier(TestCase): ...@@ -394,7 +380,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys summary_dir='/path/to/summary0', filter_keys=filter_keys
) )
self._assert_list_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertListEqual)
def test_get_summary_lineage_fail(self): def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception.""" """Test the function of get_summary_lineage with exception."""
...@@ -437,7 +423,7 @@ class TestQuerier(TestCase): ...@@ -437,7 +423,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_2(self): def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -462,7 +448,7 @@ class TestQuerier(TestCase): ...@@ -462,7 +448,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_3(self): def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -479,7 +465,7 @@ class TestQuerier(TestCase): ...@@ -479,7 +465,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_4(self): def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -497,7 +483,7 @@ class TestQuerier(TestCase): ...@@ -497,7 +483,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage() result = self.multi_querier.filter_summary_lineage()
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_5(self): def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -512,7 +498,7 @@ class TestQuerier(TestCase): ...@@ -512,7 +498,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_6(self): def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -534,7 +520,7 @@ class TestQuerier(TestCase): ...@@ -534,7 +520,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_7(self): def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -556,7 +542,7 @@ class TestQuerier(TestCase): ...@@ -556,7 +542,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_8(self): def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -572,7 +558,7 @@ class TestQuerier(TestCase): ...@@ -572,7 +558,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_success_9(self): def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -586,7 +572,7 @@ class TestQuerier(TestCase): ...@@ -586,7 +572,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self._assert_lineages_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_filter_summary_lineage_fail(self): def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception.""" """Test the function of filter_summary_lineage with exception."""
......
...@@ -21,7 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj ...@@ -21,7 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data from . import event_data
from .test_querier import create_filtration_result, create_lineage_info from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict from ....utils.tools import assert_equal_lineages
class TestLineageObj(TestCase): class TestLineageObj(TestCase):
...@@ -51,56 +51,65 @@ class TestLineageObj(TestCase): ...@@ -51,56 +51,65 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage evaluation_lineage=lineage_info.eval_lineage
) )
def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)
def test_property(self): def test_property(self):
"""Test the function of getting property.""" """Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm self.lineage_obj.algorithm,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model self.lineage_obj.model,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset self.lineage_obj.train_dataset,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters self.lineage_obj.hyper_parameters,
self.assertDictEqual
)
assert_equal_lineages(
event_data.METRIC_0,
self.lineage_obj.metric,
self.assertDictEqual
) )
self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric) assert_equal_lineages(
self._assert_dict_equal(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset self.lineage_obj.valid_dataset,
self.assertDictEqual
) )
def test_property_eval_not_exist(self): def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event.""" """Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm self.lineage_obj_no_eval.algorithm,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model self.lineage_obj_no_eval.model,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset self.lineage_obj_no_eval.train_dataset,
self.assertDictEqual
) )
self._assert_dict_equal( assert_equal_lineages(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters self.lineage_obj_no_eval.hyper_parameters,
self.assertDictEqual
) )
self._assert_dict_equal({}, self.lineage_obj_no_eval.metric) assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual)
self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset) assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual)
def test_get_summary_info(self): def test_get_summary_info(self):
"""Test the function of get_summary_info.""" """Test the function of get_summary_info."""
...@@ -111,7 +120,7 @@ class TestLineageObj(TestCase): ...@@ -111,7 +120,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
} }
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self._assert_dict_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_to_model_lineage_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict.""" """Test the function of to_model_lineage_dict."""
...@@ -125,7 +134,7 @@ class TestLineageObj(TestCase): ...@@ -125,7 +134,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph') expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict() result = self.lineage_obj.to_model_lineage_dict()
self._assert_dict_equal(expected_result, result) assert_equal_lineages(expected_result, result, self.assertDictEqual)
def test_to_dataset_lineage_dict(self): def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict.""" """Test the function of to_dataset_lineage_dict."""
......
...@@ -83,9 +83,9 @@ def compare_result_with_file(result, expected_file_path): ...@@ -83,9 +83,9 @@ def compare_result_with_file(result, expected_file_path):
assert result == expected_results assert result == expected_results
def deal_float_for_dict(res: dict, expected_res: dict): def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=5):
""" """
Deal float rounded to five decimals in dict. Deal float rounded to specified decimals in dict.
For example: For example:
res:{ res:{
...@@ -125,10 +125,9 @@ def deal_float_for_dict(res: dict, expected_res: dict): ...@@ -125,10 +125,9 @@ def deal_float_for_dict(res: dict, expected_res: dict):
"metric": {"acc": 0.1234562} "metric": {"acc": 0.1234562}
} }
} }
decimal_num (int): decimal rounded digits.
""" """
decimal_num = 5
for key in res: for key in res:
value = res[key] value = res[key]
expected_value = expected_res[key] expected_value = expected_res[key]
...@@ -137,3 +136,22 @@ def deal_float_for_dict(res: dict, expected_res: dict): ...@@ -137,3 +136,22 @@ def deal_float_for_dict(res: dict, expected_res: dict):
elif isinstance(value, float): elif isinstance(value, float):
res[key] = round(value, decimal_num) res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num) expected_res[key] = round(expected_value, decimal_num)
def _deal_float_for_list(list1, list2, decimal_num):
"""Deal float for list1 and list2."""
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index], decimal_num)
index += 1
def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2):
"""Assert lineages."""
if isinstance(lineages1, list) and isinstance(lineages2, list):
_deal_float_for_list(lineages1, lineages2, decimal_num)
elif lineages1.get('object') is not None and lineages2.get('object') is not None:
_deal_float_for_list(lineages1['object'], lineages2['object'], decimal_num)
else:
deal_float_for_dict(lineages1, lineages2, decimal_num)
assert_func(lineages1, lineages2)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册