提交 d3b797d9 编写于 作者: L liangyongxiong

add ut for scalars comparision

上级 dc75de77
...@@ -116,3 +116,18 @@ class TestScalarsProcessor: ...@@ -116,3 +116,18 @@ class TestScalarsProcessor:
assert recv_values.get('wall_time') == expected_values.get('wall_time') assert recv_values.get('wall_time') == expected_values.get('wall_time')
assert recv_values.get('step') == expected_values.get('step') assert recv_values.get('step') == expected_values.get('step')
assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6 assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6
@pytest.mark.usefixtures('load_scalar_record')
def test_get_scalars(self):
"""Get scalars success."""
scalar_processor = ScalarsProcessor(self._mock_data_manager)
scalars = scalar_processor.get_scalars([self._train_id], [self._complete_tag_name])
scalar = scalars[0]
assert scalar['train_id'] == self._train_id
assert scalar['tag'] == self._complete_tag_name
for recv_values, expected_values in zip(scalar['values'], self._scalars_metadata):
assert recv_values.get('wall_time') == expected_values.get('wall_time')
assert recv_values.get('step') == expected_values.get('step')
assert abs(recv_values.get('value') - expected_values.get('value')) < 1e-6
...@@ -141,3 +141,11 @@ class TestTrainTaskManager: ...@@ -141,3 +141,11 @@ class TestTrainTaskManager:
assert train_id in self._plugins_id_map.get(plugin_name) assert train_id in self._plugins_id_map.get(plugin_name)
else: else:
assert train_id not in self._plugins_id_map.get(plugin_name) assert train_id not in self._plugins_id_map.get(plugin_name)
@pytest.mark.usefixtures('load_data')
def test_cache_train_jobs(self):
"""Test caching train jobs with train ids."""
train_task_manager = TrainTaskManager(self._mock_data_manager)
cache_result = train_task_manager.cache_train_jobs(self._train_id_list)
assert len(self._train_id_list) == len(cache_result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册