diff --git a/mindinsight/backend/data_manager/__init__.py b/mindinsight/backend/data_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e23d2e7693e2872329bacee2d0baa89af995b66d --- /dev/null +++ b/mindinsight/backend/data_manager/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Trigger data manager load.""" + +from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER +from mindinsight.datavisual.common.log import logger +from mindinsight.conf import settings +from mindinsight.lineagemgr.cache_item_updater import LineageCacheItemUpdater + + +def init_module(app): + """ + Interface to init module. + + Args: + app (Flask): An instance of Flask. + + """ + # Just to suppress pylint warning about unused arg. + logger.debug("App: %s", type(app)) + DATA_MANAGER.register_brief_cache_item_updater(LineageCacheItemUpdater()) + DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL), + max_threads_count=int(settings.MAX_THREADS_COUNT)) diff --git a/mindinsight/backend/datavisual/__init__.py b/mindinsight/backend/datavisual/__init__.py index bc325afa52c0b8bd383d233c12285aab95297f25..278f18059725a994178ee02304c270a98ed1b1ca 100644 --- a/mindinsight/backend/datavisual/__init__.py +++ b/mindinsight/backend/datavisual/__init__.py @@ -18,9 +18,6 @@ from mindinsight.backend.datavisual.static_resource_api import init_module as st from mindinsight.backend.datavisual.task_manager_api import init_module as task_init_module from mindinsight.backend.datavisual.train_visual_api import init_module as train_init_module -from mindinsight.conf import settings -from mindinsight.datavisual.data_transform.data_manager import DATA_MANAGER - def init_module(app): """ @@ -33,6 +30,3 @@ def init_module(app): static_init_module(app) task_init_module(app) train_init_module(app) - - DATA_MANAGER.start_load_data(reload_interval=int(settings.RELOAD_INTERVAL), - max_threads_count=int(settings.MAX_THREADS_COUNT)) diff --git a/mindinsight/datavisual/common/exceptions.py b/mindinsight/datavisual/common/exceptions.py index 310bc48b396541b36ca2f6bc093b9d27fbffb550..2ec001ffa904241aab811574192120a824697cc7 100644 --- a/mindinsight/datavisual/common/exceptions.py +++ b/mindinsight/datavisual/common/exceptions.py @@ -150,3 +150,12 @@ class HistogramNotExistError(MindInsightException): super(HistogramNotExistError, self).__init__(DataVisualErrors.HISTOGRAM_NOT_EXIST, error_msg, http_code=400) + + +class TrainJobDetailNotInCacheError(MindInsightException): + """Detail info of given train job is not in cache.""" + def __init__(self, error_detail="no detail provided."): + error_msg = f'Detail info of the given train job is not in cache. Detail: {error_detail}' + super().__init__(DataVisualErrors.TRAIN_JOB_DETAIL_NOT_IN_CACHE, + error_msg, + http_code=400) diff --git a/mindinsight/datavisual/data_transform/data_manager.py b/mindinsight/datavisual/data_transform/data_manager.py index f068fdeb60496c61ffcff456113b60e4fd04eefe..ade1751b4c5a6651d0fae0a5c7bb86af22a7315c 100644 --- a/mindinsight/datavisual/data_transform/data_manager.py +++ b/mindinsight/datavisual/data_transform/data_manager.py @@ -20,11 +20,18 @@ It can read events data through the DataLoader. This module also acts as a thread pool manager. """ +import abc +import enum import threading import time +import datetime +import os +from typing import Iterable, Optional from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED +from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher + from mindinsight.conf import settings from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.log import logger @@ -37,44 +44,369 @@ from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import ParamValueError -class DataManager: +@enum.unique +class _CacheStatus(enum.Enum): + """Train job cache status.""" + NOT_IN_CACHE = "NOT_IN_CACHE" + CACHING = "CACHING" + CACHED = "CACHED" + + +class _BasicTrainJob: """ - DataManager manages a pool of loader which help access events data. + Basic info about train job. + + Args: + train_id (str): Id of the train job. + abs_summary_base_dir (str): The canonical path of summary base directory. It should be the return value of + realpath(). + abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath(). + create_time (DateTime): The create time of summary directory. + update_time (DateTime): The latest modify time of summary files directly in the summary directory. + """ + def __init__(self, train_id, abs_summary_base_dir, abs_summary_dir, create_time, update_time): + self._train_id = train_id + self._abs_summary_base_dir = abs_summary_base_dir + self._abs_summary_dir = abs_summary_dir + self._create_time = create_time + self._update_time = update_time - Each loader helps deal the data of the events. - A loader corresponds to an events_data. - The DataManager build a pool including all the data_loader. - The data_loader provides extracting - method to get the information of events. + @property + def summary_dir(self): + """Get summary directory path.""" + return self._abs_summary_dir + + @property + def train_id(self): + """Get train id.""" + return self._train_id + + +class CachedTrainJob: """ - def __init__(self, loader_generators): + Cache item for BriefCacheManager. + + DetailCacheManager will also wrap it's return value with this class. + + Args: + basic_info (_BasicTrainJob): Basic info about the train job. + """ + def __init__(self, basic_info: _BasicTrainJob): + self._basic_info = basic_info + self._last_access_time = datetime.datetime.utcnow() + + # Other cached content is stored here. + self._content = {} + + self._cache_status = _CacheStatus.NOT_IN_CACHE + + @property + def cache_status(self): + """Get cache status.""" + return self._cache_status + + @cache_status.setter + def cache_status(self, value): + """Set cache status.""" + self._cache_status = value + + def update_access_time(self): + """Update last access time of this cache item.""" + self._last_access_time = datetime.datetime.utcnow() + + @property + def last_access_time(self): + """Get last access time for purposes such as LRU.""" + return self._last_access_time + + @property + def summary_dir(self): + """Get summary directory path.""" + return self._basic_info.summary_dir + + def set(self, key, value): + """Set value to cache.""" + self._content[key] = value + + def get(self, key): + """Get value from cache.""" + try: + return self._content[key] + except KeyError: + raise ParamValueError("Invalid cache key({}).".format(key)) + + @property + def basic_info(self): + """Get basic train job info.""" + return self._basic_info + + @basic_info.setter + def basic_info(self, value): + """Set basic train job info.""" + self._basic_info = value + + +class TrainJob: + """ + Train job object. + + You must not create TrainJob objects manually. You should always get TrainJob objects from DataManager. + + Args: + brief_train_job (CachedTrainJob): Brief info about train job. + detail_train_job (Optional[CachedTrainJob]): Detailed info about train job. Default: None. + """ + def __init__(self, + brief_train_job: CachedTrainJob, + detail_train_job: Optional[CachedTrainJob] = None): + self._brief = brief_train_job + self._detail = detail_train_job + if self._detail is None: + self._cache_status = _CacheStatus.NOT_IN_CACHE + else: + self._cache_status = self._detail.cache_status + + def has_detail(self): + """Whether this train job has detailed info in cache.""" + return bool(self._detail is not None) + + def get_detail(self, key): """ - Initialize the pool of loader and the dict of name-to-path. + Get detail content. Args: - loader_generators (list[LoaderGenerator]): Loader generators help generate loaders. + key (Any): Cache key. - self._status: Refer `datavisual.common.enums.DataManagerStatus`. - self._loader_pool: {'loader_id': }. + Returns: + Any, cache content. + + Raises: + TrainJobDetailNotInCacheError: when this train job has no detail cache. + + """ + if not self.has_detail(): + raise exceptions.TrainJobDetailNotInCacheError() + return self._detail.get(key) + + def get_brief(self, key): + """ + Get brief content. + Args: + key (Any): Cache key. + + Returns: + Any, cache content. + """ + return self._brief.get(key) + + +class BaseCacheItemUpdater(abc.ABC): + """Abstract base class for other modules to update cache content.""" + def update_item(self, cache_item: CachedTrainJob): + """ + Update cache item in place. + + Args: + cache_item (CachedTrainJob): The cache item to be processed. + """ + raise NotImplementedError() + + +class _BaseCacheManager: + """Base class for cache manager.""" + + def __init__(self): + # Use dict to remove duplicate updaters. + self._updaters = {} + + # key is train_id + self._lock = threading.Lock() + self._cache_items = {} + + def size(self): + """Gets used cache slots.""" + return len(self._cache_items) + + def register_cache_item_updater(self, updater: BaseCacheItemUpdater): + """Register cache item updater.""" + self._updaters[updater.__class__.__qualname__] = updater + + def get_train_jobs(self): + """Get cached train jobs.""" + copied_train_jobs = dict(self._cache_items) + return copied_train_jobs + + def get_train_job(self, train_id): + """Get cached train job.""" + try: + return self._cache_items[train_id] + except KeyError: + raise TrainJobNotExistError(train_id) + + def cache_train_job(self, train_id) -> bool: + """ + Cache given train job and update train job's last access time. + + This method should return true if reload actions should be taken to cache the train job. + + Args: + train_id (str): Train Id. + """ + raise NotImplementedError() + + def delete_train_job(self, train_id): + """Delete train job from cache.""" + if train_id in self._cache_items: + del self._cache_items[train_id] + + def has_content(self): + """Whether this cache manager has train jobs.""" + return bool(self._cache_items) + + def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): + """ + Update cache according to given train jobs on disk. + + Different cache manager should implement different cache update policies in this method. + + Args: + disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk. + """ + raise NotImplementedError() + + def _merge_with_disk(self, disk_train_jobs: Iterable[_BasicTrainJob]): + """ + Merge train jobs in cache with train jobs from disk + + This method will remove train jobs not on disk. Call this function with lock for thread safety. + + Args: + disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk. + + Returns: + dict, a dict containing train jobs to be cached. + """ + new_cache_items = {} + for train_job in disk_train_jobs: + if train_job.train_id not in self._cache_items: + new_cache_items[train_job.train_id] = CachedTrainJob(train_job) + else: + reused_train_job = self._cache_items[train_job.train_id] + reused_train_job.basic_info = train_job + new_cache_items[train_job.train_id] = reused_train_job + + return new_cache_items + + +class _BriefCacheManager(_BaseCacheManager): + """A cache manager that holds all disk train jobs on disk.""" + + def cache_train_job(self, train_id): + """ + Cache given train job. + + All disk train jobs are cached on every reload, so this method always return false. + + Args: + train_id (str): Train Id. """ + if train_id in self._cache_items: + self._cache_items[train_id].update_access_time() + + return False + + def update_cache(self, disk_train_jobs): + """Update cache.""" + with self._lock: + new_cache_items = self._merge_with_disk(disk_train_jobs) + self._cache_items = new_cache_items + for updater in self._updaters.values(): + for cache_item in self._cache_items.values(): + updater.update_item(cache_item) + + +# Key for plugin tags. +DATAVISUAL_PLUGIN_KEY = "tag_mapping" +# Detail train job cache key for datavisual content. +DATAVISUAL_CACHE_KEY = "datavisual" + + +class _DetailCacheManager(_BaseCacheManager): + """A cache manager that holds detailed info for most recently used train jobs.""" + def __init__(self, loader_generators): + super().__init__() self._loader_pool = {} self._deleted_id_list = [] - self._status = DataManagerStatus.INIT.value - self._status_mutex = threading.Lock() self._loader_pool_mutex = threading.Lock() self._max_threads_count = 30 - self._reload_interval = 3 - self._loader_generators = loader_generators + def size(self): + """ + Get the number of items in this cache manager. + + To be implemented. + + Returns: + int, the number of items in this cache manager. + """ + raise NotImplementedError() + + def loader_pool_size(self): + """Get loader pool size.""" + return len(self._loader_pool) + + def update_cache(self, disk_train_jobs: Iterable[_BasicTrainJob]): + """ + Update cache. + + Will switch to using disk_train_jobs in the future. + + Args: + disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk. + + """ + self._generate_loaders() + self._execute_load_data() + + def cache_train_job(self, train_id): + """Cache given train job.""" + loader = None + need_reload = False + with self._loader_pool_mutex: + if self._is_loader_in_loader_pool(train_id, self._loader_pool): + loader = self._loader_pool.get(train_id) + + if loader is None: + for generator in self._loader_generators: + tmp_loader = generator.generate_loader_by_train_id(train_id) + if loader and loader.latest_update_time > tmp_loader.latest_update_time: + continue + loader = tmp_loader + + if loader is None: + raise TrainJobNotExistError(train_id) + + self._add_loader(loader) + need_reload = True + + self._update_loader_latest_update_time(loader.loader_id) + return need_reload + + def get_train_jobs(self): + """ + Get train jobs + + To be implemented. + """ + def _add_loader(self, loader): """ Add a loader to load data. Args: loader (LoaderStruct): A object of `Loader`. - """ if len(self._loader_pool) >= MAX_DATA_LOADER_SIZE: delete_number = len(self._loader_pool) - MAX_DATA_LOADER_SIZE + 1 @@ -85,40 +417,328 @@ class DataManager: self._delete_loader(delete_loader_id) self._loader_pool.update({loader.loader_id: loader}) - def _delete_loader(self, loader_id): - """ - Delete loader from loader pool by loader id. + def _delete_loader(self, loader_id): + """ + Delete loader from loader pool by loader id. + + Args: + loader_id (str): ID of loader. + """ + if self._loader_pool.get(loader_id) is not None: + logger.debug("delete loader %s", loader_id) + self._loader_pool.pop(loader_id) + + def _execute_loader(self, loader_id): + """ + Load data form data_loader. + + If there is something wrong by loading, add logs and delete the loader. + + Args: + loader_id (str): An ID for `Loader`. + + """ + try: + with self._loader_pool_mutex: + loader = self._loader_pool.get(loader_id, None) + if loader is None: + logger.debug("Loader %r has been deleted, will not load data.", loader_id) + return + loader.data_loader.load() + except MindInsightException as ex: + logger.warning("Data loader %r load data failed. " + "Delete data_loader. Detail: %s", loader_id, ex) + + with self._loader_pool_mutex: + self._delete_loader(loader_id) + + def _generate_loaders(self): + """This function generates the loader from given path.""" + loader_dict = {} + for generator in self._loader_generators: + loader_dict.update(generator.generate_loaders(self._loader_pool)) + + sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time) + latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:] + self._deal_loaders(latest_loaders) + + def _deal_loaders(self, latest_loaders): + """ + This function determines which loaders to keep or remove or added. + + It is based on the given dict of loaders. + + Args: + latest_loaders (list[dict]): A list of . + """ + + with self._loader_pool_mutex: + for loader_id, loader in latest_loaders: + if self._loader_pool.get(loader_id, None) is None: + self._add_loader(loader) + continue + + # If this loader was updated manually before, + # its latest_update_time may bigger than update_time in summary. + if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time: + self._update_loader_latest_update_time(loader_id, loader.latest_update_time) + + def _execute_load_data(self): + """Load data through multiple threads.""" + threads_count = self._get_threads_count() + if not threads_count: + logger.info("Can not find any valid train log path to load, loader pool is empty.") + return + + logger.info("Start to execute load data. threads_count: %s.", threads_count) + + with ThreadPoolExecutor(max_workers=threads_count) as executor: + futures = [] + loader_pool = self._get_snapshot_loader_pool() + for loader_id in loader_pool: + future = executor.submit(self._execute_loader, loader_id) + futures.append(future) + wait(futures, return_when=ALL_COMPLETED) + + def _get_threads_count(self): + """ + Use the maximum number of threads available. + + Returns: + int, number of threads. + + """ + threads_count = min(self._max_threads_count, len(self._loader_pool)) + + return threads_count + + def delete_train_job(self, train_id): + """ + Delete train job with a train id. + + Args: + train_id (str): ID for train job. + + """ + with self._loader_pool_mutex: + self._delete_loader(train_id) + + def list_tensors(self, train_id, tag): + """ + List tensors of the given train job and tag. + + If the tensor can not find by the given tag, will raise exception. + + Args: + train_id (str): ID for train job. + tag (str): The tag name. + + Returns: + NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`. + the value will contain the given tag data. + + """ + loader_pool = self._get_snapshot_loader_pool() + if not self._is_loader_in_loader_pool(train_id, loader_pool): + raise TrainJobNotExistError("Can not find the given train job in cache.") + + data_loader = loader_pool[train_id].data_loader + events_data = data_loader.get_events_data() + + try: + tensors = events_data.tensors(tag) + except KeyError: + error_msg = "Can not find any data in this train job by given tag." + raise ParamValueError(error_msg) + + return tensors + + def _check_train_job_exist(self, train_id, loader_pool): + """ + Check train job exist, if not exist, will raise exception. + + Args: + train_id (str): The given train job id. + loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool. + + Raises: + TrainJobNotExistError: Can not find train job in data manager. + """ + is_exist = False + if train_id in loader_pool: + return + for generator in self._loader_generators: + if generator.check_train_job_exist(train_id): + is_exist = True + break + if not is_exist: + raise TrainJobNotExistError("Can not find the train job in data manager.") + + def _is_loader_in_loader_pool(self, train_id, loader_pool): + """ + Check train job exist, if not exist, return False. Else, return True. + + Args: + train_id (str): The given train job id. + loader_pool (dict): See self._loader_pool. + + Returns: + bool, if loader in loader pool, return True. + """ + if train_id in loader_pool: + return True + return False + + def _get_snapshot_loader_pool(self): + """ + Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues. + + Returns: + dict, a copy of `self._loader_pool`. + """ + with self._loader_pool_mutex: + return dict(self._loader_pool) + + def get_train_job(self, train_id): + """ + Get train job by train ID. + + This method overrides parent method. + + Args: + train_id (str): Train ID for train job. + Returns: + dict, single train job, if can not find any data, will return None. + """ + self._check_train_job_exist(train_id, self._loader_pool) + + loader = self._get_loader(train_id) + if loader is None: + logger.warning("No valid summary log in train job %s, " + "or it is not in the cache.", train_id) + return None + + train_job = loader.to_dict() + train_job.pop('data_loader') + + plugin_data = {} + for plugin_name in PluginNameEnum.list_members(): + job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name) + if job is None: + plugin_data[plugin_name] = [] + else: + plugin_data[plugin_name] = job['tags'] + + train_job.update({DATAVISUAL_PLUGIN_KEY: plugin_data}) + + # Will fill basic_info value in future. + train_job_obj = CachedTrainJob(basic_info=None) + train_job_obj.set(DATAVISUAL_CACHE_KEY, train_job) + + # Will assign real value in future. + train_job_obj.cache_status = _CacheStatus.CACHED + + return train_job_obj + + def _get_loader(self, train_id): + """ + Get loader by train id. + + Args: + train_id (str): Train Id. + + Returns: + LoaderStruct, the loader. + """ + loader = None + with self._loader_pool_mutex: + if self._is_loader_in_loader_pool(train_id, self._loader_pool): + loader = self._loader_pool.get(train_id) + + return loader + + def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): + """ + Update loader with latest_update_time. + + Args: + loader_id (str): ID of loader. + latest_update_time (float): Timestamp. + """ + if latest_update_time is None: + latest_update_time = time.time() + self._loader_pool[loader_id].latest_update_time = latest_update_time + + def get_train_job_by_plugin(self, train_id, plugin_name): + """ + Get a train job by train job id. + + If the given train job does not has the given plugin data, the tag list will be empty. + + Args: + train_id (str): Get train job info by the given id. + plugin_name (str): Get tags by given plugin. + + Returns: + TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}), + a train job object. + + """ + self._check_train_job_exist(train_id, self._loader_pool) + + loader = self._get_loader(train_id) + if loader is None: + logger.warning("No valid summary log in train job %s, " + "or it is not in the cache.", train_id) + return None + + name = loader.name + data_loader = loader.data_loader + + tags = [] + try: + events_data = data_loader.get_events_data() + tags = events_data.list_tags_by_plugin(plugin_name) + except KeyError: + logger.debug("Plugin name %r does not exist " + "in train job %r, and set tags to empty list.", plugin_name, name) + except AttributeError: + logger.debug("Train job %r has been deleted or it has not loaded data, " + "and set tags to empty list.", name) + + result = dict(id=train_id, name=name, tags=tags) + return result - Args: - loader_id (str): ID of loader. - """ - if self._loader_pool.get(loader_id) is not None: - logger.debug("delete loader %s", loader_id) - self._loader_pool.pop(loader_id) - def _execute_loader(self, loader_id): - """ - Load data form data_loader. +class DataManager: + """ + DataManager manages a pool of loader which help access events data. - If there is something wrong by loading, add logs and delete the loader. + Each loader helps deal the data of the events. + A loader corresponds to an events_data. + The DataManager build a pool including all the data_loader. + The data_loader provides extracting + method to get the information of events. + """ + def __init__(self, summary_base_dir): + """ + Initialize the pool of loader and the dict of name-to-path. Args: - loader_id (str): An ID for `Loader`. + summary_base_dir (str): Base summary directory. + + self._status: Refer `datavisual.common.enums.DataManagerStatus`. """ - try: - with self._loader_pool_mutex: - loader = self._loader_pool.get(loader_id, None) - if loader is None: - logger.debug("Loader %r has been deleted, will not load data.", loader_id) - return - loader.data_loader.load() - except MindInsightException as ex: - logger.warning("Data loader %r load data failed. " - "Delete data_loader. Detail: %s", loader_id, ex) + self._summary_base_dir = os.path.realpath(summary_base_dir) + self._status = DataManagerStatus.INIT.value + self._status_mutex = threading.Lock() - with self._loader_pool_mutex: - self._delete_loader(loader_id) + self._reload_interval = 3 + + loader_generators = [DataLoaderGenerator(self._summary_base_dir)] + self._detail_cache = _DetailCacheManager(loader_generators) + self._brief_cache = _BriefCacheManager() def start_load_data(self, reload_interval=settings.RELOAD_INTERVAL, @@ -176,64 +796,31 @@ class DataManager: return self.status = DataManagerStatus.LOADING.value - self._generate_loaders() - self._execute_load_data() - - if not self._loader_pool: + summaries_info = SummaryWatcher().list_summary_directories(self._summary_base_dir) + + basic_train_jobs = [] + for info in summaries_info: + basic_train_jobs.append(_BasicTrainJob( + train_id=info['relative_path'], + abs_summary_base_dir=self._summary_base_dir, + abs_summary_dir=os.path.realpath(os.path.join( + self._summary_base_dir, + info['relative_path'] + )), + create_time=info['create_time'], + update_time=info['update_time'] + )) + + self._brief_cache.update_cache(basic_train_jobs) + self._detail_cache.update_cache(basic_train_jobs) + + if not self._brief_cache.has_content() and not self._detail_cache.has_content(): self.status = DataManagerStatus.INVALID.value else: self.status = DataManagerStatus.DONE.value logger.info("Load event data end, status: %r, and loader pool size is %r.", - self.status, len(self._loader_pool)) - - def _generate_loaders(self): - """This function generates the loader from given path.""" - loader_dict = {} - for generator in self._loader_generators: - loader_dict.update(generator.generate_loaders(self._loader_pool)) - - sorted_loaders = sorted(loader_dict.items(), key=lambda loader: loader[1].latest_update_time) - latest_loaders = sorted_loaders[-MAX_DATA_LOADER_SIZE:] - self._deal_loaders(latest_loaders) - - def _deal_loaders(self, latest_loaders): - """ - This function determines which loaders to keep or remove or added. - - It is based on the given dict of loaders. - - Args: - latest_loaders (list[dict]): A list of . - """ - - with self._loader_pool_mutex: - for loader_id, loader in latest_loaders: - if self._loader_pool.get(loader_id, None) is None: - self._add_loader(loader) - continue - - # If this loader was updated manually before, - # its latest_update_time may bigger than update_time in summary. - if self._loader_pool[loader_id].latest_update_time < loader.latest_update_time: - self._update_loader_latest_update_time(loader_id, loader.latest_update_time) - - def _execute_load_data(self): - """Load data through multiple threads.""" - threads_count = self._get_threads_count() - if not threads_count: - logger.info("Can not find any valid train log path to load, loader pool is empty.") - return - - logger.info("Start to execute load data. threads_count: %s.", threads_count) - - with ThreadPoolExecutor(max_workers=threads_count) as executor: - futures = [] - loader_pool = self._get_snapshot_loader_pool() - for loader_id in loader_pool: - future = executor.submit(self._execute_loader, loader_id) - futures.append(future) - wait(futures, return_when=ALL_COMPLETED) + self.status, self._detail_cache.loader_pool_size()) @staticmethod def check_reload_interval(reload_interval): @@ -262,18 +849,6 @@ class DataManager: if max_threads_count <= 0: raise ParamValueError("The value of max threads count should be > 0.") - def _get_threads_count(self): - """ - Use the maximum number of threads available. - - Returns: - int, number of threads. - - """ - threads_count = min(self._max_threads_count, len(self._loader_pool)) - - return threads_count - def get_train_job_by_plugin(self, train_id, plugin_name): """ Get a train job by train job id. @@ -290,32 +865,9 @@ class DataManager: """ self._check_status_valid() - self._check_train_job_exist(train_id, self._loader_pool) - - loader = self._get_loader(train_id) - if loader is None: - logger.warning("No valid summary log in train job %s, " - "or it is not in the cache.", train_id) - return None - - name = loader.name - data_loader = loader.data_loader - - tags = [] - try: - events_data = data_loader.get_events_data() - tags = events_data.list_tags_by_plugin(plugin_name) - except KeyError: - logger.debug("Plugin name %r does not exist " - "in train job %r, and set tags to empty list.", plugin_name, name) - except AttributeError: - logger.debug("Train job %r has been deleted or it has not loaded data, " - "and set tags to empty list.", name) - - result = dict(id=train_id, name=name, tags=tags) - return result + return self._detail_cache.get_train_job_by_plugin(train_id, plugin_name) - def delete_train_job(self, train_id): + def delete_train_job(self, train_id, only_delete_from_cache=True): """ Delete train job with a train id. @@ -323,8 +875,11 @@ class DataManager: train_id (str): ID for train job. """ - with self._loader_pool_mutex: - self._delete_loader(train_id) + if not only_delete_from_cache: + raise NotImplementedError("Delete from both cache and disk is not supported.") + + self._brief_cache.delete_train_job(train_id) + self._detail_cache.delete_train_job(train_id) def list_tensors(self, train_id, tag): """ @@ -342,66 +897,7 @@ class DataManager: """ self._check_status_valid() - loader_pool = self._get_snapshot_loader_pool() - if not self._is_loader_in_loader_pool(train_id, loader_pool): - raise TrainJobNotExistError("Can not find the given train job in cache.") - - data_loader = loader_pool[train_id].data_loader - events_data = data_loader.get_events_data() - - try: - tensors = events_data.tensors(tag) - except KeyError: - error_msg = "Can not find any data in this train job by given tag." - raise ParamValueError(error_msg) - - return tensors - - def _check_train_job_exist(self, train_id, loader_pool): - """ - Check train job exist, if not exist, will raise exception. - - Args: - train_id (str): The given train job id. - loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool. - - Raises: - TrainJobNotExistError: Can not find train job in data manager. - """ - is_exist = False - if train_id in loader_pool: - return - for generator in self._loader_generators: - if generator.check_train_job_exist(train_id): - is_exist = True - break - if not is_exist: - raise TrainJobNotExistError("Can not find the train job in data manager.") - - def _is_loader_in_loader_pool(self, train_id, loader_pool): - """ - Check train job exist, if not exist, return False. Else, return True. - - Args: - train_id (str): The given train job id. - loader_pool (dict): See self._loader_pool. - - Returns: - bool, if loader in loader pool, return True. - """ - if train_id in loader_pool: - return True - return False - - def _get_snapshot_loader_pool(self): - """ - Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues. - - Returns: - dict, a copy of `self._loader_pool`. - """ - with self._loader_pool_mutex: - return dict(self._loader_pool) + return self._detail_cache.list_tensors(train_id, tag) def _check_status_valid(self): """Check if the status is valid to load data.""" @@ -409,90 +905,29 @@ class DataManager: if self.status == DataManagerStatus.INIT.value: raise exceptions.SummaryLogIsLoading("Data is being loaded, current status: %s." % self._status) - def get_single_train_job(self, train_id, manual_update=False): + def get_train_job(self, train_id): """ Get train job by train ID. Args: train_id (str): Train ID for train job. - manual_update (bool): If manual update, True. Returns: dict, single train job, if can not find any data, will return None. """ self._check_status_valid() - self._check_train_job_exist(train_id, self._loader_pool) - - loader = self._get_loader(train_id, manual_update) - if loader is None: - logger.warning("No valid summary log in train job %s, " - "or it is not in the cache.", train_id) - return None - - train_job = loader.to_dict() - train_job.pop('data_loader') - - plugin_data = {} - for plugin_name in PluginNameEnum.list_members(): - job = self.get_train_job_by_plugin(train_id, plugin_name=plugin_name) - if job is None: - plugin_data[plugin_name] = [] - else: - plugin_data[plugin_name] = job['tags'] - - train_job.update({'tag_mapping': plugin_data}) - - return train_job - - def _get_loader(self, train_id, manual_update=False): - """ - Get loader by train id. - - Args: - train_id (str): Train Id. - manual_update (bool): If manual, True. Else False. - - Returns: - LoaderStruct, the loader. - """ - loader = None - is_reload = False - with self._loader_pool_mutex: - if self._is_loader_in_loader_pool(train_id, self._loader_pool): - loader = self._loader_pool.get(train_id) - - if manual_update and loader is None: - for generator in self._loader_generators: - tmp_loader = generator.generate_loader_by_train_id(train_id) - if loader and loader.latest_update_time > tmp_loader.latest_update_time: - continue - loader = tmp_loader - - if loader is None: - return None - - self._add_loader(loader) - is_reload = True - - if manual_update: - self._update_loader_latest_update_time(loader.loader_id) - - if is_reload: - self.reload_data() + detail_train_job = self._detail_cache.get_train_job(train_id) + brief_train_job = self._brief_cache.get_train_job(train_id) - return loader + return TrainJob(brief_train_job, detail_train_job) - def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): + def list_train_jobs(self): """ - Update loader with latest_update_time. + List train jobs. - Args: - loader_id (str): ID of loader. - latest_update_time (float): Timestamp. + To be implemented. """ - if latest_update_time is None: - latest_update_time = time.time() - self._loader_pool[loader_id].latest_update_time = latest_update_time + raise NotImplementedError() @property def status(self): @@ -509,6 +944,16 @@ class DataManager: """Set data manger status.""" self._status = status + def cache_train_job(self, train_id): + """Cache given train job (async).""" + brief_need_reload = self._brief_cache.cache_train_job(train_id) + detail_need_reload = self._detail_cache.cache_train_job(train_id) + if brief_need_reload or detail_need_reload: + self.reload_data() + + def register_brief_cache_item_updater(self, updater: BaseCacheItemUpdater): + """Register brief cache item updater for brief cache manager.""" + self._brief_cache.register_cache_item_updater(updater) + -_loader_generators = [DataLoaderGenerator(settings.SUMMARY_BASE_DIR)] -DATA_MANAGER = DataManager(_loader_generators) +DATA_MANAGER = DataManager(settings.SUMMARY_BASE_DIR) diff --git a/mindinsight/datavisual/processors/train_task_manager.py b/mindinsight/datavisual/processors/train_task_manager.py index 05b950891226349047c310180d759ff09773b9ac..bf72e2b5c0bb62e68271506d364220563ad29b1f 100644 --- a/mindinsight/datavisual/processors/train_task_manager.py +++ b/mindinsight/datavisual/processors/train_task_manager.py @@ -18,6 +18,7 @@ from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.processors.base_processor import BaseProcessor +from mindinsight.datavisual.data_transform.data_manager import DATAVISUAL_PLUGIN_KEY, DATAVISUAL_CACHE_KEY class TrainTaskManager(BaseProcessor): @@ -53,13 +54,24 @@ class TrainTaskManager(BaseProcessor): dict, refer to restful api. """ Validation.check_param_empty(train_id=train_id) - train_job = self._data_manager.get_single_train_job(train_id, manual_update=manual_update) - if not train_job: + + if manual_update: + self._data_manager.cache_train_job(train_id) + + train_job = self._data_manager.get_train_job(train_id) + + try: + data_visual_content = train_job.get_detail(DATAVISUAL_CACHE_KEY) + plugins = data_visual_content.get(DATAVISUAL_PLUGIN_KEY) + except exceptions.TrainJobDetailNotInCacheError: + plugins = [] + + if not plugins: default_result = dict() for plugin_name in PluginNameEnum.list_members(): default_result.update({plugin_name: list()}) return dict(plugins=default_result) return dict( - plugins=train_job['tag_mapping'] + plugins=plugins ) diff --git a/mindinsight/lineagemgr/cache_item_updater.py b/mindinsight/lineagemgr/cache_item_updater.py new file mode 100644 index 0000000000000000000000000000000000000000..78648353087eae9b77b8030bade1ba13616d13fc --- /dev/null +++ b/mindinsight/lineagemgr/cache_item_updater.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Cache item updater.""" +import os + +from mindinsight.datavisual.data_transform.data_manager import BaseCacheItemUpdater, CachedTrainJob +from mindinsight.lineagemgr.querier.query_model import LineageObj +from mindinsight.lineagemgr.summary.lineage_summary_analyzer import LineageSummaryAnalyzer + + +class LineageCacheItemUpdater(BaseCacheItemUpdater): + """Cache item updater for lineage info.""" + + def update_item(self, cache_item: CachedTrainJob): + """Update cache item in place.""" + log_path = cache_item.summary_dir + log_dir = os.path.dirname(log_path) + lineage_info = LineageSummaryAnalyzer.get_summary_infos(log_path) + user_defined_info = LineageSummaryAnalyzer.get_user_defined_info(log_path) + lineage_obj = LineageObj( + log_dir, + train_lineage=lineage_info.train_lineage, + evaluation_lineage=lineage_info.eval_lineage, + dataset_graph=lineage_info.dataset_graph, + user_defined_info=user_defined_info + ) + cache_item.set(key="lineage", value=lineage_obj) diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 97c860c4437e4daa302c5d3966a050b85f4f0915..4e4cb83f8f28d2e9df0cccb3ab3d5a69ec00c4fa 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -63,3 +63,4 @@ class DataVisualErrors(Enum): IMAGE_NOT_EXIST = 13 SCALAR_NOT_EXIST = 14 HISTOGRAM_NOT_EXIST = 15 + TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 diff --git a/tests/st/func/datavisual/conftest.py b/tests/st/func/datavisual/conftest.py index e7053d6b7fdafdb1408819bca9e94af47017e3b4..0fb0ec8c909f12384331641d575c421206e358c9 100644 --- a/tests/st/func/datavisual/conftest.py +++ b/tests/st/func/datavisual/conftest.py @@ -25,7 +25,6 @@ from flask import Response from mindinsight.conf import settings from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform.data_manager import DataManager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.utils import tools @@ -59,7 +58,7 @@ def init_summary_logs(): log_operations = LogOperations() summaries_metadata = log_operations.create_summary_logs(summary_base_dir, constants.SUMMARY_DIR_NUM_FIRST, constants.SUMMARY_DIR_PREFIX) - mock_data_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) + mock_data_manager = DataManager(summary_base_dir) mock_data_manager.start_load_data(reload_interval=0) check_loading_done(mock_data_manager) diff --git a/tests/ut/datavisual/data_transform/test_data_manager.py b/tests/ut/datavisual/data_transform/test_data_manager.py index 80bab30bb870b034cb74ab53dfe18fed811b7e04..4261f263ec5757a417d5af332e43378bc38daf2d 100644 --- a/tests/ut/datavisual/data_transform/test_data_manager.py +++ b/tests/ut/datavisual/data_transform/test_data_manager.py @@ -33,7 +33,6 @@ from mindinsight.datavisual.data_transform import data_manager, ms_data_loader from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_manager import DataManager from mindinsight.datavisual.data_transform.events_data import EventsData -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_struct import LoaderStruct from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader @@ -89,7 +88,7 @@ class TestDataManager: train_ids.append(f'./dir{i}') data_manager.logger = MockLogger - mock_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + mock_manager = data_manager.DataManager(summary_base_dir) mock_manager.start_load_data(reload_interval=0) check_loading_done(mock_manager) @@ -112,7 +111,7 @@ class TestDataManager: def test_start_load_data_with_invalid_params(self, params): """Test start_load_data with invalid reload_interval or invalid max_threads_count.""" summary_base_dir = tempfile.mkdtemp() - d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) + d_manager = DataManager(summary_base_dir) with pytest.raises(ParamValueError): d_manager.start_load_data(**params) shutil.rmtree(summary_base_dir) @@ -142,9 +141,9 @@ class TestDataManager: latest_update_time=modify_time_01, data_loader=loader_01) loader_pool = {train_job_01: loader} - d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) + d_manager = DataManager(summary_base_dir) d_manager._status = DataManagerStatus.LOADING.value - d_manager._loader_pool = loader_pool + d_manager._detail_cache._loader_pool = loader_pool res = d_manager.list_tensors(train_job_01, tag) assert res == {'test result'} @@ -169,9 +168,9 @@ class TestDataManager: latest_update_time=modify_time_01, data_loader=loader_01) loader_pool = {train_job_01: loader} - d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) + d_manager = DataManager(summary_base_dir) d_manager._status = DataManagerStatus.LOADING.value - d_manager._loader_pool = loader_pool + d_manager._detail_cache._loader_pool = loader_pool tag = 'image' with pytest.raises(ParamValueError): d_manager.list_tensors(train_job_01, tag) @@ -181,7 +180,7 @@ class TestDataManager: def test_list_tensors_with_not_exist_train_job(self): """Test list_tensors method with parameter train_id not found in loader_pool.""" summary_base_dir = tempfile.mkdtemp() - d_manager = DataManager([DataLoaderGenerator(summary_base_dir)]) + d_manager = DataManager(summary_base_dir) d_manager._status = DataManagerStatus.LOADING.value tag = 'image' train_job_01 = 'train_01' @@ -200,13 +199,12 @@ class TestDataManager: expected_loader_ids = list(loader_dict.keys()) mock_generate_loaders.return_value = loader_dict - generators = [data_manager.DataLoaderGenerator(summary_base_dir)] - mock_data_manager = data_manager.DataManager(generators) - mock_data_manager._execute_load_data = Mock() + mock_data_manager = data_manager.DataManager(summary_base_dir) + mock_data_manager._detail_cache._execute_load_data = Mock() mock_data_manager.start_load_data(reload_interval=0) check_loading_done(mock_data_manager, 3) - current_loader_ids = mock_data_manager._loader_pool.keys() + current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys() assert sorted(current_loader_ids) == sorted(expected_loader_ids) @@ -221,7 +219,7 @@ class TestDataManager: mock_generate_loaders.return_value = loader_dict mock_data_manager.start_load_data(reload_interval=0) check_loading_done(mock_data_manager) - current_loader_ids = mock_data_manager._loader_pool.keys() + current_loader_ids = mock_data_manager._detail_cache._loader_pool.keys() assert sorted(current_loader_ids) == sorted(expected_loader_ids) diff --git a/tests/ut/datavisual/processors/test_graph_processor.py b/tests/ut/datavisual/processors/test_graph_processor.py index 678e459d9a459257bcf302d0b29ad0aa462b5a97..231a954559ee84b974329951780eebd0d99ed852 100644 --- a/tests/ut/datavisual/processors/test_graph_processor.py +++ b/tests/ut/datavisual/processors/test_graph_processor.py @@ -30,7 +30,6 @@ from mindinsight.datavisual.common.exceptions import GraphNotExistError from mindinsight.datavisual.common.exceptions import NodeNotInGraphError from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform.data_manager import DataManager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError @@ -74,7 +73,7 @@ class TestGraphProcessor: self._temp_path, self._graph_dict, _ = log_operation.generate_log(PluginNameEnum.GRAPH.value, log_dir) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager(summary_base_dir) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done @@ -93,7 +92,7 @@ class TestGraphProcessor: self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager(summary_base_dir) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done diff --git a/tests/ut/datavisual/processors/test_histogram_processor.py b/tests/ut/datavisual/processors/test_histogram_processor.py index 638d2115fe4c3a00fc4b2fe412c7869559cb1309..0e5a46ce46c8ec2b4fb771e7770f10bb770a3e10 100644 --- a/tests/ut/datavisual/processors/test_histogram_processor.py +++ b/tests/ut/datavisual/processors/test_histogram_processor.py @@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import HistogramNotExistError from mindinsight.datavisual.data_transform import data_manager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor from mindinsight.datavisual.utils import crc32 @@ -72,7 +71,7 @@ class TestHistogramProcessor: PluginNameEnum.HISTOGRAM.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager(summary_base_dir) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done diff --git a/tests/ut/datavisual/processors/test_images_processor.py b/tests/ut/datavisual/processors/test_images_processor.py index 36052d4954629e424d2e7750c43f7ffb2193a0d1..99846314a8f60b55f354a683546b0ad81387b369 100644 --- a/tests/ut/datavisual/processors/test_images_processor.py +++ b/tests/ut/datavisual/processors/test_images_processor.py @@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.data_transform import data_manager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.utils import crc32 @@ -81,7 +80,7 @@ class TestImagesProcessor: PluginNameEnum.IMAGE.value, log_dir, dict(steps=steps_list, tag=self._tag_name)) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager(summary_base_dir) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done diff --git a/tests/ut/datavisual/processors/test_scalars_processor.py b/tests/ut/datavisual/processors/test_scalars_processor.py index c741bc54cd92fd27e43791611d77ed584fb3a449..f269166fb0240bf2bab4388a7fd18339d09ab3ef 100644 --- a/tests/ut/datavisual/processors/test_scalars_processor.py +++ b/tests/ut/datavisual/processors/test_scalars_processor.py @@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import ScalarNotExistError from mindinsight.datavisual.data_transform import data_manager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.utils import crc32 @@ -73,7 +72,7 @@ class TestScalarsProcessor: PluginNameEnum.SCALAR.value, log_dir, dict(step=self._steps_list, tag=self._tag_name)) self._generated_path.append(summary_base_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(summary_base_dir)]) + self._mock_data_manager = data_manager.DataManager(summary_base_dir) self._mock_data_manager.start_load_data(reload_interval=0) # wait for loading done diff --git a/tests/ut/datavisual/processors/test_train_task_manager.py b/tests/ut/datavisual/processors/test_train_task_manager.py index 55761f96de8646801610201dfec5336f6c71d54e..98b1af1f20fef50ff040126ad2c3aa57f6cf0f57 100644 --- a/tests/ut/datavisual/processors/test_train_task_manager.py +++ b/tests/ut/datavisual/processors/test_train_task_manager.py @@ -27,7 +27,6 @@ import pytest from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.data_transform import data_manager -from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.processors.train_task_manager import TrainTaskManager from mindinsight.datavisual.utils import crc32 @@ -97,7 +96,7 @@ class TestTrainTaskManager: self._generated_path.append(self._root_dir) - self._mock_data_manager = data_manager.DataManager([DataLoaderGenerator(self._root_dir)]) + self._mock_data_manager = data_manager.DataManager(self._root_dir) self._mock_data_manager.start_load_data(reload_interval=0) check_loading_done(self._mock_data_manager, time_limit=30)