提交 7e17d6ff 编写于 作者: W wenkai

refactor data manager and unify cache and data access/reload

上级 46d44977
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trigger data manager load."""
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.datavisual.common.log import logger
from mindinsight.conf import settings
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
def init_module(app):
"""
Interface to init module.
Args:
app (Flask): An instance of Flask.
"""
# Just to suppress pylint warning about unused arg.
logger.debug("App: %s", type(app))
DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater())
DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL),
max_threads_count=int(settings.MAX_THREADS_COUNT))
......@@ -18,9 +18,6 @@ from mindinsight.backend.datavisual.static_resource_api import init_module as st
from mindinsight.backend.datavisual.task_manager_api import init_module as task_init_module
from mindinsight.backend.datavisual.train_visual_api import init_module as train_init_module
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
def init_module(app):
"""
......@@ -33,6 +30,3 @@ def init_module(app):
static_init_module(app)
task_init_module(app)
train_init_module(app)
DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL),
max_threads_count=int(settings.MAX_THREADS_COUNT))
......@@ -150,3 +150,12 @@ class HistogramNotExistError(MindInsightException):
super(HistogramNotExistError, self).__init__(DataVisualErrors.HISTOGRAM_NOT_EXIST,
error_msg,
http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."):
error_msg = f'Detail info of the given train job is not in cache. Detail: {error_detail}'
super().__init__(DataVisualErrors.TRAIN_JOB_DETAIL_NOT_IN_CACHE,
error_msg,
http_code=400)
......@@ -20,11 +20,18 @@ It can read events data through the DataLoader.
This module also acts as a thread pool manager.
"""
import abc
import enum
import threading
import time
import datetime
import os
from typing import Iterable, Optional
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.log import logger
......@@ -37,44 +44,369 @@ from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.exceptions import ParamValueError
class DataManager:
@enum.unique
class _CacheStatus(enum.Enum):
"""Train job cache status."""
NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING"
CACHED = "CACHED"
class _BasicTrainJob:
"""
DataManager manages a pool of loader which help access events data.
Basic info about train job.
Args:
train_id (str): Id of the train job.
abs_summary_base_dir (str): The canonical path of summary base directory. It should be the return value of
realpath().
abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory.
"""
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time):
self._train_id = train_id
self._abs_summary_base_dir = abs_summary_base_dir
self._abs_summary_dir = abs_summary_dir
self._create_time = create_time
self._update_time = update_time
Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
@property
def summary_dir(self):
"""Get summary directory path."""
return self._abs_summary_dir
@property
def train_id(self):
"""Get train id."""
return self._train_id
class CachedTrainJob:
"""
def __init__(self, loader_generators):
Cache item for BriefCacheManager.
DetailCacheManager will also wrap it's return value with this class.
Args:
basic_info (_BasicTrainJob): Basic info about the train job.
"""
def __init__(self, basic_info: _BasicTrainJob):
self._basic_info = basic_info
self._last_access_time = datetime.datetime.utcnow()
# Other cached content is stored here.
self._content = {}
self._cache_status = _CacheStatus.NOT_IN_CACHE
@property
def cache_status(self):
"""Get cache status."""
return self._cache_status
@cache_status.setter
def cache_status(self, value):
"""Set cache status."""
self._cache_status = value
def update_access_time(self):
"""Update last access time of this cache item."""
self._last_access_time = datetime.datetime.utcnow()
@property
def last_access_time(self):
"""Get last access time for purposes such as LRU."""
return self._last_access_time
@property
def summary_dir(self):
"""Get summary directory path."""
return self._basic_info.summary_dir
def set(self, key, value):
"""Set value to cache."""
self._content[key] = value
def get(self, key):
"""Get value from cache."""
try:
return self._content[key]
except KeyError:
raise ParamValueError("Invalid cache key({}).".format(key))
@property
def basic_info(self):
"""Get basic train job info."""
return self._basic_info
@basic_info.setter
def basic_info(self, value):
"""Set basic train job info."""
self._basic_info = value
class TrainJob:
"""
Train job object.
You must not create TrainJob objects manually. You should always get TrainJob objects from DataManager.
Args:
brief_train_job (CachedTrainJob): Brief info about train job.
detail_train_job (Optional[CachedTrainJob]): Detailed info about train job. Default: None.
"""
def __init__(self,
brief_train_job: CachedTrainJob,
detail_train_job: Optional[CachedTrainJob] = None):
self._brief = brief_train_job
self._detail = detail_train_job
if self._detail is None:
self._cache_status = _CacheStatus.NOT_IN_CACHE
else:
self._cache_status = self._detail.cache_status
def has_detail(self):
"""Whether this train job has detailed info in cache."""
return bool(self._detail is not None)
def get_detail(self, key):
"""
Initialize the pool of loader and the dict of name-to-path.
Get detail content.
Args:
loader_generators (list[LoaderGenerator]): Loader generators help generate loaders.
key (Any): Cache key.
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
self._loader_pool: {'loader_id': <LoaderStruct>}.
Returns:
Any, cache content.
Raises:
TrainJobDetailNotInCacheError: when this train job has no detail cache.
"""
if not self.has_detail():
raise exceptions.TrainJobDetailNotInCacheError()
return self._detail.get(key)
def get_brief(self, key):
"""
Get brief content.
Args:
key (Any): Cache key.
Returns:
Any, cache content.
"""
return self._brief.get(key)
class BaseCacheItemUpdater(abc.ABC):
"""Abstract base class for other modules to update cache content."""
def update_item(self, cache_item: CachedTrainJob):
"""
Update cache item in place.
Args:
cache_item (CachedTrainJob): The cache item to be processed.
"""
raise NotImplementedError()
class _BaseCacheManager:
"""Base class for cache manager."""
def __init__(self):
# Use dict to remove duplicate updaters.
self._updaters = {}
# key is train_id
self._lock = threading.Lock()
self._cache_items = {}
def size(self):
"""Gets used cache slots."""
return len(self._cache_items)
def register_cache_item_updater(self, updater: BaseCacheItemUpdater):
"""Register cache item updater."""
self._updaters[updater.__class__.__qualname__] = updater
def get_train_jobs(self):
"""Get cached train jobs."""
copied_train_jobs = dict(self._cache_items)
return copied_train_jobs
def get_train_job(self, train_id):
"""Get cached train job."""
try:
return self._cache_items[train_id]
except KeyError:
raise TrainJobNotExistError(train_id)
def cache_train_job(self, train_id) -> bool:
"""
Cache given train job and update train job's last access time.
This method should return true if reload actions should be taken to cache the train job.
Args:
train_id (str): Train Id.
"""
raise NotImplementedError()
def delete_train_job(self, train_id):
"""Delete train job from cache."""
if train_id in self._cache_items:
del self._cache_items[train_id]
def has_content(self):
"""Whether this cache manager has train jobs."""
return bool(self._cache_items)
def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Update cache according to given train jobs on disk.
Different cache manager should implement different cache update policies in this method.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk.
"""
raise NotImplementedError()
def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Merge train jobs in cache with train jobs from disk
This method will remove train jobs not on disk. Call this function with lock for thread safety.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk.
Returns:
dict, a dict containing train jobs to be cached.
"""
new_cache_items = {}
for train_job in disk_train_jobs:
if train_job.train_id not in self._cache_items:
new_cache_items[train_job.train_id] = CachedTrainJob(train_job)
else:
reused_train_job = self._cache_items[train_job.train_id]
reused_train_job.basic_info = train_job
new_cache_items[train_job.train_id] = reused_train_job
return new_cache_items
class _BriefCacheManager(_BaseCacheManager):
"""A cache manager that holds all disk train jobs on disk."""
def cache_train_job(self, train_id):
"""
Cache given train job.
All disk train jobs are cached on every reload, so this method always return false.
Args:
train_id (str): Train Id.
"""
if train_id in self._cache_items:
self._cache_items[train_id].update_access_time()
return False
def update_cache(self, disk_train_jobs):
"""Update cache."""
with self._lock:
new_cache_items = self._merge_with_disk(disk_train_jobs)
self._cache_items = new_cache_items
for updater in self._updaters.values():
for cache_item in self._cache_items.values():
updater.update_item(cache_item)
# Key for plugin tags.
DATAVISUAL_PLUGIN_KEY = "tag_mapping"
# Detail train job cache key for datavisual content.
DATAVISUAL_CACHE_KEY = "datavisual"
class _DetailCacheManager(_BaseCacheManager):
"""A cache manager that holds detailed info for most recently used train jobs."""
def __init__(self, loader_generators):
super().__init__()
self._loader_pool = {}
self._deleted_id_list = []
self._status = DataManagerStatus.INIT.value
self._status_mutex = threading.Lock()
self._loader_pool_mutex = threading.Lock()
self._max_threads_count = 30
self._reload_interval = 3
self._loader_generators = loader_generators
def size(self):
"""
Get the number of items in this cache manager.
To be implemented.
Returns:
int, the number of items in this cache manager.
"""
raise NotImplementedError()
def loader_pool_size(self):
"""Get loader pool size."""
return len(self._loader_pool)
def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Update cache.
Will switch to using disk_train_jobs in the future.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk.
"""
self._generate_loaders()
self._execute_load_data()
def cache_train_job(self, train_id):
"""Cache given train job."""
loader = None
need_reload = False
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
if loader is None:
for generator in self._loader_generators:
tmp_loader = generator.generate_loader_by_train_id(train_id)
if loader and loader.latest_update_time > tmp_loader.latest_update_time:
continue
loader = tmp_loader
if loader is None:
raise TrainJobNotExistError(train_id)
self._add_loader(loader)
need_reload = True
self._update_loader_latest_update_time(loader.loader_id)
return need_reload
def get_train_jobs(self):
"""
Get train jobs
To be implemented.
"""
def _add_loader(self, loader):
"""
Add a loader to load data.
Args:
loader (LoaderStruct): A object of `Loader`.
"""
if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE:
delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1
......@@ -85,40 +417,328 @@ class DataManager:
self._delete_loader(delete_loader_id)
self._loader_pool.update({loader.loader_id: loader})
def _delete_loader(self, loader_id):
"""
Delete loader from loader pool by loader id.
def _delete_loader(self, loader_id):
"""
Delete loader from loader pool by loader id.
Args:
loader_id (str): ID of loader.
"""
if self._loader_pool.get(loader_id) is not None:
logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id):
"""
Load data form data_loader.
If there is something wrong by loading, add logs and delete the loader.
Args:
loader_id (str): An ID for `Loader`.
"""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return
loader.data_loader.load()
except MindInsightException as ex:
logger.warning("Data loader %r load data failed. "
"Delete data_loader. Detail: %s", loader_id, ex)
with self._loader_pool_mutex:
self._delete_loader(loader_id)
def _generate_loaders(self):
"""This function generates the loader from given path."""
loader_dict = {}
for generator in self._loader_generators:
loader_dict.update(generator.generate_loaders(self._loader_pool))
sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
self._deal_loaders(latest_loaders)
def _deal_loaders(self, latest_loaders):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
def _execute_load_data(self):
"""Load data through multiple threads."""
threads_count = self._get_threads_count()
if not threads_count:
logger.info("Can not find any valid train log path to load, loader pool is empty.")
return
logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = []
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
futures.append(future)
wait(futures, return_when=ALL_COMPLETED)
def _get_threads_count(self):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count = min(self._max_threads_count, len(self._loader_pool))
return threads_count
def delete_train_job(self, train_id):
"""
Delete train job with a train id.
Args:
train_id (str): ID for train job.
"""
with self._loader_pool_mutex:
self._delete_loader(train_id)
def list_tensors(self, train_id, tag):
"""
List tensors of the given train job and tag.
If the tensor can not find by the given tag, will raise exception.
Args:
train_id (str): ID for train job.
tag (str): The tag name.
Returns:
NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
the value will contain the given tag data.
"""
loader_pool = self._get_snapshot_loader_pool()
if not self._is_loader_in_loader_pool(train_id, loader_pool):
raise TrainJobNotExistError("Can not find the given train job in cache.")
data_loader = loader_pool[train_id].data_loader
events_data = data_loader.get_events_data()
try:
tensors = events_data.tensors(tag)
except KeyError:
error_msg = "Can not find any data in this train job by given tag."
raise ParamValueError(error_msg)
return tensors
def _check_train_job_exist(self, train_id, loader_pool):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist = False
if train_id in loader_pool:
return
for generator in self._loader_generators:
if generator.check_train_job_exist(train_id):
is_exist = True
break
if not is_exist:
raise TrainJobNotExistError("Can not find the train job in data manager.")
def _is_loader_in_loader_pool(self, train_id, loader_pool):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if train_id in loader_pool:
return True
return False
def _get_snapshot_loader_pool(self):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with self._loader_pool_mutex:
return dict(self._loader_pool)
def get_train_job(self, train_id):
"""
Get train job by train ID.
This method overrides parent method.
Args:
train_id (str): Train ID for train job.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
train_job = loader.to_dict()
train_job.pop('data_loader')
plugin_data = {}
for plugin_name in PluginNameEnum.list_members():
job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
if job is None:
plugin_data[plugin_name] = []
else:
plugin_data[plugin_name] = job['tags']
train_job.update({DATAVISUAL_PLUGIN_KEY: plugin_data})
# Will fill basic_info value in future.
train_job_obj = CachedTrainJob(basic_info=None)
train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)
# Will assign real value in future.
train_job_obj.cache_status = _CacheStatus.CACHED
return train_job_obj
def _get_loader(self, train_id):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
Returns:
LoaderStruct, the loader.
"""
loader = None
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
return loader
def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
"""
Update loader with latest_update_time.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
"""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time
def get_train_job_by_plugin(self, train_id, plugin_name):
"""
Get a train job by train job id.
If the given train job does not has the given plugin data, the tag list will be empty.
Args:
train_id (str): Get train job info by the given id.
plugin_name (str): Get tags by given plugin.
Returns:
TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
a train job object.
"""
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
name = loader.name
data_loader = loader.data_loader
tags = []
try:
events_data = data_loader.get_events_data()
tags = events_data.list_tags_by_plugin(plugin_name)
except KeyError:
logger.debug("Plugin name %r does not exist "
"in train job %r, and set tags to empty list.", plugin_name, name)
except AttributeError:
logger.debug("Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list.", name)
result = dict(id=train_id, name=name, tags=tags)
return result
Args:
loader_id (str): ID of loader.
"""
if self._loader_pool.get(loader_id) is not None:
logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id):
"""
Load data form data_loader.
class DataManager:
"""
DataManager manages a pool of loader which help access events data.
If there is something wrong by loading, add logs and delete the loader.
Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
"""
def __init__(self, summary_base_dir):
"""
Initialize the pool of loader and the dict of name-to-path.
Args:
loader_id (str): An ID for `Loader`.
summary_base_dir (str): Base summary directory.
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
"""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return
loader.data_loader.load()
except MindInsightException as ex:
logger.warning("Data loader %r load data failed. "
"Delete data_loader. Detail: %s", loader_id, ex)
self._summary_base_dir = os.path.realpath(summary_base_dir)
self._status = DataManagerStatus.INIT.value
self._status_mutex = threading.Lock()
with self._loader_pool_mutex:
self._delete_loader(loader_id)
self._reload_interval = 3
loader_generators = [DataLoaderGenerator(self._summary_base_dir)]
self._detail_cache = _DetailCacheManager(loader_generators)
self._brief_cache = _BriefCacheManager()
def start_load_data(self,
reload_interval=settings.RELOAD_INTERVAL,
......@@ -176,64 +796,31 @@ class DataManager:
return
self.status = DataManagerStatus.LOADING.value
self._generate_loaders()
self._execute_load_data()
if not self._loader_pool:
summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir)
basic_train_jobs = []
for info in summaries_info:
basic_train_jobs.append(_BasicTrainJob(
train_id=info['relative_path'],
abs_summary_base_dir=self._summary_base_dir,
abs_summary_dir=os.path.realpath(os.path.join(
self._summary_base_dir,
info['relative_path']
)),
create_time=info['create_time'],
update_time=info['update_time']
))
self._brief_cache.update_cache(basic_train_jobs)
self._detail_cache.update_cache(basic_train_jobs)
if not self._brief_cache.has_content() and not self._detail_cache.has_content():
self.status = DataManagerStatus.INVALID.value
else:
self.status = DataManagerStatus.DONE.value
logger.info("Load event data end, status: %r, and loader pool size is %r.",
self.status, len(self._loader_pool))
def _generate_loaders(self):
"""This function generates the loader from given path."""
loader_dict = {}
for generator in self._loader_generators:
loader_dict.update(generator.generate_loaders(self._loader_pool))
sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
self._deal_loaders(latest_loaders)
def _deal_loaders(self, latest_loaders):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
def _execute_load_data(self):
"""Load data through multiple threads."""
threads_count = self._get_threads_count()
if not threads_count:
logger.info("Can not find any valid train log path to load, loader pool is empty.")
return
logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = []
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
futures.append(future)
wait(futures, return_when=ALL_COMPLETED)
self.status, self._detail_cache.loader_pool_size())
@staticmethod
def check_reload_interval(reload_interval):
......@@ -262,18 +849,6 @@ class DataManager:
if max_threads_count <= 0:
raise ParamValueError("The value of max threads count should be > 0.")
def _get_threads_count(self):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count = min(self._max_threads_count, len(self._loader_pool))
return threads_count
def get_train_job_by_plugin(self, train_id, plugin_name):
"""
Get a train job by train job id.
......@@ -290,32 +865,9 @@ class DataManager:
"""
self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
name = loader.name
data_loader = loader.data_loader
tags = []
try:
events_data = data_loader.get_events_data()
tags = events_data.list_tags_by_plugin(plugin_name)
except KeyError:
logger.debug("Plugin name %r does not exist "
"in train job %r, and set tags to empty list.", plugin_name, name)
except AttributeError:
logger.debug("Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list.", name)
result = dict(id=train_id, name=name, tags=tags)
return result
return self._detail_cache.get_train_job_by_plugin(train_id, plugin_name)
def delete_train_job(self, train_id):
def delete_train_job(self, train_id, only_delete_from_cache=True):
"""
Delete train job with a train id.
......@@ -323,8 +875,11 @@ class DataManager:
train_id (str): ID for train job.
"""
with self._loader_pool_mutex:
self._delete_loader(train_id)
if not only_delete_from_cache:
raise NotImplementedError("Delete from both cache and disk is not supported.")
self._brief_cache.delete_train_job(train_id)
self._detail_cache.delete_train_job(train_id)
def list_tensors(self, train_id, tag):
"""
......@@ -342,66 +897,7 @@ class DataManager:
"""
self._check_status_valid()
loader_pool = self._get_snapshot_loader_pool()
if not self._is_loader_in_loader_pool(train_id, loader_pool):
raise TrainJobNotExistError("Can not find the given train job in cache.")
data_loader = loader_pool[train_id].data_loader
events_data = data_loader.get_events_data()
try:
tensors = events_data.tensors(tag)
except KeyError:
error_msg = "Can not find any data in this train job by given tag."
raise ParamValueError(error_msg)
return tensors
def _check_train_job_exist(self, train_id, loader_pool):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist = False
if train_id in loader_pool:
return
for generator in self._loader_generators:
if generator.check_train_job_exist(train_id):
is_exist = True
break
if not is_exist:
raise TrainJobNotExistError("Can not find the train job in data manager.")
def _is_loader_in_loader_pool(self, train_id, loader_pool):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if train_id in loader_pool:
return True
return False
def _get_snapshot_loader_pool(self):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with self._loader_pool_mutex:
return dict(self._loader_pool)
return self._detail_cache.list_tensors(train_id, tag)
def _check_status_valid(self):
"""Check if the status is valid to load data."""
......@@ -409,90 +905,29 @@ class DataManager:
if self.status == DataManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status)
def get_single_train_job(self, train_id, manual_update=False):
def get_train_job(self, train_id):
"""
Get train job by train ID.
Args:
train_id (str): Train ID for train job.
manual_update (bool): If manual update, True.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id, manual_update)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
train_job = loader.to_dict()
train_job.pop('data_loader')
plugin_data = {}
for plugin_name in PluginNameEnum.list_members():
job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
if job is None:
plugin_data[plugin_name] = []
else:
plugin_data[plugin_name] = job['tags']
train_job.update({'tag_mapping': plugin_data})
return train_job
def _get_loader(self, train_id, manual_update=False):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
manual_update (bool): If manual, True. Else False.
Returns:
LoaderStruct, the loader.
"""
loader = None
is_reload = False
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
if manual_update and loader is None:
for generator in self._loader_generators:
tmp_loader = generator.generate_loader_by_train_id(train_id)
if loader and loader.latest_update_time > tmp_loader.latest_update_time:
continue
loader = tmp_loader
if loader is None:
return None
self._add_loader(loader)
is_reload = True
if manual_update:
self._update_loader_latest_update_time(loader.loader_id)
if is_reload:
self.reload_data()
detail_train_job = self._detail_cache.get_train_job(train_id)
brief_train_job = self._brief_cache.get_train_job(train_id)
return loader
return TrainJob(brief_train_job, detail_train_job)
def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
def list_train_jobs(self):
"""
Update loader with latest_update_time.
List train jobs.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
To be implemented.
"""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time
raise NotImplementedError()
@property
def status(self):
......@@ -509,6 +944,16 @@ class DataManager:
"""Set data manger status."""
self._status = status
def cache_train_job(self, train_id):
"""Cache given train job (async)."""
brief_need_reload = self._brief_cache.cache_train_job(train_id)
detail_need_reload = self._detail_cache.cache_train_job(train_id)
if brief_need_reload or detail_need_reload:
self.reload_data()
def register_brief_cache_item_updater(self, updater: BaseCacheItemUpdater):
"""Register brief cache item updater for brief cache manager."""
self._brief_cache.register_cache_item_updater(updater)
_loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)]
DATA_MANAGER = DataManager(_loader_generators)
DATA_MANAGER = DataManager(settings.SUMMARY_BASE_DIR)
......@@ -18,6 +18,7 @@ from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
class TrainTaskManager(BaseProcessor):
......@@ -53,13 +54,24 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api.
"""
Validation.check_param_empty(train_id=train_id)
train_job = self._data_manager.get_single_train_job(train_id, manual_update=manual_update)
if not train_job:
if manual_update:
self._data_manager.cache_train_job(train_id)
train_job = self._data_manager.get_train_job(train_id)
try:
data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
except exceptions.TrainJobDetailNotInCacheError:
plugins = []
if not plugins:
default_result = dict()
for plugin_name in PluginNameEnum.list_members():
default_result.update({plugin_name: list()})
return dict(plugins=default_result)
return dict(
plugins=train_job['tag_mapping']
plugins=plugins
)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cache item updater."""
import os
from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob
from mindinsight.lineagemgr.querier.query_model import LineageObj
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageSummaryAnalyzer
class LineageCacheItemUpdater(BaseCacheItemUpdater):
"""Cache item updater for lineage info."""
def update_item(self, cache_item: CachedTrainJob):
"""Update cache item in place."""
log_path = cache_item.summary_dir
log_dir = os.path.dirname(log_path)
lineage_info = LineageSummaryAnalyzer.get_summary_infos(log_path)
user_defined_info = LineageSummaryAnalyzer.get_user_defined_info(log_path)
lineage_obj = LineageObj(
log_dir,
train_lineage=lineage_info.train_lineage,
evaluation_lineage=lineage_info.eval_lineage,
dataset_graph=lineage_info.dataset_graph,
user_defined_info=user_defined_info
)
cache_item.set(key="lineage", value=lineage_obj)
......@@ -63,3 +63,4 @@ class DataVisualErrors(Enum):
IMAGE_NOT_EXIST = 13
SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
......@@ -25,7 +25,6 @@ from flask import Response
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools
......@@ -59,7 +58,7 @@ def init_summary_logs():
log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
mock_data_manager = DataManager(summary_base_dir)
mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager)
......
......@@ -33,7 +33,6 @@ from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
......@@ -89,7 +88,7 @@ class TestDataManager:
train_ids.append(f'./dir{i}')
data_manager.logger = MockLogger
mock_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
mock_manager = data_manager.DataManager(summary_base_dir)
mock_manager.start_load_data(reload_interval=0)
check_loading_done(mock_manager)
......@@ -112,7 +111,7 @@ class TestDataManager:
def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp()
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
d_manager = DataManager(summary_base_dir)
with pytest.raises(ParamValueError):
d_manager.start_load_data(**params)
shutil.rmtree(summary_base_dir)
......@@ -142,9 +141,9 @@ class TestDataManager:
latest_update_time=modify_time_01,
data_loader=loader_01)
loader_pool = {train_job_01: loader}
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value
d_manager._loader_pool = loader_pool
d_manager._detail_cache._loader_pool = loader_pool
res = d_manager.list_tensors(train_job_01, tag)
assert res == {'test result'}
......@@ -169,9 +168,9 @@ class TestDataManager:
latest_update_time=modify_time_01,
data_loader=loader_01)
loader_pool = {train_job_01: loader}
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value
d_manager._loader_pool = loader_pool
d_manager._detail_cache._loader_pool = loader_pool
tag = 'image'
with pytest.raises(ParamValueError):
d_manager.list_tensors(train_job_01, tag)
......@@ -181,7 +180,7 @@ class TestDataManager:
def test_list_tensors_with_not_exist_train_job(self):
"""Test list_tensors method with parameter train_id not found in loader_pool."""
summary_base_dir = tempfile.mkdtemp()
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)])
d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value
tag = 'image'
train_job_01 = 'train_01'
......@@ -200,13 +199,12 @@ class TestDataManager:
expected_loader_ids = list(loader_dict.keys())
mock_generate_loaders.return_value = loader_dict
generators = [data_manager.DataLoaderGenerator(summary_base_dir)]
mock_data_manager = data_manager.DataManager(generators)
mock_data_manager._execute_load_data = Mock()
mock_data_manager = data_manager.DataManager(summary_base_dir)
mock_data_manager._detail_cache._execute_load_data = Mock()
mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager, 3)
current_loader_ids = mock_data_manager._loader_pool.keys()
current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys()
assert sorted(current_loader_ids) == sorted(expected_loader_ids)
......@@ -221,7 +219,7 @@ class TestDataManager:
mock_generate_loaders.return_value = loader_dict
mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager)
current_loader_ids = mock_data_manager._loader_pool.keys()
current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys()
assert sorted(current_loader_ids) == sorted(expected_loader_ids)
......
......@@ -30,7 +30,6 @@ from mindinsight.datavisual.common.exceptions import GraphNotExistError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError
......@@ -74,7 +73,7 @@ class TestGraphProcessor:
self._temp_path, self._graph_dict, _ = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......@@ -93,7 +92,7 @@ class TestGraphProcessor:
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......
......@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import HistogramNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor
from mindinsight.datavisual.utils import crc32
......@@ -72,7 +71,7 @@ class TestHistogramProcessor:
PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......
......@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32
......@@ -81,7 +80,7 @@ class TestImagesProcessor:
PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......
......@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32
......@@ -73,7 +72,7 @@ class TestScalarsProcessor:
PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)])
self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done
......
......@@ -27,7 +27,6 @@ import pytest
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.utils import crc32
......@@ -97,7 +96,7 @@ class TestTrainTaskManager:
self._generated_path.append(self._root_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(self._root_dir)])
self._mock_data_manager = data_manager.DataManager(self._root_dir)
self._mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(self._mock_data_manager, time_limit=30)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册