diff --git a/tests/st/func/lineagemgr/api/test_model_api.py b/tests/st/func/lineagemgr/api/test_model_api.py index 6da892f4773f9c80446b04523b58600f8078bfd3..4c421eb011166c2ceebc16affd358b63c0f9db00 100644 --- a/tests/st/func/lineagemgr/api/test_model_api.py +++ b/tests/st/func/lineagemgr/api/test_model_api.py @@ -31,6 +31,7 @@ from mindinsight.lineagemgr.common.exceptions.exceptions import (LineageFileNotF LineageSearchConditionParamError) from ..conftest import BASE_SUMMARY_DIR, DATASET_GRAPH, SUMMARY_DIR, SUMMARY_DIR_2 from .....ut.lineagemgr.querier import event_data +from .....utils.tools import assert_equal_lineages LINEAGE_INFO_RUN1 = { 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), @@ -39,7 +40,7 @@ LINEAGE_INFO_RUN1 = { }, 'hyper_parameters': { 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, + 'learning_rate': 0.12, 'loss_function': 'SoftmaxCrossEntropyWithLogits', 'epoch': 14, 'parallel_mode': 'stand_alone', @@ -73,11 +74,11 @@ LINEAGE_FILTRATION_EXCEPT_RUN = { 'user_defined': {}, 'network': 'ResNet', 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, + 'learning_rate': 0.12, 'epoch': 10, 'batch_size': 32, 'device_num': 2, - 'loss': 0.029999999329447746, + 'loss': 0.03, 'model_size': 64, 'metric': {}, 'dataset_mark': 2 @@ -92,10 +93,14 @@ LINEAGE_FILTRATION_RUN1 = { 'train_dataset_count': 1024, 'test_dataset_path': None, 'test_dataset_count': 1024, - 'user_defined': {'info': 'info1', 'version': 'v1'}, + 'user_defined': { + 'info': 'info1', + 'version': 'v1', + 'eval_version': 'version2' + }, 'network': 'ResNet', 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, + 'learning_rate': 0.12, 'epoch': 14, 'batch_size': 32, 'device_num': 2, @@ -119,14 +124,14 @@ LINEAGE_FILTRATION_RUN2 = { 'user_defined': {}, 'network': "ResNet", 'optimizer': "Momentum", - 'learning_rate': 0.11999999731779099, + 'learning_rate': 0.12, 'epoch': 10, 'batch_size': 32, 'device_num': 2, - 'loss': 0.029999999329447746, + 'loss': 0.03, 'model_size': 10, 'metric': { - 'accuracy': 2.7800000000000002 + 'accuracy': 2.78 }, 'dataset_mark': 3 }, @@ -173,7 +178,7 @@ class TestModelApi(TestCase): 'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'), 'hyper_parameters': { 'optimizer': 'Momentum', - 'learning_rate': 0.11999999731779099, + 'learning_rate': 0.12, 'loss_function': 'SoftmaxCrossEntropyWithLogits', 'epoch': 14, 'parallel_mode': 'stand_alone', @@ -190,9 +195,9 @@ class TestModelApi(TestCase): 'network': 'ResNet' } } - assert expect_total_res == total_res - assert expect_partial_res1 == partial_res1 - assert expect_partial_res2 == partial_res2 + assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual) + assert_equal_lineages(expect_partial_res1, partial_res1, self.assertDictEqual) + assert_equal_lineages(expect_partial_res2, partial_res2, self.assertDictEqual) # the lineage summary file is empty result = get_summary_lineage(self.dir_with_empty_lineage) @@ -345,7 +350,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == res + assert_equal_lineages(expect_result, res, self.assertDictEqual) expect_result = { 'customized': {}, @@ -356,7 +361,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == res + assert_equal_lineages(expect_result, res, self.assertDictEqual) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -394,7 +399,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == partial_res + assert_equal_lineages(expect_result, partial_res, self.assertDictEqual) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -432,7 +437,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == partial_res + assert_equal_lineages(expect_result, partial_res, self.assertDictEqual) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -461,7 +466,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(partial_res1.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == partial_res1 + assert_equal_lineages(expect_result, partial_res1, self.assertDictEqual) search_condition2 = { 'batch_size': { @@ -477,9 +482,6 @@ class TestModelApi(TestCase): 'count': 0 } partial_res2 = filter_summary_lineage(BASE_SUMMARY_DIR, search_condition2) - expect_objects = expect_result.get('object') - for idx, res_object in enumerate(partial_res2.get('object')): - expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') assert expect_result == partial_res2 @pytest.mark.level0 diff --git a/tests/st/func/lineagemgr/cache/test_lineage_cache.py b/tests/st/func/lineagemgr/cache/test_lineage_cache.py index 17e60f1f02eee4e3fb7c3c867bb6753b46ee9ae3..0bfb2c40e19d447ee9e5ba86e951d9c403998356 100644 --- a/tests/st/func/lineagemgr/cache/test_lineage_cache.py +++ b/tests/st/func/lineagemgr/cache/test_lineage_cache.py @@ -33,7 +33,7 @@ from ..api.test_model_api import LINEAGE_INFO_RUN1, LINEAGE_FILTRATION_EXCEPT_RU LINEAGE_FILTRATION_RUN1, LINEAGE_FILTRATION_RUN2 from ..conftest import BASE_SUMMARY_DIR from .....ut.lineagemgr.querier import event_data -from .....utils.tools import check_loading_done +from .....utils.tools import check_loading_done, assert_equal_lineages @pytest.mark.usefixtures("create_summary_dir") @@ -58,8 +58,7 @@ class TestModelApi(TestCase): """Test the interface of get_summary_lineage.""" total_res = general_get_summary_lineage(data_manager=self._data_manger, summary_dir="./run1") expect_total_res = LINEAGE_INFO_RUN1 - - assert expect_total_res == total_res + assert_equal_lineages(expect_total_res, total_res, self.assertDictEqual) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @@ -86,7 +85,7 @@ class TestModelApi(TestCase): expect_objects = expect_result.get('object') for idx, res_object in enumerate(res.get('object')): expect_objects[idx]['model_lineage']['dataset_mark'] = res_object['model_lineage'].get('dataset_mark') - assert expect_result == res + assert_equal_lineages(expect_result, res, self.assertDictEqual) expect_result = { 'customized': {}, @@ -100,4 +99,4 @@ class TestModelApi(TestCase): } } res = general_filter_summary_lineage(data_manager=self._data_manger, search_condition=search_condition) - assert expect_result == res + assert_equal_lineages(expect_result, res, self.assertDictEqual) diff --git a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py index e048eac934eb1c83a5b2cc57d54e19afa0732b73..268ac92c93d1bd272263194b269ad121e338712f 100644 --- a/tests/st/func/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/st/func/lineagemgr/collection/model/test_model_lineage.py @@ -73,6 +73,10 @@ class TestModelLineage(TestCase): TrainLineage(cls.summary_record) ] cls.run_context['list_callback'] = _ListCallback(callback) + cls.user_defined_info = { + "info": "info1", + "version": "v1" + } @pytest.mark.scene_train(2) @pytest.mark.level0 @@ -83,7 +87,7 @@ class TestModelLineage(TestCase): @pytest.mark.env_single def test_train_begin(self): """Test the begin function in TrainLineage.""" - train_callback = TrainLineage(self.summary_record, True) + train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) train_callback.begin(RunContext(self.run_context)) assert train_callback.initial_learning_rate == 0.12 lineage_log_path = train_callback.lineage_summary.lineage_log_path @@ -98,30 +102,6 @@ class TestModelLineage(TestCase): @pytest.mark.env_single def test_train_begin_with_user_defined_info(self): """Test TrainLineage with nested user defined info.""" - user_defined_info = {"info": {"version": "v1"}} - train_callback = TrainLineage( - self.summary_record, - False, - user_defined_info - ) - train_callback.begin(RunContext(self.run_context)) - assert train_callback.initial_learning_rate == 0.12 - lineage_log_path = train_callback.lineage_summary.lineage_log_path - assert os.path.isfile(lineage_log_path) is True - - @pytest.mark.scene_train(2) - @pytest.mark.level0 - @pytest.mark.platform_arm_ascend_training - @pytest.mark.platform_x86_gpu_training - @pytest.mark.platform_x86_ascend_training - @pytest.mark.platform_x86_cpu - @pytest.mark.env_single - def test_train_begin_with_user_defined_key_in_lineage(self): - """Test TrainLineage with nested user defined info.""" - expected_res = { - "info": "info1", - "version": "v1" - } user_defined_info = { "info": "info1", "version": "v1", @@ -137,7 +117,7 @@ class TestModelLineage(TestCase): lineage_log_path = train_callback.lineage_summary.lineage_log_path assert os.path.isfile(lineage_log_path) is True res = filter_summary_lineage(os.path.dirname(lineage_log_path)) - assert expected_res == res['object'][0]['model_lineage']['user_defined'] + assert self.user_defined_info == res['object'][0]['model_lineage']['user_defined'] @pytest.mark.scene_train(2) @pytest.mark.level0 @@ -168,7 +148,7 @@ class TestModelLineage(TestCase): def test_training_end(self, *args): """Test the end function in TrainLineage.""" args[0].return_value = 64 - train_callback = TrainLineage(self.summary_record, True) + train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) train_callback.initial_learning_rate = 0.12 train_callback.end(RunContext(self.run_context)) res = get_summary_lineage(SUMMARY_DIR) @@ -188,7 +168,7 @@ class TestModelLineage(TestCase): @pytest.mark.env_single def test_eval_end(self): """Test the end function in EvalLineage.""" - eval_callback = EvalLineage(self.summary_record, True) + eval_callback = EvalLineage(self.summary_record, True, {'eval_version': 'version2'}) eval_run_context = self.run_context eval_run_context['metrics'] = {'accuracy': 0.78} eval_run_context['valid_dataset'] = self.run_context['train_dataset'] @@ -361,7 +341,7 @@ class TestModelLineage(TestCase): def test_train_with_customized_network(self, *args): """Test train with customized network.""" args[0].return_value = 64 - train_callback = TrainLineage(self.summary_record, True) + train_callback = TrainLineage(self.summary_record, True, self.user_defined_info) run_context_customized = self.run_context del run_context_customized['optimizer'] del run_context_customized['net_outputs'] diff --git a/tests/ut/lineagemgr/querier/event_data.py b/tests/ut/lineagemgr/querier/event_data.py index a46ecfbc8e90e8e727a001aabc12c3a68212538d..c1e5280ce2fa360e420e8ef844b0bec791e15759 100644 --- a/tests/ut/lineagemgr/querier/event_data.py +++ b/tests/ut/lineagemgr/querier/event_data.py @@ -195,7 +195,8 @@ CUSTOMIZED__0 = { CUSTOMIZED__1 = { **CUSTOMIZED__0, 'user_defined/info': {'label': 'user_defined/info', 'required': False, 'type': 'str'}, - 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'} + 'user_defined/version': {'label': 'user_defined/version', 'required': False, 'type': 'str'}, + 'user_defined/eval_version': {'label': 'user_defined/eval_version', 'required': False, 'type': 'str'} } CUSTOMIZED_0 = { diff --git a/tests/ut/lineagemgr/querier/test_querier.py b/tests/ut/lineagemgr/querier/test_querier.py index 75c465cffc23a81a1b6d35ced0cc9feec61cc788..c8a4a6f23fa94a677ca6e4e7f7dea5f54db894df 100644 --- a/tests/ut/lineagemgr/querier/test_querier.py +++ b/tests/ut/lineagemgr/querier/test_querier.py @@ -27,7 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from . import event_data -from ....utils.tools import deal_float_for_dict +from ....utils.tools import assert_equal_lineages def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): @@ -282,31 +282,17 @@ class TestQuerier(TestCase): lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs self.multi_querier = Querier(lineage_objects) - def _deal_float_for_list(self, list1, list2): - index = 0 - for _ in list1: - deal_float_for_dict(list1[index], list2[index]) - index += 1 - - def _assert_list_equal(self, list1, list2): - self._deal_float_for_list(list1, list2) - self.assertListEqual(list1, list2) - - def _assert_lineages_equal(self, lineages1, lineages2): - self._deal_float_for_list(lineages1['object'], lineages2['object']) - self.assertDictEqual(lineages1, lineages2) - def test_get_summary_lineage_success_1(self): """Test the success of get_summary_lineage.""" expected_result = [LINEAGE_INFO_0] result = self.single_querier.get_summary_lineage() - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_success_2(self): """Test the success of get_summary_lineage.""" expected_result = [LINEAGE_INFO_0] result = self.single_querier.get_summary_lineage() - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_success_3(self): """Test the success of get_summary_lineage.""" @@ -320,7 +306,7 @@ class TestQuerier(TestCase): result = self.single_querier.get_summary_lineage( filter_keys=['model', 'algorithm'] ) - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_success_4(self): """Test the success of get_summary_lineage.""" @@ -367,7 +353,7 @@ class TestQuerier(TestCase): } ] result = self.multi_querier.get_summary_lineage() - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_success_5(self): """Test the success of get_summary_lineage.""" @@ -375,7 +361,7 @@ class TestQuerier(TestCase): result = self.multi_querier.get_summary_lineage( summary_dir='/path/to/summary1' ) - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_success_6(self): """Test the success of get_summary_lineage.""" @@ -394,7 +380,7 @@ class TestQuerier(TestCase): result = self.multi_querier.get_summary_lineage( summary_dir='/path/to/summary0', filter_keys=filter_keys ) - self._assert_list_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertListEqual) def test_get_summary_lineage_fail(self): """Test the function of get_summary_lineage with exception.""" @@ -437,7 +423,7 @@ class TestQuerier(TestCase): 'count': 2, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_2(self): """Test the success of filter_summary_lineage.""" @@ -462,7 +448,7 @@ class TestQuerier(TestCase): 'count': 2, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_3(self): """Test the success of filter_summary_lineage.""" @@ -479,7 +465,7 @@ class TestQuerier(TestCase): 'count': 7, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_4(self): """Test the success of filter_summary_lineage.""" @@ -497,7 +483,7 @@ class TestQuerier(TestCase): 'count': 7, } result = self.multi_querier.filter_summary_lineage() - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_5(self): """Test the success of filter_summary_lineage.""" @@ -512,7 +498,7 @@ class TestQuerier(TestCase): 'count': 1, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_6(self): """Test the success of filter_summary_lineage.""" @@ -534,7 +520,7 @@ class TestQuerier(TestCase): 'count': 7, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_7(self): """Test the success of filter_summary_lineage.""" @@ -556,7 +542,7 @@ class TestQuerier(TestCase): 'count': 7, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_8(self): """Test the success of filter_summary_lineage.""" @@ -572,7 +558,7 @@ class TestQuerier(TestCase): 'count': 1, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_success_9(self): """Test the success of filter_summary_lineage.""" @@ -586,7 +572,7 @@ class TestQuerier(TestCase): 'count': 7, } result = self.multi_querier.filter_summary_lineage(condition=condition) - self._assert_lineages_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_filter_summary_lineage_fail(self): """Test the function of filter_summary_lineage with exception.""" diff --git a/tests/ut/lineagemgr/querier/test_query_model.py b/tests/ut/lineagemgr/querier/test_query_model.py index 0da44de8c0c8156ac7d8c66ae1270ec82f1fa5b9..88d7691b412fd8751aa7c2f3a46fa9d80d9501ce 100644 --- a/tests/ut/lineagemgr/querier/test_query_model.py +++ b/tests/ut/lineagemgr/querier/test_query_model.py @@ -21,7 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj from . import event_data from .test_querier import create_filtration_result, create_lineage_info -from ....utils.tools import deal_float_for_dict +from ....utils.tools import assert_equal_lineages class TestLineageObj(TestCase): @@ -51,56 +51,65 @@ class TestLineageObj(TestCase): evaluation_lineage=lineage_info.eval_lineage ) - def _assert_dict_equal(self, dict1, dict2): - deal_float_for_dict(dict1, dict2) - self.assertDictEqual(dict1, dict2) - def test_property(self): """Test the function of getting property.""" self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], - self.lineage_obj.algorithm + self.lineage_obj.algorithm, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], - self.lineage_obj.model + self.lineage_obj.model, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], - self.lineage_obj.train_dataset + self.lineage_obj.train_dataset, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], - self.lineage_obj.hyper_parameters + self.lineage_obj.hyper_parameters, + self.assertDictEqual + ) + assert_equal_lineages( + event_data.METRIC_0, + self.lineage_obj.metric, + self.assertDictEqual ) - self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], - self.lineage_obj.valid_dataset + self.lineage_obj.valid_dataset, + self.assertDictEqual ) def test_property_eval_not_exist(self): """Test the function of getting property with no evaluation event.""" self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], - self.lineage_obj_no_eval.algorithm + self.lineage_obj_no_eval.algorithm, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], - self.lineage_obj_no_eval.model + self.lineage_obj_no_eval.model, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], - self.lineage_obj_no_eval.train_dataset + self.lineage_obj_no_eval.train_dataset, + self.assertDictEqual ) - self._assert_dict_equal( + assert_equal_lineages( event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], - self.lineage_obj_no_eval.hyper_parameters + self.lineage_obj_no_eval.hyper_parameters, + self.assertDictEqual ) - self._assert_dict_equal({}, self.lineage_obj_no_eval.metric) - self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset) + assert_equal_lineages({}, self.lineage_obj_no_eval.metric, self.assertDictEqual) + assert_equal_lineages({}, self.lineage_obj_no_eval.valid_dataset, self.assertDictEqual) def test_get_summary_info(self): """Test the function of get_summary_info.""" @@ -111,7 +120,7 @@ class TestLineageObj(TestCase): 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] } result = self.lineage_obj.get_summary_info(filter_keys) - self._assert_dict_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_to_model_lineage_dict(self): """Test the function of to_model_lineage_dict.""" @@ -125,7 +134,7 @@ class TestLineageObj(TestCase): expected_result['model_lineage']['dataset_mark'] = None expected_result.pop('dataset_graph') result = self.lineage_obj.to_model_lineage_dict() - self._assert_dict_equal(expected_result, result) + assert_equal_lineages(expected_result, result, self.assertDictEqual) def test_to_dataset_lineage_dict(self): """Test the function of to_dataset_lineage_dict.""" diff --git a/tests/utils/tools.py b/tests/utils/tools.py index d8a1324ea1d2e8828b3659c4640a94a709453594..4b24813fced984a8ccca5551bb08c92d9d974227 100644 --- a/tests/utils/tools.py +++ b/tests/utils/tools.py @@ -83,9 +83,9 @@ def compare_result_with_file(result, expected_file_path): assert result == expected_results -def deal_float_for_dict(res: dict, expected_res: dict): +def deal_float_for_dict(res: dict, expected_res: dict, decimal_num=5): """ - Deal float rounded to five decimals in dict. + Deal float rounded to specified decimals in dict. For example: res:{ @@ -125,10 +125,9 @@ def deal_float_for_dict(res: dict, expected_res: dict): "metric": {"acc": 0.1234562} } } - + decimal_num (int): decimal rounded digits. """ - decimal_num = 5 for key in res: value = res[key] expected_value = expected_res[key] @@ -137,3 +136,22 @@ def deal_float_for_dict(res: dict, expected_res: dict): elif isinstance(value, float): res[key] = round(value, decimal_num) expected_res[key] = round(expected_value, decimal_num) + + +def _deal_float_for_list(list1, list2, decimal_num): + """Deal float for list1 and list2.""" + index = 0 + for _ in list1: + deal_float_for_dict(list1[index], list2[index], decimal_num) + index += 1 + + +def assert_equal_lineages(lineages1, lineages2, assert_func, decimal_num=2): + """Assert lineages.""" + if isinstance(lineages1, list) and isinstance(lineages2, list): + _deal_float_for_list(lineages1, lineages2, decimal_num) + elif lineages1.get('object') is not None and lineages2.get('object') is not None: + _deal_float_for_list(lineages1['object'], lineages2['object'], decimal_num) + else: + deal_float_for_dict(lineages1, lineages2, decimal_num) + assert_func(lineages1, lineages2)