diff --git a/tests/ut/datavisual/processors/test_scalars_processor.py b/tests/ut/datavisual/processors/test_scalars_processor.py index f269166fb0240bf2bab4388a7fd18339d09ab3ef..3dc29978bb9a1ef2524abb090a4e706f134d9d47 100644 --- a/tests/ut/datavisual/processors/test_scalars_processor.py +++ b/tests/ut/datavisual/processors/test_scalars_processor.py @@ -116,3 +116,18 @@ class TestScalarsProcessor: assert recv_values.get('wall_time') == expected_values.get('wall_time') assert recv_values.get('step') == expected_values.get('step') assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6 + + @pytest.mark.usefixtures('load_scalar_record') + def test_get_scalars(self): + """Get scalars success.""" + scalar_processor = ScalarsProcessor(self._mock_data_manager) + scalars = scalar_processor.get_scalars([self._train_id], [self._complete_tag_name]) + scalar = scalars[0] + + assert scalar['train_id'] == self._train_id + assert scalar['tag'] == self._complete_tag_name + + for recv_values, expected_values in zip(scalar['values'], self._scalars_metadata): + assert recv_values.get('wall_time') == expected_values.get('wall_time') + assert recv_values.get('step') == expected_values.get('step') + assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6 diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index 98b1af1f20fef50ff040126ad2c3aa57f6cf0f57..80c4eb989fc1716b0a5f1ac45e6ce91db557db45 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -141,3 +141,11 @@ class TestTrainTaskManager: assert train_id in self._plugins_id_map.get(plugin_name) else: assert train_id not in self._plugins_id_map.get(plugin_name) + + @pytest.mark.usefixtures('load_data') + def test_cache_train_jobs(self): + """Test caching train jobs with train ids.""" + train_task_manager = TrainTaskManager(self._mock_data_manager) + + cache_result = train_task_manager.cache_train_jobs(self._train_id_list) + assert len(self._train_id_list) == len(cache_result)