提交 88932364 编写于 作者: L liangyongxiong

compare scalars within multiple train jobs

上级 9e0d7cd5
......@@ -25,10 +25,11 @@ from flask import request
from flask import jsonify
from mindinsight.conf import settings
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import str_to_bool
from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
......@@ -65,16 +66,11 @@ def query_train_jobs():
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10)
summary_watcher = SummaryWatcher()
total, directories = summary_watcher.list_summary_directories_by_pagination(
settings.SUMMARY_BASE_DIR, offset, limit)
offset = Validation.check_offset(offset=offset)
limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT)
train_jobs = [{
'train_id': directory['relative_path'],
'relative_path': directory['relative_path'],
'create_time': directory['create_time'].strftime('%Y-%m-%d %H:%M:%S'),
'update_time': directory['update_time'].strftime('%Y-%m-%d %H:%M:%S'),
} for directory in directories]
processor = TrainTaskManager(DATA_MANAGER)
total, train_jobs = processor.query_train_jobs(offset, limit)
return jsonify({
'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)),
......@@ -83,6 +79,18 @@ def query_train_jobs():
})
@BLUEPRINT.route("/datavisual/train-job-caches", methods=["POST"])
def cache_train_jobs():
""" Cache train jobs."""
data = request.get_json(silent=True)
train_ids = data.get('train_ids', [])
processor = TrainTaskManager(DATA_MANAGER)
cache_result = processor.cache_train_jobs(train_ids)
return jsonify({'cache_result': cache_result})
def init_module(app):
"""
Init module entry.
......
......@@ -162,6 +162,17 @@ def histogram():
return jsonify(response)
@BLUEPRINT.route("/datavisual/scalars", methods=["GET"])
def get_scalars():
"""Get scalar data for given train_ids and tags."""
train_ids = request.args.getlist('train_id')
tags = request.args.getlist('tag')
processor = ScalarsProcessor(DATA_MANAGER)
scalars = processor.get_scalars(train_ids, tags)
return jsonify({'scalars': scalars})
def init_module(app):
"""
Init module entry.
......
......@@ -45,7 +45,7 @@ from mindinsight.utils.exceptions import ParamValueError
@enum.unique
class _CacheStatus(enum.Enum):
class CacheStatus(enum.Enum):
"""Train job cache status."""
NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING"
......@@ -63,13 +63,15 @@ class _BasicTrainJob:
abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory.
profiler_dir (str): The relative path of profiler directory.
"""
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time):
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time, profiler_dir):
self._train_id = train_id
self._abs_summary_base_dir = abs_summary_base_dir
self._abs_summary_dir = abs_summary_dir
self._create_time = create_time
self._update_time = update_time
self._profiler_dir = profiler_dir
@property
def abs_summary_dir(self):
......@@ -86,6 +88,16 @@ class _BasicTrainJob:
"""Get train id."""
return self._train_id
@property
def profiler_dir(self):
"""Get profiler directory path."""
return self._profiler_dir
@property
def create_time(self):
"""Get create time."""
return self._create_time
@property
def update_time(self):
"""Get update time."""
......@@ -108,7 +120,7 @@ class CachedTrainJob:
# Other cached content is stored here.
self._content = {}
self._cache_status = _CacheStatus.NOT_IN_CACHE
self._cache_status = CacheStatus.NOT_IN_CACHE
self._key_locks = {}
@property
......@@ -203,7 +215,7 @@ class TrainJob:
self._brief = brief_train_job
self._detail = detail_train_job
if self._detail is None:
self._cache_status = _CacheStatus.NOT_IN_CACHE
self._cache_status = CacheStatus.NOT_IN_CACHE
else:
self._cache_status = self._detail.cache_status
......@@ -241,6 +253,20 @@ class TrainJob:
"""
return self._brief.get(key)
def get_basic_info(self):
"""
Get basic info.
Returns:
basic_info (_BasicTrainJob): Basic info about the train job.
"""
return self._brief.basic_info
@property
def cache_status(self):
"""Get cache status."""
return self._cache_status
class BaseCacheItemUpdater(abc.ABC):
"""Abstract base class for other modules to update cache content."""
......@@ -686,7 +712,7 @@ class _DetailCacheManager(_BaseCacheManager):
train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)
# Will assign real value in future.
train_job_obj.cache_status = _CacheStatus.CACHED
train_job_obj.cache_status = CacheStatus.CACHED
return train_job_obj
......@@ -863,6 +889,7 @@ class DataManager:
basic_train_jobs = []
for info in summaries_info:
profiler = info['profiler']
basic_train_jobs.append(_BasicTrainJob(
train_id=info['relative_path'],
abs_summary_base_dir=self._summary_base_dir,
......@@ -871,7 +898,8 @@ class DataManager:
info['relative_path']
)),
create_time=info['create_time'],
update_time=info['update_time']
update_time=info['update_time'],
profiler_dir=None if profiler is None else profiler['directory'],
))
self._brief_cache.update_cache(basic_train_jobs)
......
......@@ -31,6 +31,7 @@ class SummaryWatcher:
SUMMARY_FILENAME_REGEX = r'summary\.(?P<timestamp>\d+)'
PB_FILENAME_REGEX = r'\.pb$'
PROFILER_DIRECTORY_REGEX = r'^profiler$'
MAX_SUMMARY_DIR_COUNT = 999
# scan at most 20000 files/directories (approximately 1 seconds)
......@@ -52,6 +53,8 @@ class SummaryWatcher:
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
- profiler (dict): profiler info, including profiler subdirectory path, profiler creation time and
profiler modification time.
Examples:
>>> from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
......@@ -95,7 +98,7 @@ class SummaryWatcher:
if entry.is_symlink():
pass
elif entry.is_file():
self._update_summary_dict(summary_dict, relative_path, entry)
self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry)
elif entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name))
try:
......@@ -103,27 +106,39 @@ class SummaryWatcher:
except PermissionError:
logger.warning('Path of %s under summary base directory is not accessible.', entry.name)
continue
self._scan_subdir_entries(summary_dict, subdir_entries, entry.name, counter)
directories = [{
'relative_path': key,
'create_time': value['ctime'],
'update_time': value['mtime'],
} for key, value in summary_dict.items()]
self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter)
directories = []
for key, value in summary_dict.items():
directory = {
'relative_path': key,
'profiler': None,
'create_time': value['ctime'],
'update_time': value['mtime'],
}
profiler = value.get('profiler')
if profiler is not None:
directory['profiler'] = {
'directory': profiler['directory'],
'create_time': profiler['ctime'],
'update_time': profiler['mtime'],
}
directories.append(directory)
# sort by update time in descending order and relative path in ascending order
directories.sort(key=lambda x: (-int(x['update_time'].timestamp()), x['relative_path']))
return directories
def _scan_subdir_entries(self, summary_dict, subdir_entries, entry_name, counter):
def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter):
"""
Scan subdir entries.
Args:
summary_dict (dict): Temporary data structure to hold summary directory info.
subdir_entries(DirEntry): Directory entry instance.
summary_base_dir (str): Path of summary base directory.
entry_name (str): Name of entry.
subdir_entries(DirEntry): Directory entry instance.
counter (Counter): An instance of CountLimiter.
"""
......@@ -139,8 +154,7 @@ class SummaryWatcher:
subdir_relative_path = os.path.join('.', entry_name)
if subdir_entry.is_symlink():
pass
elif subdir_entry.is_file():
self._update_summary_dict(summary_dict, subdir_relative_path, subdir_entry)
self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry)
def _contains_null_byte(self, **kwargs):
"""
......@@ -194,40 +208,62 @@ class SummaryWatcher:
return True
def _update_summary_dict(self, summary_dict, relative_path, entry):
def _update_summary_dict(self, summary_dict, summary_base_dir, relative_path, entry):
"""
Update summary_dict with ctime and mtime.
Args:
summary_dict (dict): Temporary data structure to hold summary directory info.
summary_base_dir (str): Path of summary base directory.
relative_path (str): Relative path of summary directory, referring to summary base directory,
starting with "./" .
entry (DirEntry): Directory entry instance needed to check with regular expression.
"""
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern is None and pb_pattern is None:
return
ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if summary_pattern is not None:
timestamp = int(summary_pattern.groupdict().get('timestamp'))
try:
# extract created time from filename
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone()
except OverflowError:
if entry.is_file():
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern is None and pb_pattern is None:
return
if summary_pattern is not None:
timestamp = int(summary_pattern.groupdict().get('timestamp'))
try:
# extract created time from filename
ctime = datetime.datetime.fromtimestamp(timestamp).astimezone()
except OverflowError:
return
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': None,
}
elif summary_dict[relative_path]['ctime'] < ctime:
summary_dict[relative_path].update({
'ctime': ctime,
'mtime': mtime,
})
elif entry.is_dir():
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name)
if profiler_pattern is None or self._is_empty_directory(full_dir_path):
return
else:
ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone()
# extract modified time from filesystem
mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone()
if relative_path not in summary_dict or summary_dict[relative_path]['ctime'] < ctime:
summary_dict[relative_path] = {
profiler = {
'directory': os.path.join('.', entry.name),
'ctime': ctime,
'mtime': mtime,
}
if relative_path not in summary_dict:
summary_dict[relative_path] = {
'ctime': ctime,
'mtime': mtime,
'profiler': profiler,
}
def is_summary_directory(self, summary_base_dir, relative_path):
"""
Check if the given summary directory is valid.
......@@ -259,15 +295,28 @@ class SummaryWatcher:
raise FileSystemPermissionError('Path of summary base directory is not accessible.')
for entry in entries:
if entry.is_symlink() or not entry.is_file():
if entry.is_symlink():
continue
summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name)
if summary_pattern is not None and entry.is_file():
return True
pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name)
if summary_pattern or pb_pattern:
if pb_pattern is not None and entry.is_file():
return True
profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name)
if profiler_pattern is not None and entry.is_dir():
full_path = os.path.realpath(os.path.join(summary_directory, entry.name))
if not self._is_empty_directory(full_path):
return True
return False
def _is_empty_directory(self, directory):
return not bool(os.listdir(directory))
def list_summary_directories_by_pagination(self, summary_base_dir, offset=0, limit=10):
"""
List summary directories within base directory.
......
......@@ -13,7 +13,10 @@
# limitations under the License.
# ============================================================================
"""Scalar Processor APIs."""
from mindinsight.utils.exceptions import ParamValueError
from urllib.parse import unquote
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor
......@@ -46,3 +49,47 @@ class ScalarsProcessor(BaseProcessor):
'step': tensor.step,
'value': tensor.value})
return dict(metadatas=job_response)
def get_scalars(self, train_ids, tags):
"""
Get scalar data for given train_ids and tags.
Args:
train_ids (list): Specify list of train job ID.
tags (list): Specify list of tags.
Returns:
list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar.
"""
for index, train_id in enumerate(train_ids):
try:
train_id = unquote(train_id, errors='strict')
except UnicodeDecodeError:
raise UrlDecodeError('Unquote train id error with strict mode')
else:
train_ids[index] = train_id
scalars = []
for train_id in train_ids:
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError:
continue
scalar = {
'train_id': train_id,
'tag': tag,
'values': [],
}
for tensor in tensors:
scalar['values'].append({
'wall_time': tensor.wall_time,
'step': tensor.step,
'value': if_nan_inf_to_none('scalar_value', tensor.value),
})
scalars.append(scalar)
return scalars
......@@ -14,11 +14,13 @@
# ============================================================================
"""Train task manager."""
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
from mindinsight.datavisual.data_transform.data_manager import CacheStatus
class TrainTaskManager(BaseProcessor):
......@@ -75,3 +77,78 @@ class TrainTaskManager(BaseProcessor):
return dict(
plugins=plugins
)
def query_train_jobs(self, offset=0, limit=10):
"""
Query train jobs.
Args:
offset (int): Specify page number. Default is 0.
limit (int): Specify page size. Default is 10.
Returns:
tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = list(brief_cache.get_train_jobs().values())
brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
total = len(brief_train_jobs)
start = offset * limit
end = (offset + 1) * limit
train_jobs = []
train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
for train_id in train_ids:
try:
train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError:
logger.warning('Train job %s not existed', train_id)
continue
basic_info = train_job.get_basic_info()
train_job_item = dict(
train_id=basic_info.train_id,
relative_path=basic_info.train_id,
create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
profiler_dir=basic_info.profiler_dir,
cache_status=train_job.cache_status.value,
)
plugins = self.get_plugins(train_id)
train_job_item.update(plugins)
train_jobs.append(train_job_item)
return total, train_jobs
def cache_train_jobs(self, train_ids):
"""
Cache train jobs.
Args:
train_ids (list): Specify list of train_ids to be cached.
Returns:
dict, indicates train job ID and its current cache status.
"""
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = brief_cache.get_train_jobs()
for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job is None:
raise exceptions.TrainJobNotExistError(f'Train id {train_id} not exists')
cache_result = []
for train_id in train_ids:
brief_train_job = brief_train_jobs.get(train_id)
if brief_train_job.cache_status.value == CacheStatus.NOT_IN_CACHE.value:
self._data_manager.cache_train_job(train_id)
cache_result.append({
'train_id': train_id,
'cache_status': brief_train_job.cache_status.value,
})
return cache_result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册