提交 eaa4497d 编写于 作者: L liangyongxiong

fix bug of multi-scalars comparision under concurrency

上级 c4fc9bfb
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
from urllib.parse import unquote from urllib.parse import unquote
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
...@@ -71,25 +73,44 @@ class ScalarsProcessor(BaseProcessor): ...@@ -71,25 +73,44 @@ class ScalarsProcessor(BaseProcessor):
scalars = [] scalars = []
for train_id in train_ids: for train_id in train_ids:
for tag in tags: scalars += self._get_train_scalars(train_id, tags)
try:
tensors = self._data_manager.list_tensors(train_id, tag) return scalars
except ParamValueError:
continue def _get_train_scalars(self, train_id, tags):
"""
scalar = { Get scalar data for given train_id and tags.
'train_id': train_id,
'tag': tag, Args:
'values': [], train_id (str): Specify train job ID.
} tags (list): Specify list of tags.
for tensor in tensors: Returns:
scalar['values'].append({ list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar.
'wall_time': tensor.wall_time, """
'step': tensor.step, scalars = []
'value': if_nan_inf_to_none('scalar_value', tensor.value), for tag in tags:
}) try:
tensors = self._data_manager.list_tensors(train_id, tag)
scalars.append(scalar) except ParamValueError:
continue
except TrainJobNotExistError:
logger.warning('Can not find the given train job in cache.')
return []
scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}
for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})
scalars.append(scalar)
return scalars return scalars
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册