提交 e1d627ff 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!176 fix the float compare in lineage ut

Merge pull request !176 from luopengting/r0.3
......@@ -22,7 +22,7 @@ EVENT_TRAIN_DICT_0 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum0',
'learning_rate': 0.10000000149011612,
'learning_rate': 0.11,
'loss_function': '',
'epoch': 1,
'parallel_mode': 'stand_alone0',
......@@ -31,7 +31,7 @@ EVENT_TRAIN_DICT_0 = {
},
'algorithm': {
'network': 'TrainOneStepCell0',
'loss': 2.3025848865509033
'loss': 2.3025841
},
'train_dataset': {
'train_dataset_path': '',
......@@ -49,7 +49,7 @@ EVENT_TRAIN_DICT_1 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum1',
'learning_rate': 0.20000000298023224,
'learning_rate': 0.2100001,
'loss_function': 'loss_function1',
'epoch': 1,
'parallel_mode': 'stand_alone1',
......@@ -58,7 +58,7 @@ EVENT_TRAIN_DICT_1 = {
},
'algorithm': {
'network': 'TrainOneStepCell1',
'loss': 2.4025847911834717
'loss': 2.4025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset1',
......@@ -76,7 +76,7 @@ EVENT_TRAIN_DICT_2 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum2',
'learning_rate': 0.30000001192092896,
'learning_rate': 0.3100001,
'loss_function': 'loss_function2',
'epoch': 2,
'parallel_mode': 'stand_alone2',
......@@ -85,7 +85,7 @@ EVENT_TRAIN_DICT_2 = {
},
'algorithm': {
'network': 'TrainOneStepCell2',
'loss': 2.502584934234619
'loss': 2.5025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset2',
......@@ -103,7 +103,7 @@ EVENT_TRAIN_DICT_3 = {
'train_lineage': {
'hyper_parameters': {
'optimizer': 'ApplyMomentum3',
'learning_rate': 0.4000000059604645,
'learning_rate': 0.4,
'loss_function': 'loss_function3',
'epoch': 2,
'parallel_mode': 'stand_alone3',
......@@ -112,7 +112,7 @@ EVENT_TRAIN_DICT_3 = {
},
'algorithm': {
'network': 'TrainOneStepCell3',
'loss': 2.6025848388671875
'loss': 2.6025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset3',
......@@ -139,7 +139,7 @@ EVENT_TRAIN_DICT_4 = {
},
'algorithm': {
'network': 'TrainOneStepCell4',
'loss': 2.702584981918335
'loss': 2.7025841
},
'train_dataset': {
'train_dataset_path': '/path/to/train_dataset4',
......@@ -166,7 +166,7 @@ EVENT_TRAIN_DICT_5 = {
},
'algorithm': {
'network': 'TrainOneStepCell5',
'loss': 2.702584981918335
'loss': 2.7025841
},
'train_dataset': {
'train_dataset_size': 35
......@@ -211,33 +211,33 @@ CUSTOMIZED_2 = {
}
METRIC_1 = {
'accuracy': 1.0000002,
'accuracy': 1.2000002,
'mae': 2.00000002,
'mse': 3.00000002
}
METRIC_2 = {
'accuracy': 1.0000003,
'mae': 2.00000003,
'mse': 3.00000003
'accuracy': 1.3000003,
'mae': 2.30000003,
'mse': 3.30000003
}
METRIC_3 = {
'accuracy': 1.0000004,
'mae': 2.00000004,
'mse': 3.00000004
'accuracy': 1.4000004,
'mae': 2.40000004,
'mse': 3.40000004
}
METRIC_4 = {
'accuracy': 1.0000005,
'mae': 2.00000005,
'mse': 3.00000005
'accuracy': 1.5000005,
'mae': 2.50000005,
'mse': 3.50000005
}
METRIC_5 = {
'accuracy': 1.0000006,
'mae': 2.00000006,
'mse': 3.00000006
'accuracy': 1.7000006,
'mae': 2.60000006,
'mse': 3.60000006
}
EVENT_EVAL_DICT_0 = {
......
......@@ -27,6 +27,7 @@ from mindinsight.lineagemgr.querier.querier import Querier
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageInfo
from . import event_data
from ....utils.tools import deal_float_for_dict
def create_lineage_info(train_event_dict, eval_event_dict, dataset_event_dict):
......@@ -266,7 +267,6 @@ class TestQuerier(TestCase):
mock_file_handler = MagicMock()
mock_file_handler.size = 1
args[2].return_value = [{'relative_path': './', 'update_time': 1}]
single_summary_path = '/path/to/summary0'
lineage_objects = LineageOrganizer(summary_base_dir=single_summary_path).super_lineage_objs
......@@ -282,17 +282,31 @@ class TestQuerier(TestCase):
lineage_objects = LineageOrganizer(summary_base_dir=summary_base_dir).super_lineage_objs
self.multi_querier = Querier(lineage_objects)
def _deal_float_for_list(self, list1, list2):
index = 0
for _ in list1:
deal_float_for_dict(list1[index], list2[index])
index += 1
def _assert_list_equal(self, list1, list2):
self._deal_float_for_list(list1, list2)
self.assertListEqual(list1, list2)
def _assert_lineages_equal(self, lineages1, lineages2):
self._deal_float_for_list(lineages1['object'], lineages2['object'])
self.assertDictEqual(lineages1, lineages2)
def test_get_summary_lineage_success_1(self):
"""Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_2(self):
"""Test the success of get_summary_lineage."""
expected_result = [LINEAGE_INFO_0]
result = self.single_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_3(self):
"""Test the success of get_summary_lineage."""
......@@ -306,7 +320,7 @@ class TestQuerier(TestCase):
result = self.single_querier.get_summary_lineage(
filter_keys=['model', 'algorithm']
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_4(self):
"""Test the success of get_summary_lineage."""
......@@ -353,7 +367,7 @@ class TestQuerier(TestCase):
}
]
result = self.multi_querier.get_summary_lineage()
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_5(self):
"""Test the success of get_summary_lineage."""
......@@ -361,7 +375,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary1'
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_success_6(self):
"""Test the success of get_summary_lineage."""
......@@ -380,7 +394,7 @@ class TestQuerier(TestCase):
result = self.multi_querier.get_summary_lineage(
summary_dir='/path/to/summary0', filter_keys=filter_keys
)
self.assertListEqual(expected_result, result)
self._assert_list_equal(expected_result, result)
def test_get_summary_lineage_fail(self):
"""Test the function of get_summary_lineage with exception."""
......@@ -423,7 +437,7 @@ class TestQuerier(TestCase):
'count': 2,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_2(self):
"""Test the success of filter_summary_lineage."""
......@@ -448,7 +462,7 @@ class TestQuerier(TestCase):
'count': 2,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_3(self):
"""Test the success of filter_summary_lineage."""
......@@ -465,7 +479,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_4(self):
"""Test the success of filter_summary_lineage."""
......@@ -483,7 +497,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage()
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_5(self):
"""Test the success of filter_summary_lineage."""
......@@ -498,7 +512,7 @@ class TestQuerier(TestCase):
'count': 1,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_6(self):
"""Test the success of filter_summary_lineage."""
......@@ -520,7 +534,7 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_7(self):
"""Test the success of filter_summary_lineage."""
......@@ -542,14 +556,14 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_8(self):
"""Test the success of filter_summary_lineage."""
condition = {
'metric/accuracy': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
}
}
expected_result = {
......@@ -558,7 +572,7 @@ class TestQuerier(TestCase):
'count': 1,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_success_9(self):
"""Test the success of filter_summary_lineage."""
......@@ -572,14 +586,14 @@ class TestQuerier(TestCase):
'count': 7,
}
result = self.multi_querier.filter_summary_lineage(condition=condition)
self.assertDictEqual(expected_result, result)
self._assert_lineages_equal(expected_result, result)
def test_filter_summary_lineage_fail(self):
"""Test the function of filter_summary_lineage with exception."""
condition = {
'xxx': {
'lt': 1.0000006,
'gt': 1.0000004
'lt': 1.6000006,
'gt': 1.4000004
}
}
self.assertRaises(
......
......@@ -21,6 +21,7 @@ from mindinsight.lineagemgr.querier.query_model import LineageObj
from . import event_data
from .test_querier import create_filtration_result, create_lineage_info
from ....utils.tools import deal_float_for_dict
class TestLineageObj(TestCase):
......@@ -50,27 +51,31 @@ class TestLineageObj(TestCase):
evaluation_lineage=lineage_info.eval_lineage
)
def _assert_dict_equal(self, dict1, dict2):
deal_float_for_dict(dict1, dict2)
self.assertDictEqual(dict1, dict2)
def test_property(self):
"""Test the function of getting property."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj.algorithm
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj.model
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj.train_dataset
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj.hyper_parameters
)
self.assertDictEqual(event_data.METRIC_0, self.lineage_obj.metric)
self.assertDictEqual(
self._assert_dict_equal(event_data.METRIC_0, self.lineage_obj.metric)
self._assert_dict_equal(
event_data.EVENT_EVAL_DICT_0['evaluation_lineage']['valid_dataset'],
self.lineage_obj.valid_dataset
)
......@@ -78,24 +83,24 @@ class TestLineageObj(TestCase):
def test_property_eval_not_exist(self):
"""Test the function of getting property with no evaluation event."""
self.assertEqual(self.summary_dir, self.lineage_obj.summary_dir)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['algorithm'],
self.lineage_obj_no_eval.algorithm
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['model'],
self.lineage_obj_no_eval.model
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['train_dataset'],
self.lineage_obj_no_eval.train_dataset
)
self.assertDictEqual(
self._assert_dict_equal(
event_data.EVENT_TRAIN_DICT_0['train_lineage']['hyper_parameters'],
self.lineage_obj_no_eval.hyper_parameters
)
self.assertDictEqual({}, self.lineage_obj_no_eval.metric)
self.assertDictEqual({}, self.lineage_obj_no_eval.valid_dataset)
self._assert_dict_equal({}, self.lineage_obj_no_eval.metric)
self._assert_dict_equal({}, self.lineage_obj_no_eval.valid_dataset)
def test_get_summary_info(self):
"""Test the function of get_summary_info."""
......@@ -106,7 +111,7 @@ class TestLineageObj(TestCase):
'model': event_data.EVENT_TRAIN_DICT_0['train_lineage']['model']
}
result = self.lineage_obj.get_summary_info(filter_keys)
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)
def test_to_model_lineage_dict(self):
"""Test the function of to_model_lineage_dict."""
......@@ -120,7 +125,7 @@ class TestLineageObj(TestCase):
expected_result['model_lineage']['dataset_mark'] = None
expected_result.pop('dataset_graph')
result = self.lineage_obj.to_model_lineage_dict()
self.assertDictEqual(expected_result, result)
self._assert_dict_equal(expected_result, result)
def test_to_dataset_lineage_dict(self):
"""Test the function of to_dataset_lineage_dict."""
......
......@@ -81,3 +81,59 @@ def compare_result_with_file(result, expected_file_path):
with open(expected_file_path, 'r') as file:
expected_results = json.load(file)
assert result == expected_results
def deal_float_for_dict(res: dict, expected_res: dict):
"""
Deal float rounded to five decimals in dict.
For example:
res:{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
After:
res:{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
expected_res:
{
"model_lineages": {
"metric": {"acc": 0.12346}
}
}
Args:
res (dict): e.g.
{
"model_lineages": {
"metric": {"acc": 0.1234561}
}
}
expected_res (dict):
{
"model_lineages": {
"metric": {"acc": 0.1234562}
}
}
"""
decimal_num = 5
for key in res:
value = res[key]
expected_value = expected_res[key]
if isinstance(value, dict):
deal_float_for_dict(value, expected_value)
elif isinstance(value, float):
res[key] = round(value, decimal_num)
expected_res[key] = round(expected_value, decimal_num)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册