From d3b797d973961e37233095abf3271b881079c792 Mon Sep 17 00:00:00 2001 From: liangyongxiong Date: Fri, 29 May 2020 18:13:55 +0800 Subject: [PATCH] add ut for scalars comparision --- .../processors/test_scalars_processor.py | 15 +++++++++++++++ .../processors/test_train_task_manager.py | 8 ++++++++ 2 files changed, 23 insertions(+) diff --git a/tests/ut/datavisual/processors/test_scalars_processor.py b/tests/ut/datavisual/processors/test_scalars_processor.py index f269166..3dc2997 100644 --- a/tests/ut/datavisual/processors/test_scalars_processor.py +++ b/tests/ut/datavisual/processors/test_scalars_processor.py @@ -116,3 +116,18 @@ class TestScalarsProcessor: assert recv_values.get('wall_time') == expected_values.get('wall_time') assert recv_values.get('step') == expected_values.get('step') assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6 + + @pytest.mark.usefixtures('load_scalar_record') + def test_get_scalars(self): + """Get scalars success.""" + scalar_processor = ScalarsProcessor(self._mock_data_manager) + scalars = scalar_processor.get_scalars([self._train_id], [self._complete_tag_name]) + scalar = scalars[0] + + assert scalar['train_id'] == self._train_id + assert scalar['tag'] == self._complete_tag_name + + for recv_values, expected_values in zip(scalar['values'], self._scalars_metadata): + assert recv_values.get('wall_time') == expected_values.get('wall_time') + assert recv_values.get('step') == expected_values.get('step') + assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6 diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index 98b1af1..80c4eb9 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -141,3 +141,11 @@ class TestTrainTaskManager: assert train_id in self._plugins_id_map.get(plugin_name) else: assert train_id not in self._plugins_id_map.get(plugin_name) + + @pytest.mark.usefixtures('load_data') + def test_cache_train_jobs(self): + """Test caching train jobs with train ids.""" + train_task_manager = TrainTaskManager(self._mock_data_manager) + + cache_result = train_task_manager.cache_train_jobs(self._train_id_list) + assert len(self._train_id_list) == len(cache_result) -- GitLab