提交 e1d627ff 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!176 fix the float compare in lineage ut

Merge pull request !176 from luopengting/r0.3
...@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = { ...@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum0', 'optimizer': 'ApplyMomentum0',
'learning_rate': 0.10000000149011612, 'learning_rate': 0.11,
'loss_function': '', 'loss_function': '',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone0', 'parallel_mode': 'stand_alone0',
...@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = { ...@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell0', 'network': 'TrainOneStepCell0',
'loss': 2.3025848865509033 'loss': 2.3025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '', 'train_dataset_path': '',
...@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = { ...@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum1', 'optimizer': 'ApplyMomentum1',
'learning_rate': 0.20000000298023224, 'learning_rate': 0.2100001,
'loss_function': 'loss_function1', 'loss_function': 'loss_function1',
'epoch': 1, 'epoch': 1,
'parallel_mode': 'stand_alone1', 'parallel_mode': 'stand_alone1',
...@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = { ...@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell1', 'network': 'TrainOneStepCell1',
'loss': 2.4025847911834717 'loss': 2.4025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset1', 'train_dataset_path': '/path/to/train_dataset1',
...@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = { ...@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum2', 'optimizer': 'ApplyMomentum2',
'learning_rate': 0.30000001192092896, 'learning_rate': 0.3100001,
'loss_function': 'loss_function2', 'loss_function': 'loss_function2',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone2', 'parallel_mode': 'stand_alone2',
...@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = { ...@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell2', 'network': 'TrainOneStepCell2',
'loss': 2.502584934234619 'loss': 2.5025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset2', 'train_dataset_path': '/path/to/train_dataset2',
...@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = { ...@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = {
'train_lineage': { 'train_lineage': {
'hyper_parameters': { 'hyper_parameters': {
'optimizer': 'ApplyMomentum3', 'optimizer': 'ApplyMomentum3',
'learning_rate': 0.4000000059604645, 'learning_rate': 0.4,
'loss_function': 'loss_function3', 'loss_function': 'loss_function3',
'epoch': 2, 'epoch': 2,
'parallel_mode': 'stand_alone3', 'parallel_mode': 'stand_alone3',
...@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = { ...@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell3', 'network': 'TrainOneStepCell3',
'loss': 2.6025848388671875 'loss': 2.6025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset3', 'train_dataset_path': '/path/to/train_dataset3',
...@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = { ...@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell4', 'network': 'TrainOneStepCell4',
'loss': 2.702584981918335 'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_path': '/path/to/train_dataset4', 'train_dataset_path': '/path/to/train_dataset4',
...@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = { ...@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = {
}, },
'algorithm': { 'algorithm': {
'network': 'TrainOneStepCell5', 'network': 'TrainOneStepCell5',
'loss': 2.702584981918335 'loss': 2.7025841
}, },
'train_dataset': { 'train_dataset': {
'train_dataset_size': 35 'train_dataset_size': 35
...@@ -211,33 +211,33 @@ CUSTOMIZED_2 = { ...@@ -211,33 +211,33 @@ CUSTOMIZED_2 = {
} }
METRIC_1 = { METRIC_1 = {
'accuracy': 1.0000002, 'accuracy': 1.2000002,
'mae': 2.00000002, 'mae': 2.00000002,
'mse': 3.00000002 'mse': 3.00000002
} }
METRIC_2 = { METRIC_2 = {
'accuracy': 1.0000003, 'accuracy': 1.3000003,
'mae': 2.00000003, 'mae': 2.30000003,
'mse': 3.00000003 'mse': 3.30000003
} }
METRIC_3 = { METRIC_3 = {
'accuracy': 1.0000004, 'accuracy': 1.4000004,
'mae': 2.00000004, 'mae': 2.40000004,
'mse': 3.00000004 'mse': 3.40000004
} }
METRIC_4 = { METRIC_4 = {
'accuracy': 1.0000005, 'accuracy': 1.5000005,
'mae': 2.00000005, 'mae': 2.50000005,
'mse': 3.00000005 'mse': 3.50000005
} }
METRIC_5 = { METRIC_5 = {
'accuracy': 1.0000006, 'accuracy': 1.7000006,
'mae': 2.00000006, 'mae': 2.60000006,
'mse': 3.00000006 'mse': 3.60000006
} }
EVENT_EVAL_DICT_0 = { EVENT_EVAL_DICT_0 = {
......
...@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier ...@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
from . import event_data from . import event_data
from ....utils.tools import deal_float_for_dict
def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict): def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
...@@ -266,7 +267,6 @@ class TestQuerier(TestCase): ...@@ -266,7 +267,6 @@ class TestQuerier(TestCase):
mock_file_handler = MagicMock() mock_file_handler = MagicMock()
mock_file_handler.size = 1 mock_file_handler.size = 1
args[2].return_value = [{'relative_path': './', 'update_time': 1}] args[2].return_value = [{'relative_path': './', 'update_time': 1}]
single_summary_path = '/path/to/summary0' single_summary_path = '/path/to/summary0'
lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs
...@@ -282,17 +282,31 @@ class TestQuerier(TestCase): ...@@ -282,17 +282,31 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects) self.multi_querier = Querier(lineage_objects)
def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1
def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)
def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)
def test_get_summary_lineage_success_1(self): def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_2(self): def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0] expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage() result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_3(self): def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -306,7 +320,7 @@ class TestQuerier(TestCase): ...@@ -306,7 +320,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage( result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm'] filter_keys=['model', 'algorithm']
) )
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_4(self): def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -353,7 +367,7 @@ class TestQuerier(TestCase): ...@@ -353,7 +367,7 @@ class TestQuerier(TestCase):
} }
] ]
result = self.multi_querier.get_summary_lineage() result = self.multi_querier.get_summary_lineage()
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_5(self): def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -361,7 +375,7 @@ class TestQuerier(TestCase): ...@@ -361,7 +375,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1' summary_dir='/path/to/summary1'
) )
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_6(self): def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage.""" """Test the success of get_summary_lineage."""
...@@ -380,7 +394,7 @@ class TestQuerier(TestCase): ...@@ -380,7 +394,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage( result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys summary_dir='/path/to/summary0', filter_keys=filter_keys
) )
self.assertListEqual(expected_result, result) self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_fail(self): def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception.""" """Test the function of get_summary_lineage with exception."""
...@@ -423,7 +437,7 @@ class TestQuerier(TestCase): ...@@ -423,7 +437,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_2(self): def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -448,7 +462,7 @@ class TestQuerier(TestCase): ...@@ -448,7 +462,7 @@ class TestQuerier(TestCase):
'count': 2, 'count': 2,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_3(self): def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -465,7 +479,7 @@ class TestQuerier(TestCase): ...@@ -465,7 +479,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_4(self): def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -483,7 +497,7 @@ class TestQuerier(TestCase): ...@@ -483,7 +497,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage() result = self.multi_querier.filter_summary_lineage()
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_5(self): def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -498,7 +512,7 @@ class TestQuerier(TestCase): ...@@ -498,7 +512,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_6(self): def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -520,7 +534,7 @@ class TestQuerier(TestCase): ...@@ -520,7 +534,7 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_7(self): def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -542,14 +556,14 @@ class TestQuerier(TestCase): ...@@ -542,14 +556,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_8(self): def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
condition = { condition = {
'metric/accuracy': { 'metric/accuracy': {
'lt': 1.0000006, 'lt': 1.6000006,
'gt': 1.0000004 'gt': 1.4000004
} }
} }
expected_result = { expected_result = {
...@@ -558,7 +572,7 @@ class TestQuerier(TestCase): ...@@ -558,7 +572,7 @@ class TestQuerier(TestCase):
'count': 1, 'count': 1,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_9(self): def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage.""" """Test the success of filter_summary_lineage."""
...@@ -572,14 +586,14 @@ class TestQuerier(TestCase): ...@@ -572,14 +586,14 @@ class TestQuerier(TestCase):
'count': 7, 'count': 7,
} }
result = self.multi_querier.filter_summary_lineage(condition=condition) result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result) self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_fail(self): def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception.""" """Test the function of filter_summary_lineage with exception."""
condition = { condition = {
'xxx': { 'xxx': {
'lt': 1.0000006, 'lt': 1.6000006,
'gt': 1.0000004 'gt': 1.4000004
} }
} }
self.assertRaises( self.assertRaises(
......
...@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj ...@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data from . import event_data
from .test_querier import create_filtration_result, create_lineage_info from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict
class TestLineageObj(TestCase): class TestLineageObj(TestCase):
...@@ -50,27 +51,31 @@ class TestLineageObj(TestCase): ...@@ -50,27 +51,31 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage evaluation_lineage=lineage_info.eval_lineage
) )
def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)
def test_property(self): def test_property(self):
"""Test the function of getting property.""" """Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm self.lineage_obj.algorithm
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model self.lineage_obj.model
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset self.lineage_obj.train_dataset
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters self.lineage_obj.hyper_parameters
) )
self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric) self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric)
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'], event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset self.lineage_obj.valid_dataset
) )
...@@ -78,24 +83,24 @@ class TestLineageObj(TestCase): ...@@ -78,24 +83,24 @@ class TestLineageObj(TestCase):
def test_property_eval_not_exist(self): def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event.""" """Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir) self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm self.lineage_obj_no_eval.algorithm
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model self.lineage_obj_no_eval.model
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset self.lineage_obj_no_eval.train_dataset
) )
self.assertDictEqual( self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'], event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters self.lineage_obj_no_eval.hyper_parameters
) )
self.assertDictEqual({}, self.lineage_obj_no_eval.metric) self._assert_dict_equal({}, self.lineage_obj_no_eval.metric)
self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset) self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset)
def test_get_summary_info(self): def test_get_summary_info(self):
"""Test the function of get_summary_info.""" """Test the function of get_summary_info."""
...@@ -106,7 +111,7 @@ class TestLineageObj(TestCase): ...@@ -106,7 +111,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'] 'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
} }
result = self.lineage_obj.get_summary_info(filter_keys) result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result) self._assert_dict_equal(expected_result, result)
def test_to_model_lineage_dict(self): def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict.""" """Test the function of to_model_lineage_dict."""
...@@ -120,7 +125,7 @@ class TestLineageObj(TestCase): ...@@ -120,7 +125,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph') expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict() result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result) self._assert_dict_equal(expected_result, result)
def test_to_dataset_lineage_dict(self): def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict.""" """Test the function of to_dataset_lineage_dict."""
......
...@@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path): ...@@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path):
with open(expected_file_path, 'r') as file: with open(expected_file_path, 'r') as file:
expected_results = json.load(file) expected_results = json.load(file)
assert result == expected_results assert result == expected_results
def deal_float_for_dict(res: dict, expected_res: dict):
"""
Deal float rounded to five decimals in dict.
For example:
res:{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
After:
res:{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
Args:
res (dict): e.g.
{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res (dict):
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
"""
decimal_num = 5
for key in res:
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册