提交 88932364 编写于 作者: L liangyongxiong

compare scalars within multiple train jobs

上级 9e0d7cd5
...@@ -25,10 +25,11 @@ from flask import request ...@@ -25,10 +25,11 @@ from flask import request
from flask import jsonify from flask import jsonify
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import str_to_bool from mindinsight.datavisual.utils.tools import str_to_bool
from mindinsight.datavisual.utils.tools import get_train_id from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
...@@ -65,16 +66,11 @@ def query_train_jobs(): ...@@ -65,16 +66,11 @@ def query_train_jobs():
offset = request.args.get("offset", default=0) offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10) limit = request.args.get("limit", default=10)
summary_watcher = SummaryWatcher() offset = Validation.check_offset(offset=offset)
total, directories = summary_watcher.list_summary_directories_by_pagination( limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT)
settings.SUMMARY_BASE_DIR, offset, limit)
train_jobs = [{ processor = TrainTaskManager(DATA_MANAGER)
'train_id': directory['relative_path'], total, train_jobs = processor.query_train_jobs(offset, limit)
'relative_path': directory['relative_path'],
'create_time': directory['create_time'].strftime('%Y-%m-%d %H:%M:%S'),
'update_time': directory['update_time'].strftime('%Y-%m-%d %H:%M:%S'),
} for directory in directories]
return jsonify({ return jsonify({
'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)), 'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)),
...@@ -83,6 +79,18 @@ def query_train_jobs(): ...@@ -83,6 +79,18 @@ def query_train_jobs():
}) })
@BLUEPRINT.route("/datavisual/train-job-caches", methods=["POST"])
def cache_train_jobs():
""" Cache train jobs."""
data = request.get_json(silent=True)
train_ids = data.get('train_ids', [])
processor = TrainTaskManager(DATA_MANAGER)
cache_result = processor.cache_train_jobs(train_ids)
return jsonify({'cache_result': cache_result})
def init_module(app): def init_module(app):
""" """
Init module entry. Init module entry.
......
...@@ -162,6 +162,17 @@ def histogram(): ...@@ -162,6 +162,17 @@ def histogram():
return jsonify(response) return jsonify(response)
@BLUEPRINT.route("/datavisual/scalars", methods=["GET"])
def get_scalars():
"""Get scalar data for given train_ids and tags."""
train_ids = request.args.getlist('train_id')
tags = request.args.getlist('tag')
processor = ScalarsProcessor(DATA_MANAGER)
scalars = processor.get_scalars(train_ids, tags)
return jsonify({'scalars': scalars})
def init_module(app): def init_module(app):
""" """
Init module entry. Init module entry.
......
...@@ -45,7 +45,7 @@ from mindinsight.utils.exceptions import ParamValueError ...@@ -45,7 +45,7 @@ from mindinsight.utils.exceptions import ParamValueError
@enum.unique @enum.unique
class _CacheStatus(enum.Enum): class CacheStatus(enum.Enum):
"""Train job cache status.""" """Train job cache status."""
NOT_IN_CACHE = "NOT_IN_CACHE" NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING" CACHING = "CACHING"
...@@ -63,13 +63,15 @@ class _BasicTrainJob: ...@@ -63,13 +63,15 @@ class _BasicTrainJob:
abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath(). abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory. create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory. update_time (DateTime): The latest modify time of summary files directly in the summary directory.
profiler_dir (str): The relative path of profiler directory.
""" """
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time): def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time, profiler_dir):
self._train_id = train_id self._train_id = train_id
self._abs_summary_base_dir = abs_summary_base_dir self._abs_summary_base_dir = abs_summary_base_dir
self._abs_summary_dir = abs_summary_dir self._abs_summary_dir = abs_summary_dir
self._create_time = create_time self._create_time = create_time
self._update_time = update_time self._update_time = update_time
self._profiler_dir = profiler_dir
@property @property
def abs_summary_dir(self): def abs_summary_dir(self):
...@@ -86,6 +88,16 @@ class _BasicTrainJob: ...@@ -86,6 +88,16 @@ class _BasicTrainJob:
"""Get train id.""" """Get train id."""
return self._train_id return self._train_id
@property
def profiler_dir(self):
"""Get profiler directory path."""
return self._profiler_dir
@property
def create_time(self):
"""Get create time."""
return self._create_time
@property @property
def update_time(self): def update_time(self):
"""Get update time.""" """Get update time."""
...@@ -108,7 +120,7 @@ class CachedTrainJob: ...@@ -108,7 +120,7 @@ class CachedTrainJob:
# Other cached content is stored here. # Other cached content is stored here.
self._content = {} self._content = {}
self._cache_status = _CacheStatus.NOT_IN_CACHE self._cache_status = CacheStatus.NOT_IN_CACHE
self._key_locks = {} self._key_locks = {}
@property @property
...@@ -203,7 +215,7 @@ class TrainJob: ...@@ -203,7 +215,7 @@ class TrainJob:
self._brief = brief_train_job self._brief = brief_train_job
self._detail = detail_train_job self._detail = detail_train_job
if self._detail is None: if self._detail is None:
self._cache_status = _CacheStatus.NOT_IN_CACHE self._cache_status = CacheStatus.NOT_IN_CACHE
else: else:
self._cache_status = self._detail.cache_status self._cache_status = self._detail.cache_status
...@@ -241,6 +253,20 @@ class TrainJob: ...@@ -241,6 +253,20 @@ class TrainJob:
""" """
return self._brief.get(key) return self._brief.get(key)
def get_basic_info(self):
"""
Get basic info.
Returns:
basic_info (_BasicTrainJob): Basic info about the train job.
"""
return self._brief.basic_info
@property
def cache_status(self):
"""Get cache status."""
return self._cache_status
class BaseCacheItemUpdater(abc.ABC): class BaseCacheItemUpdater(abc.ABC):
"""Abstract base class for other modules to update cache content.""" """Abstract base class for other modules to update cache content."""
...@@ -686,7 +712,7 @@ class _DetailCacheManager(_BaseCacheManager): ...@@ -686,7 +712,7 @@ class _DetailCacheManager(_BaseCacheManager):
train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job) train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)
# Will assign real value in future. # Will assign real value in future.
train_job_obj.cache_status = _CacheStatus.CACHED train_job_obj.cache_status = CacheStatus.CACHED
return train_job_obj return train_job_obj
...@@ -863,6 +889,7 @@ class DataManager: ...@@ -863,6 +889,7 @@ class DataManager:
basic_train_jobs = [] basic_train_jobs = []
for info in summaries_info: for info in summaries_info:
profiler = info['profiler']
basic_train_jobs.append(_BasicTrainJob( basic_train_jobs.append(_BasicTrainJob(
train_id=info['relative_path'], train_id=info['relative_path'],
abs_summary_base_dir=self._summary_base_dir, abs_summary_base_dir=self._summary_base_dir,
...@@ -871,7 +898,8 @@ class DataManager: ...@@ -871,7 +898,8 @@ class DataManager:
info['relative_path'] info['relative_path']
)), )),
create_time=info['create_time'], create_time=info['create_time'],
update_time=info['update_time'] update_time=info['update_time'],
profiler_dir=None if profiler is None else profiler['directory'],
)) ))
self._brief_cache.update_cache(basic_train_jobs) self._brief_cache.update_cache(basic_train_jobs)
......
...@@ -31,6 +31,7 @@ class SummaryWatcher: ...@@ -31,6 +31,7 @@ class SummaryWatcher:
SUMMARY_FILENAME_REGEX = r'summary\.(?P<timestamp>\d+)' SUMMARY_FILENAME_REGEX = r'summary\.(?P<timestamp>\d+)'
PB_FILENAME_REGEX = r'\.pb$' PB_FILENAME_REGEX = r'\.pb$'
PROFILER_DIRECTORY_REGEX = r'^profiler$'
MAX_SUMMARY_DIR_COUNT = 999 MAX_SUMMARY_DIR_COUNT = 999
# scan at most 20000 files/directories (approximately 1 seconds) # scan at most 20000 files/directories (approximately 1 seconds)
...@@ -52,6 +53,8 @@ class SummaryWatcher: ...@@ -52,6 +53,8 @@ class SummaryWatcher:
starting with "./". starting with "./".
- create_time (datetime): Creation time of summary file. - create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file. - update_time (datetime): Modification time of summary file.
- profiler (dict): profiler info, including profiler subdirectory path, profiler creation time and
profiler modification time.
Examples: Examples:
>>> from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher >>> from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
...@@ -95,7 +98,7 @@ class SummaryWatcher: ...@@ -95,7 +98,7 @@ class SummaryWatcher:
if entry.is_symlink(): if entry.is_symlink():
pass pass
elif entry.is_file(): elif entry.is_file():
self._update_summary_dict(summary_dict, relative_path, entry) self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry)
elif entry.is_dir(): elif entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
try: try:
...@@ -103,27 +106,39 @@ class SummaryWatcher: ...@@ -103,27 +106,39 @@ class SummaryWatcher:
except PermissionError: except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name) logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
continue continue
self._scan_subdir_entries(summary_dict, subdir_entries, entry.name, counter) self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter)
directories = [{ directories = []
'relative_path': key, for key, value in summary_dict.items():
'create_time': value['ctime'], directory = {
'update_time': value['mtime'], 'relative_path': key,
} for key, value in summary_dict.items()] 'profiler': None,
'create_time': value['ctime'],
'update_time': value['mtime'],
}
profiler = value.get('profiler')
if profiler is not None:
directory['profiler'] = {
'directory': profiler['directory'],
'create_time': profiler['ctime'],
'update_time': profiler['mtime'],
}
directories.append(directory)
# sort by update time in descending order and relative path in ascending order # sort by update time in descending order and relative path in ascending order
directories.sort(key=lambda x: (-int(x['update_time'].timestamp()), x['relative_path'])) directories.sort(key=lambda x: (-int(x['update_time'].timestamp()), x['relative_path']))
return directories return directories
def _scan_subdir_entries(self, summary_dict, subdir_entries, entry_name, counter): def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter):
""" """
Scan subdir entries. Scan subdir entries.
Args: Args:
summary_dict (dict): Temporary data structure to hold summary directory info. summary_dict (dict): Temporary data structure to hold summary directory info.
subdir_entries(DirEntry): Directory entry instance. summary_base_dir (str): Path of summary base directory.
entry_name (str): Name of entry. entry_name (str): Name of entry.
subdir_entries(DirEntry): Directory entry instance.
counter (Counter): An instance of CountLimiter. counter (Counter): An instance of CountLimiter.
""" """
...@@ -139,8 +154,7 @@ class SummaryWatcher: ...@@ -139,8 +154,7 @@ class SummaryWatcher:
subdir_relative_path = os.path.join('.', entry_name) subdir_relative_path = os.path.join('.', entry_name)
if subdir_entry.is_symlink(): if subdir_entry.is_symlink():
pass pass
elif subdir_entry.is_file(): self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
self._update_summary_dict(summary_dict, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs): def _contains_null_byte(self, **kwargs):
""" """
...@@ -194,40 +208,62 @@ class SummaryWatcher: ...@@ -194,40 +208,62 @@ class SummaryWatcher:
return True return True
def _update_summary_dict(self, summary_dict, relative_path, entry): def _update_summary_dict(self, summary_dict, summary_base_dir, relative_path, entry):
""" """
Update summary_dict with ctime and mtime. Update summary_dict with ctime and mtime.
Args: Args:
summary_dict (dict): Temporary data structure to hold summary directory info. summary_dict (dict): Temporary data structure to hold summary directory info.
summary_base_dir (str): Path of summary base directory.
relative_path (str): Relative path of summary directory, referring to summary base directory, relative_path (str): Relative path of summary directory, referring to summary base directory,
starting with "./" . starting with "./" .
entry (DirEntry): Directory entry instance needed to check with regular expression. entry (DirEntry): Directory entry instance needed to check with regular expression.
""" """
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if summary_pattern is None and pb_pattern is None:
return
if summary_pattern is not None: if entry.is_file():
timestamp = int(summary_pattern.groupdict().get('timestamp')) summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
try: pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
# extract created time from filename if summary_pattern is None and pb_pattern is None:
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone() return
except OverflowError: if summary_pattern is not None:
timestamp = int(summary_pattern.groupdict().get('timestamp'))
try:
# extract created time from filename
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone()
except OverflowError:
return
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': None,
}
elif summary_dict[relative_path]['ctime'] < ctime:
summary_dict[relative_path].update({
'ctime': ctime,
'mtime': mtime,
})
elif entry.is_dir():
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name)
if profiler_pattern is None or self._is_empty_directory(full_dir_path):
return return
else:
ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
# extract modified time from filesystem
mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if relative_path not in summary_dict or summary_dict[relative_path]['ctime'] < ctime: profiler = {
summary_dict[relative_path] = { 'directory': os.path.join('.', entry.name),
'ctime': ctime, 'ctime': ctime,
'mtime': mtime, 'mtime': mtime,
} }
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': profiler,
}
def is_summary_directory(self, summary_base_dir, relative_path): def is_summary_directory(self, summary_base_dir, relative_path):
""" """
Check if the given summary directory is valid. Check if the given summary directory is valid.
...@@ -259,15 +295,28 @@ class SummaryWatcher: ...@@ -259,15 +295,28 @@ class SummaryWatcher:
raise FileSystemPermissionError('Path of summary base directory is not accessible.') raise FileSystemPermissionError('Path of summary base directory is not accessible.')
for entry in entries: for entry in entries:
if entry.is_symlink() or not entry.is_file(): if entry.is_symlink():
continue continue
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
if summary_pattern is not None and entry.is_file():
return True
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern or pb_pattern: if pb_pattern is not None and entry.is_file():
return True return True
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
if profiler_pattern is not None and entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_directory, entry.name))
if not self._is_empty_directory(full_path):
return True
return False return False
def _is_empty_directory(self, directory):
return not bool(os.listdir(directory))
def list_summary_directories_by_pagination(self, summary_base_dir, offset=0, limit=10): def list_summary_directories_by_pagination(self, summary_base_dir, offset=0, limit=10):
""" """
List summary directories within base directory. List summary directories within base directory.
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Scalar Processor APIs.""" """Scalar Processor APIs."""
from mindinsight.utils.exceptions import ParamValueError from urllib.parse import unquote
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
...@@ -46,3 +49,47 @@ class ScalarsProcessor(BaseProcessor): ...@@ -46,3 +49,47 @@ class ScalarsProcessor(BaseProcessor):
'step': tensor.step, 'step': tensor.step,
'value': tensor.value}) 'value': tensor.value})
return dict(metadatas=job_response) return dict(metadatas=job_response)
def get_scalars(self, train_ids, tags):
"""
Get scalar data for given train_ids and tags.
Args:
train_ids (list): Specify list of train job ID.
tags (list): Specify list of tags.
Returns:
list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar.
"""
for index, train_id in enumerate(train_ids):
try:
train_id = unquote(train_id, errors='strict')
except UnicodeDecodeError:
raise UrlDecodeError('Unquote train id error with strict mode')
else:
train_ids[index] = train_id
scalars = []
for train_id in train_ids:
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError:
continue
scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}
for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})
scalars.append(scalar)
return scalars
...@@ -14,11 +14,13 @@ ...@@ -14,11 +14,13 @@
# ============================================================================ # ============================================================================
"""Train task manager.""" """Train task manager."""
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
from mindinsight.datavisual.data_transform.data_manager import CacheStatus
class TrainTaskManager(BaseProcessor): class TrainTaskManager(BaseProcessor):
...@@ -75,3 +77,78 @@ class TrainTaskManager(BaseProcessor): ...@@ -75,3 +77,78 @@ class TrainTaskManager(BaseProcessor):
return dict( return dict(
plugins=plugins plugins=plugins
) )
def query_train_jobs(self, offset=0, limit=10):
"""
Query train jobs.
Args:
offset (int): Specify page number. Default is 0.
limit (int): Specify page size. Default is 10.
Returns:
tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = list(brief_cache.get_train_jobs().values())
brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
total = len(brief_train_jobs)
start = offset * limit
end = (offset + 1) * limit
train_jobs = []
train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
for train_id in train_ids:
try:
train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError:
logger.warning('Train job %s not existed', train_id)
continue
basic_info = train_job.get_basic_info()
train_job_item = dict(
train_id=basic_info.train_id,
relative_path=basic_info.train_id,
create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
profiler_dir=basic_info.profiler_dir,
cache_status=train_job.cache_status.value,
)
plugins = self.get_plugins(train_id)
train_job_item.update(plugins)
train_jobs.append(train_job_item)
return total, train_jobs
def cache_train_jobs(self, train_ids):
"""
Cache train jobs.
Args:
train_ids (list): Specify list of train_ids to be cached.
Returns:
dict, indicates train job ID and its current cache status.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = brief_cache.get_train_jobs()
for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job is None:
raise exceptions.TrainJobNotExistError(f'Train id {train_id} not exists')
cache_result = []
for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job.cache_status.value == CacheStatus.NOT_IN_CACHE.value:
self._data_manager.cache_train_job(train_id)
cache_result.append({
'train_id': train_id,
'cache_status': brief_train_job.cache_status.value,
})
return cache_result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册