提交 7e17d6ff 编写于 作者: W wenkai

refactor data manager and unify cache and data access/reload

上级 46d44977
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trigger data manager load."""
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
from mindinsight.datavisual.common.log import logger
from mindinsight.conf import settings
from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater
def init_module(app):
"""
Interface to init module.
Args:
app (Flask): An instance of Flask.
"""
# Just to suppress pylint warning about unused arg.
logger.debug("App: %s", type(app))
DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater())
DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL),
max_threads_count=int(settings.MAX_THREADS_COUNT))
...@@ -18,9 +18,6 @@ from mindinsight.backend.datavisual.static_resource_api import init_module as st ...@@ -18,9 +18,6 @@ from mindinsight.backend.datavisual.static_resource_api import init_module as st
from mindinsight.backend.datavisual.task_manager_api import init_module as task_init_module from mindinsight.backend.datavisual.task_manager_api import init_module as task_init_module
from mindinsight.backend.datavisual.train_visual_api import init_module as train_init_module from mindinsight.backend.datavisual.train_visual_api import init_module as train_init_module
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER
def init_module(app): def init_module(app):
""" """
...@@ -33,6 +30,3 @@ def init_module(app): ...@@ -33,6 +30,3 @@ def init_module(app):
static_init_module(app) static_init_module(app)
task_init_module(app) task_init_module(app)
train_init_module(app) train_init_module(app)
DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL),
max_threads_count=int(settings.MAX_THREADS_COUNT))
...@@ -150,3 +150,12 @@ class HistogramNotExistError(MindInsightException): ...@@ -150,3 +150,12 @@ class HistogramNotExistError(MindInsightException):
super(HistogramNotExistError, self).__init__(DataVisualErrors.HISTOGRAM_NOT_EXIST, super(HistogramNotExistError, self).__init__(DataVisualErrors.HISTOGRAM_NOT_EXIST,
error_msg, error_msg,
http_code=400) http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."):
error_msg = f'Detail info of the given train job is not in cache. Detail: {error_detail}'
super().__init__(DataVisualErrors.TRAIN_JOB_DETAIL_NOT_IN_CACHE,
error_msg,
http_code=400)
...@@ -20,11 +20,18 @@ It can read events data through the DataLoader. ...@@ -20,11 +20,18 @@ It can read events data through the DataLoader.
This module also acts as a thread pool manager. This module also acts as a thread pool manager.
""" """
import abc
import enum
import threading import threading
import time import time
import datetime
import os
from typing import Iterable, Optional
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
...@@ -37,44 +44,369 @@ from mindinsight.utils.exceptions import MindInsightException ...@@ -37,44 +44,369 @@ from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
class DataManager: @enum.unique
class _CacheStatus(enum.Enum):
"""Train job cache status."""
NOT_IN_CACHE = "NOT_IN_CACHE"
CACHING = "CACHING"
CACHED = "CACHED"
class _BasicTrainJob:
""" """
DataManager manages a pool of loader which help access events data. Basic info about train job.
Each loader helps deal the data of the events. Args:
A loader corresponds to an events_data. train_id (str): Id of the train job.
The DataManager build a pool including all the data_loader. abs_summary_base_dir (str): The canonical path of summary base directory. It should be the return value of
The data_loader provides extracting realpath().
method to get the information of events. abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory.
"""
def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time):
self._train_id = train_id
self._abs_summary_base_dir = abs_summary_base_dir
self._abs_summary_dir = abs_summary_dir
self._create_time = create_time
self._update_time = update_time
@property
def summary_dir(self):
"""Get summary directory path."""
return self._abs_summary_dir
@property
def train_id(self):
"""Get train id."""
return self._train_id
class CachedTrainJob:
""" """
def __init__(self, loader_generators): Cache item for BriefCacheManager.
DetailCacheManager will also wrap it's return value with this class.
Args:
basic_info (_BasicTrainJob): Basic info about the train job.
""" """
Initialize the pool of loader and the dict of name-to-path. def __init__(self, basic_info: _BasicTrainJob):
self._basic_info = basic_info
self._last_access_time = datetime.datetime.utcnow()
# Other cached content is stored here.
self._content = {}
self._cache_status = _CacheStatus.NOT_IN_CACHE
@property
def cache_status(self):
"""Get cache status."""
return self._cache_status
@cache_status.setter
def cache_status(self, value):
"""Set cache status."""
self._cache_status = value
def update_access_time(self):
"""Update last access time of this cache item."""
self._last_access_time = datetime.datetime.utcnow()
@property
def last_access_time(self):
"""Get last access time for purposes such as LRU."""
return self._last_access_time
@property
def summary_dir(self):
"""Get summary directory path."""
return self._basic_info.summary_dir
def set(self, key, value):
"""Set value to cache."""
self._content[key] = value
def get(self, key):
"""Get value from cache."""
try:
return self._content[key]
except KeyError:
raise ParamValueError("Invalid cache key({}).".format(key))
@property
def basic_info(self):
"""Get basic train job info."""
return self._basic_info
@basic_info.setter
def basic_info(self, value):
"""Set basic train job info."""
self._basic_info = value
class TrainJob:
"""
Train job object.
You must not create TrainJob objects manually. You should always get TrainJob objects from DataManager.
Args: Args:
loader_generators (list[LoaderGenerator]): Loader generators help generate loaders. brief_train_job (CachedTrainJob): Brief info about train job.
detail_train_job (Optional[CachedTrainJob]): Detailed info about train job. Default: None.
"""
def __init__(self,
brief_train_job: CachedTrainJob,
detail_train_job: Optional[CachedTrainJob] = None):
self._brief = brief_train_job
self._detail = detail_train_job
if self._detail is None:
self._cache_status = _CacheStatus.NOT_IN_CACHE
else:
self._cache_status = self._detail.cache_status
self._status: Refer `datavisual.common.enums.DataManagerStatus`. def has_detail(self):
self._loader_pool: {'loader_id': <LoaderStruct>}. """Whether this train job has detailed info in cache."""
return bool(self._detail is not None)
def get_detail(self, key):
"""
Get detail content.
Args:
key (Any): Cache key.
Returns:
Any, cache content.
Raises:
TrainJobDetailNotInCacheError: when this train job has no detail cache.
"""
if not self.has_detail():
raise exceptions.TrainJobDetailNotInCacheError()
return self._detail.get(key)
def get_brief(self, key):
"""
Get brief content.
Args:
key (Any): Cache key.
Returns:
Any, cache content.
"""
return self._brief.get(key)
class BaseCacheItemUpdater(abc.ABC):
"""Abstract base class for other modules to update cache content."""
def update_item(self, cache_item: CachedTrainJob):
"""
Update cache item in place.
Args:
cache_item (CachedTrainJob): The cache item to be processed.
"""
raise NotImplementedError()
class _BaseCacheManager:
"""Base class for cache manager."""
def __init__(self):
# Use dict to remove duplicate updaters.
self._updaters = {}
# key is train_id
self._lock = threading.Lock()
self._cache_items = {}
def size(self):
"""Gets used cache slots."""
return len(self._cache_items)
def register_cache_item_updater(self, updater: BaseCacheItemUpdater):
"""Register cache item updater."""
self._updaters[updater.__class__.__qualname__] = updater
def get_train_jobs(self):
"""Get cached train jobs."""
copied_train_jobs = dict(self._cache_items)
return copied_train_jobs
def get_train_job(self, train_id):
"""Get cached train job."""
try:
return self._cache_items[train_id]
except KeyError:
raise TrainJobNotExistError(train_id)
def cache_train_job(self, train_id) -> bool:
"""
Cache given train job and update train job's last access time.
This method should return true if reload actions should be taken to cache the train job.
Args:
train_id (str): Train Id.
"""
raise NotImplementedError()
def delete_train_job(self, train_id):
"""Delete train job from cache."""
if train_id in self._cache_items:
del self._cache_items[train_id]
def has_content(self):
"""Whether this cache manager has train jobs."""
return bool(self._cache_items)
def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Update cache according to given train jobs on disk.
Different cache manager should implement different cache update policies in this method.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk.
"""
raise NotImplementedError()
def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Merge train jobs in cache with train jobs from disk
This method will remove train jobs not on disk. Call this function with lock for thread safety.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk.
Returns:
dict, a dict containing train jobs to be cached.
"""
new_cache_items = {}
for train_job in disk_train_jobs:
if train_job.train_id not in self._cache_items:
new_cache_items[train_job.train_id] = CachedTrainJob(train_job)
else:
reused_train_job = self._cache_items[train_job.train_id]
reused_train_job.basic_info = train_job
new_cache_items[train_job.train_id] = reused_train_job
return new_cache_items
class _BriefCacheManager(_BaseCacheManager):
"""A cache manager that holds all disk train jobs on disk."""
def cache_train_job(self, train_id):
"""
Cache given train job.
All disk train jobs are cached on every reload, so this method always return false.
Args:
train_id (str): Train Id.
""" """
if train_id in self._cache_items:
self._cache_items[train_id].update_access_time()
return False
def update_cache(self, disk_train_jobs):
"""Update cache."""
with self._lock:
new_cache_items = self._merge_with_disk(disk_train_jobs)
self._cache_items = new_cache_items
for updater in self._updaters.values():
for cache_item in self._cache_items.values():
updater.update_item(cache_item)
# Key for plugin tags.
DATAVISUAL_PLUGIN_KEY = "tag_mapping"
# Detail train job cache key for datavisual content.
DATAVISUAL_CACHE_KEY = "datavisual"
class _DetailCacheManager(_BaseCacheManager):
"""A cache manager that holds detailed info for most recently used train jobs."""
def __init__(self, loader_generators):
super().__init__()
self._loader_pool = {} self._loader_pool = {}
self._deleted_id_list = [] self._deleted_id_list = []
self._status = DataManagerStatus.INIT.value
self._status_mutex = threading.Lock()
self._loader_pool_mutex = threading.Lock() self._loader_pool_mutex = threading.Lock()
self._max_threads_count = 30 self._max_threads_count = 30
self._reload_interval = 3
self._loader_generators = loader_generators self._loader_generators = loader_generators
def size(self):
"""
Get the number of items in this cache manager.
To be implemented.
Returns:
int, the number of items in this cache manager.
"""
raise NotImplementedError()
def loader_pool_size(self):
"""Get loader pool size."""
return len(self._loader_pool)
def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]):
"""
Update cache.
Will switch to using disk_train_jobs in the future.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk.
"""
self._generate_loaders()
self._execute_load_data()
def cache_train_job(self, train_id):
"""Cache given train job."""
loader = None
need_reload = False
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
if loader is None:
for generator in self._loader_generators:
tmp_loader = generator.generate_loader_by_train_id(train_id)
if loader and loader.latest_update_time > tmp_loader.latest_update_time:
continue
loader = tmp_loader
if loader is None:
raise TrainJobNotExistError(train_id)
self._add_loader(loader)
need_reload = True
self._update_loader_latest_update_time(loader.loader_id)
return need_reload
def get_train_jobs(self):
"""
Get train jobs
To be implemented.
"""
def _add_loader(self, loader): def _add_loader(self, loader):
""" """
Add a loader to load data. Add a loader to load data.
Args: Args:
loader (LoaderStruct): A object of `Loader`. loader (LoaderStruct): A object of `Loader`.
""" """
if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE: if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE:
delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1 delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1
...@@ -85,40 +417,328 @@ class DataManager: ...@@ -85,40 +417,328 @@ class DataManager:
self._delete_loader(delete_loader_id) self._delete_loader(delete_loader_id)
self._loader_pool.update({loader.loader_id: loader}) self._loader_pool.update({loader.loader_id: loader})
def _delete_loader(self, loader_id): def _delete_loader(self, loader_id):
""" """
Delete loader from loader pool by loader id. Delete loader from loader pool by loader id.
Args:
loader_id (str): ID of loader.
"""
if self._loader_pool.get(loader_id) is not None:
logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id):
"""
Load data form data_loader.
If there is something wrong by loading, add logs and delete the loader.
Args:
loader_id (str): An ID for `Loader`.
"""
try:
with self._loader_pool_mutex:
loader = self._loader_pool.get(loader_id, None)
if loader is None:
logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return
loader.data_loader.load()
except MindInsightException as ex:
logger.warning("Data loader %r load data failed. "
"Delete data_loader. Detail: %s", loader_id, ex)
with self._loader_pool_mutex:
self._delete_loader(loader_id)
def _generate_loaders(self):
"""This function generates the loader from given path."""
loader_dict = {}
for generator in self._loader_generators:
loader_dict.update(generator.generate_loaders(self._loader_pool))
sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
self._deal_loaders(latest_loaders)
def _deal_loaders(self, latest_loaders):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
def _execute_load_data(self):
"""Load data through multiple threads."""
threads_count = self._get_threads_count()
if not threads_count:
logger.info("Can not find any valid train log path to load, loader pool is empty.")
return
logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = []
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
futures.append(future)
wait(futures, return_when=ALL_COMPLETED)
def _get_threads_count(self):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count = min(self._max_threads_count, len(self._loader_pool))
return threads_count
def delete_train_job(self, train_id):
"""
Delete train job with a train id.
Args:
train_id (str): ID for train job.
"""
with self._loader_pool_mutex:
self._delete_loader(train_id)
def list_tensors(self, train_id, tag):
"""
List tensors of the given train job and tag.
If the tensor can not find by the given tag, will raise exception.
Args:
train_id (str): ID for train job.
tag (str): The tag name.
Returns:
NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
the value will contain the given tag data.
"""
loader_pool = self._get_snapshot_loader_pool()
if not self._is_loader_in_loader_pool(train_id, loader_pool):
raise TrainJobNotExistError("Can not find the given train job in cache.")
data_loader = loader_pool[train_id].data_loader
events_data = data_loader.get_events_data()
try:
tensors = events_data.tensors(tag)
except KeyError:
error_msg = "Can not find any data in this train job by given tag."
raise ParamValueError(error_msg)
return tensors
def _check_train_job_exist(self, train_id, loader_pool):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist = False
if train_id in loader_pool:
return
for generator in self._loader_generators:
if generator.check_train_job_exist(train_id):
is_exist = True
break
if not is_exist:
raise TrainJobNotExistError("Can not find the train job in data manager.")
def _is_loader_in_loader_pool(self, train_id, loader_pool):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if train_id in loader_pool:
return True
return False
def _get_snapshot_loader_pool(self):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with self._loader_pool_mutex:
return dict(self._loader_pool)
def get_train_job(self, train_id):
"""
Get train job by train ID.
This method overrides parent method.
Args:
train_id (str): Train ID for train job.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
train_job = loader.to_dict()
train_job.pop('data_loader')
plugin_data = {}
for plugin_name in PluginNameEnum.list_members():
job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
if job is None:
plugin_data[plugin_name] = []
else:
plugin_data[plugin_name] = job['tags']
train_job.update({DATAVISUAL_PLUGIN_KEY: plugin_data})
# Will fill basic_info value in future.
train_job_obj = CachedTrainJob(basic_info=None)
train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job)
# Will assign real value in future.
train_job_obj.cache_status = _CacheStatus.CACHED
return train_job_obj
def _get_loader(self, train_id):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
Returns:
LoaderStruct, the loader.
"""
loader = None
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
return loader
def _update_loader_latest_update_time(self, loader_id, latest_update_time=None):
"""
Update loader with latest_update_time.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
"""
if latest_update_time is None:
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time
def get_train_job_by_plugin(self, train_id, plugin_name):
"""
Get a train job by train job id.
If the given train job does not has the given plugin data, the tag list will be empty.
Args:
train_id (str): Get train job info by the given id.
plugin_name (str): Get tags by given plugin.
Returns:
TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
a train job object.
"""
self._check_train_job_exist(train_id, self._loader_pool)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
name = loader.name
data_loader = loader.data_loader
tags = []
try:
events_data = data_loader.get_events_data()
tags = events_data.list_tags_by_plugin(plugin_name)
except KeyError:
logger.debug("Plugin name %r does not exist "
"in train job %r, and set tags to empty list.", plugin_name, name)
except AttributeError:
logger.debug("Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list.", name)
result = dict(id=train_id, name=name, tags=tags)
return result
Args:
loader_id (str): ID of loader.
"""
if self._loader_pool.get(loader_id) is not None:
logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id): class DataManager:
""" """
Load data form data_loader. DataManager manages a pool of loader which help access events data.
If there is something wrong by loading, add logs and delete the loader. Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
"""
def __init__(self, summary_base_dir):
"""
Initialize the pool of loader and the dict of name-to-path.
Args: Args:
loader_id (str): An ID for `Loader`. summary_base_dir (str): Base summary directory.
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
""" """
try: self._summary_base_dir = os.path.realpath(summary_base_dir)
with self._loader_pool_mutex: self._status = DataManagerStatus.INIT.value
loader = self._loader_pool.get(loader_id, None) self._status_mutex = threading.Lock()
if loader is None:
logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return
loader.data_loader.load()
except MindInsightException as ex:
logger.warning("Data loader %r load data failed. "
"Delete data_loader. Detail: %s", loader_id, ex)
with self._loader_pool_mutex: self._reload_interval = 3
self._delete_loader(loader_id)
loader_generators = [DataLoaderGenerator(self._summary_base_dir)]
self._detail_cache = _DetailCacheManager(loader_generators)
self._brief_cache = _BriefCacheManager()
def start_load_data(self, def start_load_data(self,
reload_interval=settings.RELOAD_INTERVAL, reload_interval=settings.RELOAD_INTERVAL,
...@@ -176,64 +796,31 @@ class DataManager: ...@@ -176,64 +796,31 @@ class DataManager:
return return
self.status = DataManagerStatus.LOADING.value self.status = DataManagerStatus.LOADING.value
self._generate_loaders() summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir)
self._execute_load_data()
basic_train_jobs = []
if not self._loader_pool: for info in summaries_info:
basic_train_jobs.append(_BasicTrainJob(
train_id=info['relative_path'],
abs_summary_base_dir=self._summary_base_dir,
abs_summary_dir=os.path.realpath(os.path.join(
self._summary_base_dir,
info['relative_path']
)),
create_time=info['create_time'],
update_time=info['update_time']
))
self._brief_cache.update_cache(basic_train_jobs)
self._detail_cache.update_cache(basic_train_jobs)
if not self._brief_cache.has_content() and not self._detail_cache.has_content():
self.status = DataManagerStatus.INVALID.value self.status = DataManagerStatus.INVALID.value
else: else:
self.status = DataManagerStatus.DONE.value self.status = DataManagerStatus.DONE.value
logger.info("Load event data end, status: %r, and loader pool size is %r.", logger.info("Load event data end, status: %r, and loader pool size is %r.",
self.status, len(self._loader_pool)) self.status, self._detail_cache.loader_pool_size())
def _generate_loaders(self):
"""This function generates the loader from given path."""
loader_dict = {}
for generator in self._loader_generators:
loader_dict.update(generator.generate_loaders(self._loader_pool))
sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time)
latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:]
self._deal_loaders(latest_loaders)
def _deal_loaders(self, latest_loaders):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with self._loader_pool_mutex:
for loader_id, loader in latest_loaders:
if self._loader_pool.get(loader_id, None) is None:
self._add_loader(loader)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time:
self._update_loader_latest_update_time(loader_id, loader.latest_update_time)
def _execute_load_data(self):
"""Load data through multiple threads."""
threads_count = self._get_threads_count()
if not threads_count:
logger.info("Can not find any valid train log path to load, loader pool is empty.")
return
logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = []
loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
futures.append(future)
wait(futures, return_when=ALL_COMPLETED)
@staticmethod @staticmethod
def check_reload_interval(reload_interval): def check_reload_interval(reload_interval):
...@@ -262,18 +849,6 @@ class DataManager: ...@@ -262,18 +849,6 @@ class DataManager:
if max_threads_count <= 0: if max_threads_count <= 0:
raise ParamValueError("The value of max threads count should be > 0.") raise ParamValueError("The value of max threads count should be > 0.")
def _get_threads_count(self):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count = min(self._max_threads_count, len(self._loader_pool))
return threads_count
def get_train_job_by_plugin(self, train_id, plugin_name): def get_train_job_by_plugin(self, train_id, plugin_name):
""" """
Get a train job by train job id. Get a train job by train job id.
...@@ -290,32 +865,9 @@ class DataManager: ...@@ -290,32 +865,9 @@ class DataManager:
""" """
self._check_status_valid() self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool) return self._detail_cache.get_train_job_by_plugin(train_id, plugin_name)
loader = self._get_loader(train_id)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
name = loader.name
data_loader = loader.data_loader
tags = []
try:
events_data = data_loader.get_events_data()
tags = events_data.list_tags_by_plugin(plugin_name)
except KeyError:
logger.debug("Plugin name %r does not exist "
"in train job %r, and set tags to empty list.", plugin_name, name)
except AttributeError:
logger.debug("Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list.", name)
result = dict(id=train_id, name=name, tags=tags)
return result
def delete_train_job(self, train_id): def delete_train_job(self, train_id, only_delete_from_cache=True):
""" """
Delete train job with a train id. Delete train job with a train id.
...@@ -323,8 +875,11 @@ class DataManager: ...@@ -323,8 +875,11 @@ class DataManager:
train_id (str): ID for train job. train_id (str): ID for train job.
""" """
with self._loader_pool_mutex: if not only_delete_from_cache:
self._delete_loader(train_id) raise NotImplementedError("Delete from both cache and disk is not supported.")
self._brief_cache.delete_train_job(train_id)
self._detail_cache.delete_train_job(train_id)
def list_tensors(self, train_id, tag): def list_tensors(self, train_id, tag):
""" """
...@@ -342,66 +897,7 @@ class DataManager: ...@@ -342,66 +897,7 @@ class DataManager:
""" """
self._check_status_valid() self._check_status_valid()
loader_pool = self._get_snapshot_loader_pool() return self._detail_cache.list_tensors(train_id, tag)
if not self._is_loader_in_loader_pool(train_id, loader_pool):
raise TrainJobNotExistError("Can not find the given train job in cache.")
data_loader = loader_pool[train_id].data_loader
events_data = data_loader.get_events_data()
try:
tensors = events_data.tensors(tag)
except KeyError:
error_msg = "Can not find any data in this train job by given tag."
raise ParamValueError(error_msg)
return tensors
def _check_train_job_exist(self, train_id, loader_pool):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist = False
if train_id in loader_pool:
return
for generator in self._loader_generators:
if generator.check_train_job_exist(train_id):
is_exist = True
break
if not is_exist:
raise TrainJobNotExistError("Can not find the train job in data manager.")
def _is_loader_in_loader_pool(self, train_id, loader_pool):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if train_id in loader_pool:
return True
return False
def _get_snapshot_loader_pool(self):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with self._loader_pool_mutex:
return dict(self._loader_pool)
def _check_status_valid(self): def _check_status_valid(self):
"""Check if the status is valid to load data.""" """Check if the status is valid to load data."""
...@@ -409,90 +905,29 @@ class DataManager: ...@@ -409,90 +905,29 @@ class DataManager:
if self.status == DataManagerStatus.INIT.value: if self.status == DataManagerStatus.INIT.value:
raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status) raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status)
def get_single_train_job(self, train_id, manual_update=False): def get_train_job(self, train_id):
""" """
Get train job by train ID. Get train job by train ID.
Args: Args:
train_id (str): Train ID for train job. train_id (str): Train ID for train job.
manual_update (bool): If manual update, True.
Returns: Returns:
dict, single train job, if can not find any data, will return None. dict, single train job, if can not find any data, will return None.
""" """
self._check_status_valid() self._check_status_valid()
self._check_train_job_exist(train_id, self._loader_pool) detail_train_job = self._detail_cache.get_train_job(train_id)
brief_train_job = self._brief_cache.get_train_job(train_id)
loader = self._get_loader(train_id, manual_update)
if loader is None:
logger.warning("No valid summary log in train job %s, "
"or it is not in the cache.", train_id)
return None
train_job = loader.to_dict()
train_job.pop('data_loader')
plugin_data = {}
for plugin_name in PluginNameEnum.list_members():
job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name)
if job is None:
plugin_data[plugin_name] = []
else:
plugin_data[plugin_name] = job['tags']
train_job.update({'tag_mapping': plugin_data})
return train_job
def _get_loader(self, train_id, manual_update=False):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
manual_update (bool): If manual, True. Else False.
Returns:
LoaderStruct, the loader.
"""
loader = None
is_reload = False
with self._loader_pool_mutex:
if self._is_loader_in_loader_pool(train_id, self._loader_pool):
loader = self._loader_pool.get(train_id)
if manual_update and loader is None:
for generator in self._loader_generators:
tmp_loader = generator.generate_loader_by_train_id(train_id)
if loader and loader.latest_update_time > tmp_loader.latest_update_time:
continue
loader = tmp_loader
if loader is None:
return None
self._add_loader(loader)
is_reload = True
if manual_update: return TrainJob(brief_train_job, detail_train_job)
self._update_loader_latest_update_time(loader.loader_id)
if is_reload:
self.reload_data()
return loader
def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): def list_train_jobs(self):
""" """
Update loader with latest_update_time. List train jobs.
Args: To be implemented.
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
""" """
if latest_update_time is None: raise NotImplementedError()
latest_update_time = time.time()
self._loader_pool[loader_id].latest_update_time = latest_update_time
@property @property
def status(self): def status(self):
...@@ -509,6 +944,16 @@ class DataManager: ...@@ -509,6 +944,16 @@ class DataManager:
"""Set data manger status.""" """Set data manger status."""
self._status = status self._status = status
def cache_train_job(self, train_id):
"""Cache given train job (async)."""
brief_need_reload = self._brief_cache.cache_train_job(train_id)
detail_need_reload = self._detail_cache.cache_train_job(train_id)
if brief_need_reload or detail_need_reload:
self.reload_data()
def register_brief_cache_item_updater(self, updater: BaseCacheItemUpdater):
"""Register brief cache item updater for brief cache manager."""
self._brief_cache.register_cache_item_updater(updater)
_loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)] DATA_MANAGER = DataManager(settings.SUMMARY_BASE_DIR)
DATA_MANAGER = DataManager(_loader_generators)
...@@ -18,6 +18,7 @@ from mindinsight.datavisual.common import exceptions ...@@ -18,6 +18,7 @@ from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY
class TrainTaskManager(BaseProcessor): class TrainTaskManager(BaseProcessor):
...@@ -53,13 +54,24 @@ class TrainTaskManager(BaseProcessor): ...@@ -53,13 +54,24 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api. dict, refer to restful api.
""" """
Validation.check_param_empty(train_id=train_id) Validation.check_param_empty(train_id=train_id)
train_job = self._data_manager.get_single_train_job(train_id, manual_update=manual_update)
if not train_job: if manual_update:
self._data_manager.cache_train_job(train_id)
train_job = self._data_manager.get_train_job(train_id)
try:
data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY)
plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY)
except exceptions.TrainJobDetailNotInCacheError:
plugins = []
if not plugins:
default_result = dict() default_result = dict()
for plugin_name in PluginNameEnum.list_members(): for plugin_name in PluginNameEnum.list_members():
default_result.update({plugin_name: list()}) default_result.update({plugin_name: list()})
return dict(plugins=default_result) return dict(plugins=default_result)
return dict( return dict(
plugins=train_job['tag_mapping'] plugins=plugins
) )
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cache item updater."""
import os
from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob
from mindinsight.lineagemgr.querier.query_model import LineageObj
from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageSummaryAnalyzer
class LineageCacheItemUpdater(BaseCacheItemUpdater):
"""Cache item updater for lineage info."""
def update_item(self, cache_item: CachedTrainJob):
"""Update cache item in place."""
log_path = cache_item.summary_dir
log_dir = os.path.dirname(log_path)
lineage_info = LineageSummaryAnalyzer.get_summary_infos(log_path)
user_defined_info = LineageSummaryAnalyzer.get_user_defined_info(log_path)
lineage_obj = LineageObj(
log_dir,
train_lineage=lineage_info.train_lineage,
evaluation_lineage=lineage_info.eval_lineage,
dataset_graph=lineage_info.dataset_graph,
user_defined_info=user_defined_info
)
cache_item.set(key="lineage", value=lineage_obj)
...@@ -63,3 +63,4 @@ class DataVisualErrors(Enum): ...@@ -63,3 +63,4 @@ class DataVisualErrors(Enum):
IMAGE_NOT_EXIST = 13 IMAGE_NOT_EXIST = 13
SCALAR_NOT_EXIST = 14 SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15 HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
...@@ -25,7 +25,6 @@ from flask import Response ...@@ -25,7 +25,6 @@ from flask import Response
from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.utils import tools from mindinsight.datavisual.utils import tools
...@@ -59,7 +58,7 @@ def init_summary_logs(): ...@@ -59,7 +58,7 @@ def init_summary_logs():
log_operations = LogOperations() log_operations = LogOperations()
summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST,
constants.SUMMARY_DIR_PREFIX) constants.SUMMARY_DIR_PREFIX)
mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) mock_data_manager = DataManager(summary_base_dir)
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager) check_loading_done(mock_data_manager)
......
...@@ -33,7 +33,6 @@ from mindinsight.datavisual.data_transform import data_manager, ms_data_loader ...@@ -33,7 +33,6 @@ from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.events_data import EventsData from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct
from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
...@@ -89,7 +88,7 @@ class TestDataManager: ...@@ -89,7 +88,7 @@ class TestDataManager:
train_ids.append(f'./dir{i}') train_ids.append(f'./dir{i}')
data_manager.logger = MockLogger data_manager.logger = MockLogger
mock_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) mock_manager = data_manager.DataManager(summary_base_dir)
mock_manager.start_load_data(reload_interval=0) mock_manager.start_load_data(reload_interval=0)
check_loading_done(mock_manager) check_loading_done(mock_manager)
...@@ -112,7 +111,7 @@ class TestDataManager: ...@@ -112,7 +111,7 @@ class TestDataManager:
def test_start_load_data_with_invalid_params(self, params): def test_start_load_data_with_invalid_params(self, params):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count.""" """Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) d_manager = DataManager(summary_base_dir)
with pytest.raises(ParamValueError): with pytest.raises(ParamValueError):
d_manager.start_load_data(**params) d_manager.start_load_data(**params)
shutil.rmtree(summary_base_dir) shutil.rmtree(summary_base_dir)
...@@ -142,9 +141,9 @@ class TestDataManager: ...@@ -142,9 +141,9 @@ class TestDataManager:
latest_update_time=modify_time_01, latest_update_time=modify_time_01,
data_loader=loader_01) data_loader=loader_01)
loader_pool = {train_job_01: loader} loader_pool = {train_job_01: loader}
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value d_manager._status = DataManagerStatus.LOADING.value
d_manager._loader_pool = loader_pool d_manager._detail_cache._loader_pool = loader_pool
res = d_manager.list_tensors(train_job_01, tag) res = d_manager.list_tensors(train_job_01, tag)
assert res == {'test result'} assert res == {'test result'}
...@@ -169,9 +168,9 @@ class TestDataManager: ...@@ -169,9 +168,9 @@ class TestDataManager:
latest_update_time=modify_time_01, latest_update_time=modify_time_01,
data_loader=loader_01) data_loader=loader_01)
loader_pool = {train_job_01: loader} loader_pool = {train_job_01: loader}
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value d_manager._status = DataManagerStatus.LOADING.value
d_manager._loader_pool = loader_pool d_manager._detail_cache._loader_pool = loader_pool
tag = 'image' tag = 'image'
with pytest.raises(ParamValueError): with pytest.raises(ParamValueError):
d_manager.list_tensors(train_job_01, tag) d_manager.list_tensors(train_job_01, tag)
...@@ -181,7 +180,7 @@ class TestDataManager: ...@@ -181,7 +180,7 @@ class TestDataManager:
def test_list_tensors_with_not_exist_train_job(self): def test_list_tensors_with_not_exist_train_job(self):
"""Test list_tensors method with parameter train_id not found in loader_pool.""" """Test list_tensors method with parameter train_id not found in loader_pool."""
summary_base_dir = tempfile.mkdtemp() summary_base_dir = tempfile.mkdtemp()
d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) d_manager = DataManager(summary_base_dir)
d_manager._status = DataManagerStatus.LOADING.value d_manager._status = DataManagerStatus.LOADING.value
tag = 'image' tag = 'image'
train_job_01 = 'train_01' train_job_01 = 'train_01'
...@@ -200,13 +199,12 @@ class TestDataManager: ...@@ -200,13 +199,12 @@ class TestDataManager:
expected_loader_ids = list(loader_dict.keys()) expected_loader_ids = list(loader_dict.keys())
mock_generate_loaders.return_value = loader_dict mock_generate_loaders.return_value = loader_dict
generators = [data_manager.DataLoaderGenerator(summary_base_dir)] mock_data_manager = data_manager.DataManager(summary_base_dir)
mock_data_manager = data_manager.DataManager(generators) mock_data_manager._detail_cache._execute_load_data = Mock()
mock_data_manager._execute_load_data = Mock()
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager, 3) check_loading_done(mock_data_manager, 3)
current_loader_ids = mock_data_manager._loader_pool.keys() current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys()
assert sorted(current_loader_ids) == sorted(expected_loader_ids) assert sorted(current_loader_ids) == sorted(expected_loader_ids)
...@@ -221,7 +219,7 @@ class TestDataManager: ...@@ -221,7 +219,7 @@ class TestDataManager:
mock_generate_loaders.return_value = loader_dict mock_generate_loaders.return_value = loader_dict
mock_data_manager.start_load_data(reload_interval=0) mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(mock_data_manager) check_loading_done(mock_data_manager)
current_loader_ids = mock_data_manager._loader_pool.keys() current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys()
assert sorted(current_loader_ids) == sorted(expected_loader_ids) assert sorted(current_loader_ids) == sorted(expected_loader_ids)
......
...@@ -30,7 +30,6 @@ from mindinsight.datavisual.common.exceptions import GraphNotExistError ...@@ -30,7 +30,6 @@ from mindinsight.datavisual.common.exceptions import GraphNotExistError
from mindinsight.datavisual.common.exceptions import NodeNotInGraphError from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.data_manager import DataManager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
...@@ -74,7 +73,7 @@ class TestGraphProcessor: ...@@ -74,7 +73,7 @@ class TestGraphProcessor:
self._temp_path, self._graph_dict, _ = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) self._temp_path, self._graph_dict, _ = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir)
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
...@@ -93,7 +92,7 @@ class TestGraphProcessor: ...@@ -93,7 +92,7 @@ class TestGraphProcessor:
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
......
...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum ...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import HistogramNotExistError from mindinsight.datavisual.common.exceptions import HistogramNotExistError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
...@@ -72,7 +71,7 @@ class TestHistogramProcessor: ...@@ -72,7 +71,7 @@ class TestHistogramProcessor:
PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
......
...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum ...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
...@@ -81,7 +80,7 @@ class TestImagesProcessor: ...@@ -81,7 +80,7 @@ class TestImagesProcessor:
PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
......
...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum ...@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.common.exceptions import ScalarNotExistError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
...@@ -73,7 +72,7 @@ class TestScalarsProcessor: ...@@ -73,7 +72,7 @@ class TestScalarsProcessor:
PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name))
self._generated_path.append(summary_base_dir) self._generated_path.append(summary_base_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) self._mock_data_manager = data_manager.DataManager(summary_base_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
# wait for loading done # wait for loading done
......
...@@ -27,7 +27,6 @@ import pytest ...@@ -27,7 +27,6 @@ import pytest
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform import data_manager
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
...@@ -97,7 +96,7 @@ class TestTrainTaskManager: ...@@ -97,7 +96,7 @@ class TestTrainTaskManager:
self._generated_path.append(self._root_dir) self._generated_path.append(self._root_dir)
self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(self._root_dir)]) self._mock_data_manager = data_manager.DataManager(self._root_dir)
self._mock_data_manager.start_load_data(reload_interval=0) self._mock_data_manager.start_load_data(reload_interval=0)
check_loading_done(self._mock_data_manager, time_limit=30) check_loading_done(self._mock_data_manager, time_limit=30)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册