From 88932364179a6d8a059e75ee680a95eb09a20cb9 Mon Sep 17 00:00:00 2001 From: liangyongxiong Date: Mon, 18 May 2020 19:28:07 +0800 Subject: [PATCH] compare scalars within multiple train jobs --- .../backend/datavisual/task_manager_api.py | 28 +++-- .../backend/datavisual/train_visual_api.py | 11 ++ .../datavisual/data_transform/data_manager.py | 40 ++++++- .../data_transform/summary_watcher.py | 113 +++++++++++++----- .../processors/scalars_processor.py | 49 +++++++- .../processors/train_task_manager.py | 77 ++++++++++++ 6 files changed, 269 insertions(+), 49 deletions(-) diff --git a/mindinsight/backend/datavisual/task_manager_api.py b/mindinsight/backend/datavisual/task_manager_api.py index 8bba324..f817c48 100644 --- a/mindinsight/backend/datavisual/task_manager_api.py +++ b/mindinsight/backend/datavisual/task_manager_api.py @@ -25,10 +25,11 @@ from flask import request from flask import jsonify from mindinsight.conf import settings +from mindinsight.datavisual.common.validation import Validation +from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.utils.tools import str_to_bool from mindinsight.datavisual.utils.tools import get_train_id from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager -from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER @@ -65,16 +66,11 @@ def query_train_jobs(): offset = request.args.get("offset", default=0) limit = request.args.get("limit", default=10) - summary_watcher = SummaryWatcher() - total, directories = summary_watcher.list_summary_directories_by_pagination( - settings.SUMMARY_BASE_DIR, offset, limit) + offset = Validation.check_offset(offset=offset) + limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT) - train_jobs = [{ - 'train_id': directory['relative_path'], - 'relative_path': directory['relative_path'], - 'create_time': directory['create_time'].strftime('%Y-%m-%d %H:%M:%S'), - 'update_time': directory['update_time'].strftime('%Y-%m-%d %H:%M:%S'), - } for directory in directories] + processor = TrainTaskManager(DATA_MANAGER) + total, train_jobs = processor.query_train_jobs(offset, limit) return jsonify({ 'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)), @@ -83,6 +79,18 @@ def query_train_jobs(): }) +@BLUEPRINT.route("/datavisual/train-job-caches", methods=["POST"]) +def cache_train_jobs(): + """ Cache train jobs.""" + data = request.get_json(silent=True) + train_ids = data.get('train_ids', []) + + processor = TrainTaskManager(DATA_MANAGER) + cache_result = processor.cache_train_jobs(train_ids) + + return jsonify({'cache_result': cache_result}) + + def init_module(app): """ Init module entry. diff --git a/mindinsight/backend/datavisual/train_visual_api.py b/mindinsight/backend/datavisual/train_visual_api.py index 24fcd6c..91871ae 100644 --- a/mindinsight/backend/datavisual/train_visual_api.py +++ b/mindinsight/backend/datavisual/train_visual_api.py @@ -162,6 +162,17 @@ def histogram(): return jsonify(response) +@BLUEPRINT.route("/datavisual/scalars", methods=["GET"]) +def get_scalars(): + """Get scalar data for given train_ids and tags.""" + train_ids = request.args.getlist('train_id') + tags = request.args.getlist('tag') + + processor = ScalarsProcessor(DATA_MANAGER) + scalars = processor.get_scalars(train_ids, tags) + return jsonify({'scalars': scalars}) + + def init_module(app): """ Init module entry. diff --git a/mindinsight/datavisual/data_transform/data_manager.py b/mindinsight/datavisual/data_transform/data_manager.py index 7c43a86..9970f48 100644 --- a/mindinsight/datavisual/data_transform/data_manager.py +++ b/mindinsight/datavisual/data_transform/data_manager.py @@ -45,7 +45,7 @@ from mindinsight.utils.exceptions import ParamValueError @enum.unique -class _CacheStatus(enum.Enum): +class CacheStatus(enum.Enum): """Train job cache status.""" NOT_IN_CACHE = "NOT_IN_CACHE" CACHING = "CACHING" @@ -63,13 +63,15 @@ class _BasicTrainJob: abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath(). create_time (DateTime): The create time of summary directory. update_time (DateTime): The latest modify time of summary files directly in the summary directory. + profiler_dir (str): The relative path of profiler directory. """ - def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time): + def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time, profiler_dir): self._train_id = train_id self._abs_summary_base_dir = abs_summary_base_dir self._abs_summary_dir = abs_summary_dir self._create_time = create_time self._update_time = update_time + self._profiler_dir = profiler_dir @property def abs_summary_dir(self): @@ -86,6 +88,16 @@ class _BasicTrainJob: """Get train id.""" return self._train_id + @property + def profiler_dir(self): + """Get profiler directory path.""" + return self._profiler_dir + + @property + def create_time(self): + """Get create time.""" + return self._create_time + @property def update_time(self): """Get update time.""" @@ -108,7 +120,7 @@ class CachedTrainJob: # Other cached content is stored here. self._content = {} - self._cache_status = _CacheStatus.NOT_IN_CACHE + self._cache_status = CacheStatus.NOT_IN_CACHE self._key_locks = {} @property @@ -203,7 +215,7 @@ class TrainJob: self._brief = brief_train_job self._detail = detail_train_job if self._detail is None: - self._cache_status = _CacheStatus.NOT_IN_CACHE + self._cache_status = CacheStatus.NOT_IN_CACHE else: self._cache_status = self._detail.cache_status @@ -241,6 +253,20 @@ class TrainJob: """ return self._brief.get(key) + def get_basic_info(self): + """ + Get basic info. + + Returns: + basic_info (_BasicTrainJob): Basic info about the train job. + """ + return self._brief.basic_info + + @property + def cache_status(self): + """Get cache status.""" + return self._cache_status + class BaseCacheItemUpdater(abc.ABC): """Abstract base class for other modules to update cache content.""" @@ -686,7 +712,7 @@ class _DetailCacheManager(_BaseCacheManager): train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job) # Will assign real value in future. - train_job_obj.cache_status = _CacheStatus.CACHED + train_job_obj.cache_status = CacheStatus.CACHED return train_job_obj @@ -863,6 +889,7 @@ class DataManager: basic_train_jobs = [] for info in summaries_info: + profiler = info['profiler'] basic_train_jobs.append(_BasicTrainJob( train_id=info['relative_path'], abs_summary_base_dir=self._summary_base_dir, @@ -871,7 +898,8 @@ class DataManager: info['relative_path'] )), create_time=info['create_time'], - update_time=info['update_time'] + update_time=info['update_time'], + profiler_dir=None if profiler is None else profiler['directory'], )) self._brief_cache.update_cache(basic_train_jobs) diff --git a/mindinsight/datavisual/data_transform/summary_watcher.py b/mindinsight/datavisual/data_transform/summary_watcher.py index e9680c6..8a4fff4 100644 --- a/mindinsight/datavisual/data_transform/summary_watcher.py +++ b/mindinsight/datavisual/data_transform/summary_watcher.py @@ -31,6 +31,7 @@ class SummaryWatcher: SUMMARY_FILENAME_REGEX = r'summary\.(?P\d+)' PB_FILENAME_REGEX = r'\.pb$' + PROFILER_DIRECTORY_REGEX = r'^profiler$' MAX_SUMMARY_DIR_COUNT = 999 # scan at most 20000 files/directories (approximately 1 seconds) @@ -52,6 +53,8 @@ class SummaryWatcher: starting with "./". - create_time (datetime): Creation time of summary file. - update_time (datetime): Modification time of summary file. + - profiler (dict): profiler info, including profiler subdirectory path, profiler creation time and + profiler modification time. Examples: >>> from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher @@ -95,7 +98,7 @@ class SummaryWatcher: if entry.is_symlink(): pass elif entry.is_file(): - self._update_summary_dict(summary_dict, relative_path, entry) + self._update_summary_dict(summary_dict, summary_base_dir, relative_path, entry) elif entry.is_dir(): full_path = os.path.realpath(os.path.join(summary_base_dir, entry.name)) try: @@ -103,27 +106,39 @@ class SummaryWatcher: except PermissionError: logger.warning('Path of %s under summary base directory is not accessible.', entry.name) continue - self._scan_subdir_entries(summary_dict, subdir_entries, entry.name, counter) - - directories = [{ - 'relative_path': key, - 'create_time': value['ctime'], - 'update_time': value['mtime'], - } for key, value in summary_dict.items()] + self._scan_subdir_entries(summary_dict, summary_base_dir, subdir_entries, entry.name, counter) + + directories = [] + for key, value in summary_dict.items(): + directory = { + 'relative_path': key, + 'profiler': None, + 'create_time': value['ctime'], + 'update_time': value['mtime'], + } + profiler = value.get('profiler') + if profiler is not None: + directory['profiler'] = { + 'directory': profiler['directory'], + 'create_time': profiler['ctime'], + 'update_time': profiler['mtime'], + } + directories.append(directory) # sort by update time in descending order and relative path in ascending order directories.sort(key=lambda x: (-int(x['update_time'].timestamp()), x['relative_path'])) return directories - def _scan_subdir_entries(self, summary_dict, subdir_entries, entry_name, counter): + def _scan_subdir_entries(self, summary_dict, summary_base_dir, subdir_entries, entry_name, counter): """ Scan subdir entries. Args: summary_dict (dict): Temporary data structure to hold summary directory info. - subdir_entries(DirEntry): Directory entry instance. + summary_base_dir (str): Path of summary base directory. entry_name (str): Name of entry. + subdir_entries(DirEntry): Directory entry instance. counter (Counter): An instance of CountLimiter. """ @@ -139,8 +154,7 @@ class SummaryWatcher: subdir_relative_path = os.path.join('.', entry_name) if subdir_entry.is_symlink(): pass - elif subdir_entry.is_file(): - self._update_summary_dict(summary_dict, subdir_relative_path, subdir_entry) + self._update_summary_dict(summary_dict, summary_base_dir, subdir_relative_path, subdir_entry) def _contains_null_byte(self, **kwargs): """ @@ -194,40 +208,62 @@ class SummaryWatcher: return True - def _update_summary_dict(self, summary_dict, relative_path, entry): + def _update_summary_dict(self, summary_dict, summary_base_dir, relative_path, entry): """ Update summary_dict with ctime and mtime. Args: summary_dict (dict): Temporary data structure to hold summary directory info. + summary_base_dir (str): Path of summary base directory. relative_path (str): Relative path of summary directory, referring to summary base directory, starting with "./" . entry (DirEntry): Directory entry instance needed to check with regular expression. """ - summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) - pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) - if summary_pattern is None and pb_pattern is None: - return + ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone() + mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone() - if summary_pattern is not None: - timestamp = int(summary_pattern.groupdict().get('timestamp')) - try: - # extract created time from filename - ctime = datetime.datetime.fromtimestamp(timestamp).astimezone() - except OverflowError: + if entry.is_file(): + summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) + pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) + if summary_pattern is None and pb_pattern is None: + return + if summary_pattern is not None: + timestamp = int(summary_pattern.groupdict().get('timestamp')) + try: + # extract created time from filename + ctime = datetime.datetime.fromtimestamp(timestamp).astimezone() + except OverflowError: + return + if relative_path not in summary_dict: + summary_dict[relative_path] = { + 'ctime': ctime, + 'mtime': mtime, + 'profiler': None, + } + elif summary_dict[relative_path]['ctime'] < ctime: + summary_dict[relative_path].update({ + 'ctime': ctime, + 'mtime': mtime, + }) + elif entry.is_dir(): + profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) + full_dir_path = os.path.join(summary_base_dir, relative_path, entry.name) + if profiler_pattern is None or self._is_empty_directory(full_dir_path): return - else: - ctime = datetime.datetime.fromtimestamp(entry.stat().st_ctime).astimezone() - - # extract modified time from filesystem - mtime = datetime.datetime.fromtimestamp(entry.stat().st_mtime).astimezone() - if relative_path not in summary_dict or summary_dict[relative_path]['ctime'] < ctime: - summary_dict[relative_path] = { + profiler = { + 'directory': os.path.join('.', entry.name), 'ctime': ctime, 'mtime': mtime, } + if relative_path not in summary_dict: + summary_dict[relative_path] = { + 'ctime': ctime, + 'mtime': mtime, + 'profiler': profiler, + } + def is_summary_directory(self, summary_base_dir, relative_path): """ Check if the given summary directory is valid. @@ -259,15 +295,28 @@ class SummaryWatcher: raise FileSystemPermissionError('Path of summary base directory is not accessible.') for entry in entries: - if entry.is_symlink() or not entry.is_file(): + if entry.is_symlink(): continue + summary_pattern = re.search(self.SUMMARY_FILENAME_REGEX, entry.name) + if summary_pattern is not None and entry.is_file(): + return True + pb_pattern = re.search(self.PB_FILENAME_REGEX, entry.name) - if summary_pattern or pb_pattern: + if pb_pattern is not None and entry.is_file(): return True + profiler_pattern = re.search(self.PROFILER_DIRECTORY_REGEX, entry.name) + if profiler_pattern is not None and entry.is_dir(): + full_path = os.path.realpath(os.path.join(summary_directory, entry.name)) + if not self._is_empty_directory(full_path): + return True + return False + def _is_empty_directory(self, directory): + return not bool(os.listdir(directory)) + def list_summary_directories_by_pagination(self, summary_base_dir, offset=0, limit=10): """ List summary directories within base directory. diff --git a/mindinsight/datavisual/processors/scalars_processor.py b/mindinsight/datavisual/processors/scalars_processor.py index d7411fd..880422b 100644 --- a/mindinsight/datavisual/processors/scalars_processor.py +++ b/mindinsight/datavisual/processors/scalars_processor.py @@ -13,7 +13,10 @@ # limitations under the License. # ============================================================================ """Scalar Processor APIs.""" -from mindinsight.utils.exceptions import ParamValueError +from urllib.parse import unquote + +from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError +from mindinsight.datavisual.utils.tools import if_nan_inf_to_none from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.processors.base_processor import BaseProcessor @@ -46,3 +49,47 @@ class ScalarsProcessor(BaseProcessor): 'step': tensor.step, 'value': tensor.value}) return dict(metadatas=job_response) + + def get_scalars(self, train_ids, tags): + """ + Get scalar data for given train_ids and tags. + + Args: + train_ids (list): Specify list of train job ID. + tags (list): Specify list of tags. + + Returns: + list[dict], a list of dictionaries containing the `wall_time`, `step`, `value` for each scalar. + """ + for index, train_id in enumerate(train_ids): + try: + train_id = unquote(train_id, errors='strict') + except UnicodeDecodeError: + raise UrlDecodeError('Unquote train id error with strict mode') + else: + train_ids[index] = train_id + + scalars = [] + for train_id in train_ids: + for tag in tags: + try: + tensors = self._data_manager.list_tensors(train_id, tag) + except ParamValueError: + continue + + scalar = { + 'train_id': train_id, + 'tag': tag, + 'values': [], + } + + for tensor in tensors: + scalar['values'].append({ + 'wall_time': tensor.wall_time, + 'step': tensor.step, + 'value': if_nan_inf_to_none('scalar_value', tensor.value), + }) + + scalars.append(scalar) + + return scalars diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index bf72e2b..5c5cfb8 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -14,11 +14,13 @@ # ============================================================================ """Train task manager.""" +from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY +from mindinsight.datavisual.data_transform.data_manager import CacheStatus class TrainTaskManager(BaseProcessor): @@ -75,3 +77,78 @@ class TrainTaskManager(BaseProcessor): return dict( plugins=plugins ) + + def query_train_jobs(self, offset=0, limit=10): + """ + Query train jobs. + + Args: + offset (int): Specify page number. Default is 0. + limit (int): Specify page size. Default is 10. + + Returns: + tuple, return quantity of total train jobs and list of train jobs specified by offset and limit. + """ + brief_cache = self._data_manager.get_brief_cache() + brief_train_jobs = list(brief_cache.get_train_jobs().values()) + brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True) + total = len(brief_train_jobs) + + start = offset * limit + end = (offset + 1) * limit + train_jobs = [] + + train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]] + + for train_id in train_ids: + try: + train_job = self._data_manager.get_train_job(train_id) + except exceptions.TrainJobNotExistError: + logger.warning('Train job %s not existed', train_id) + continue + + basic_info = train_job.get_basic_info() + train_job_item = dict( + train_id=basic_info.train_id, + relative_path=basic_info.train_id, + create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'), + update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'), + profiler_dir=basic_info.profiler_dir, + cache_status=train_job.cache_status.value, + ) + plugins = self.get_plugins(train_id) + train_job_item.update(plugins) + train_jobs.append(train_job_item) + + return total, train_jobs + + def cache_train_jobs(self, train_ids): + """ + Cache train jobs. + + Args: + train_ids (list): Specify list of train_ids to be cached. + + Returns: + dict, indicates train job ID and its current cache status. + """ + brief_cache = self._data_manager.get_brief_cache() + brief_train_jobs = brief_cache.get_train_jobs() + + for train_id in train_ids: + brief_train_job = brief_train_jobs.get(train_id) + if brief_train_job is None: + raise exceptions.TrainJobNotExistError(f'Train id {train_id} not exists') + + cache_result = [] + for train_id in train_ids: + brief_train_job = brief_train_jobs.get(train_id) + if brief_train_job.cache_status.value == CacheStatus.NOT_IN_CACHE.value: + self._data_manager.cache_train_job(train_id) + + cache_result.append({ + 'train_id': train_id, + 'cache_status': brief_train_job.cache_status.value, + }) + + return cache_result -- GitLab