提交 a5cabe8c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!440 Change the mindinsight multiprocessing computing code to use a unified...

!440 Change the mindinsight multiprocessing computing code to use a unified manager, add new  features
Merge pull request !440 from wenkai/pref_opt_0720_1cp1
...@@ -34,11 +34,11 @@ class DataLoader: ...@@ -34,11 +34,11 @@ class DataLoader:
self._summary_dir = summary_dir self._summary_dir = summary_dir
self._loader = None self._loader = None
def load(self, workers_count=1): def load(self, computing_resource_mgr):
"""Load the data when loader is exist. """Load the data when loader is exist.
Args: Args:
workers_count (int): The count of workers. Default value is 1. computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance.
""" """
if self._loader is None: if self._loader is None:
...@@ -53,7 +53,7 @@ class DataLoader: ...@@ -53,7 +53,7 @@ class DataLoader:
logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir) logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir)
raise exceptions.SummaryLogPathInvalid() raise exceptions.SummaryLogPathInvalid()
self._loader.load(workers_count) self._loader.load(computing_resource_mgr)
def get_events_data(self): def get_events_data(self):
""" """
......
...@@ -40,6 +40,7 @@ from mindinsight.datavisual.common.enums import PluginNameEnum ...@@ -40,6 +40,7 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE from mindinsight.datavisual.data_transform.loader_generators.loader_generator import MAX_DATA_LOADER_SIZE
from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator from mindinsight.datavisual.data_transform.loader_generators.data_loader_generator import DataLoaderGenerator
from mindinsight.utils.computing_resource_mgr import ComputingResourceManager
from mindinsight.utils.exceptions import MindInsightException from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import UnknownError from mindinsight.utils.exceptions import UnknownError
...@@ -510,7 +511,7 @@ class _DetailCacheManager(_BaseCacheManager): ...@@ -510,7 +511,7 @@ class _DetailCacheManager(_BaseCacheManager):
logger.debug("delete loader %s", loader_id) logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id) self._loader_pool.pop(loader_id)
def _execute_loader(self, loader_id, workers_count): def _execute_loader(self, loader_id, computing_resource_mgr):
""" """
Load data form data_loader. Load data form data_loader.
...@@ -518,7 +519,7 @@ class _DetailCacheManager(_BaseCacheManager): ...@@ -518,7 +519,7 @@ class _DetailCacheManager(_BaseCacheManager):
Args: Args:
loader_id (str): An ID for `Loader`. loader_id (str): An ID for `Loader`.
workers_count (int): The count of workers. computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance.
""" """
try: try:
with self._loader_pool_mutex: with self._loader_pool_mutex:
...@@ -527,7 +528,7 @@ class _DetailCacheManager(_BaseCacheManager): ...@@ -527,7 +528,7 @@ class _DetailCacheManager(_BaseCacheManager):
logger.debug("Loader %r has been deleted, will not load data.", loader_id) logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return return
loader.data_loader.load(workers_count) loader.data_loader.load(computing_resource_mgr)
# Update loader cache status to CACHED. # Update loader cache status to CACHED.
# Loader with cache status CACHED should remain the same cache status. # Loader with cache status CACHED should remain the same cache status.
...@@ -580,11 +581,15 @@ class _DetailCacheManager(_BaseCacheManager): ...@@ -580,11 +581,15 @@ class _DetailCacheManager(_BaseCacheManager):
logger.info("Start to execute load data. threads_count: %s.", threads_count) logger.info("Start to execute load data. threads_count: %s.", threads_count)
with ComputingResourceManager(
executors_cnt=threads_count,
max_processes_cnt=settings.MAX_PROCESSES_COUNT) as computing_resource_mgr:
with ThreadPoolExecutor(max_workers=threads_count) as executor: with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = [] futures = []
loader_pool = self._get_snapshot_loader_pool() loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool: for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id, threads_count) future = executor.submit(self._execute_loader, loader_id, computing_resource_mgr)
futures.append(future) futures.append(future)
wait(futures, return_when=ALL_COMPLETED) wait(futures, return_when=ALL_COMPLETED)
......
...@@ -19,17 +19,12 @@ This module is used to load the MindSpore training log file. ...@@ -19,17 +19,12 @@ This module is used to load the MindSpore training log file.
Each instance will read an entire run, a run can contain one or Each instance will read an entire run, a run can contain one or
more log file. more log file.
""" """
import concurrent.futures as futures
import math
import os
import re import re
import struct import struct
import threading
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from google.protobuf.text_format import ParseError from google.protobuf.text_format import ParseError
from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
...@@ -84,14 +79,14 @@ class MSDataLoader: ...@@ -84,14 +79,14 @@ class MSDataLoader:
"we will reload all files in path %s.", self._summary_dir) "we will reload all files in path %s.", self._summary_dir)
self.__init__(self._summary_dir) self.__init__(self._summary_dir)
def load(self, workers_count=1): def load(self, computing_resource_mgr):
""" """
Load all log valid files. Load all log valid files.
When the file is reloaded, it will continue to load from where it left off. When the file is reloaded, it will continue to load from where it left off.
Args: Args:
workers_count (int): The count of workers. Default value is 1. computing_resource_mgr (ComputingResourceManager): The ComputingResourceManager instance.
""" """
logger.debug("Start to load data in ms data loader.") logger.debug("Start to load data in ms data loader.")
filenames = self.filter_valid_files() filenames = self.filter_valid_files()
...@@ -102,8 +97,9 @@ class MSDataLoader: ...@@ -102,8 +97,9 @@ class MSDataLoader:
self._valid_filenames = filenames self._valid_filenames = filenames
self._check_files_deleted(filenames, old_filenames) self._check_files_deleted(filenames, old_filenames)
with computing_resource_mgr.get_executor() as executor:
for parser in self._parser_list: for parser in self._parser_list:
parser.parse_files(workers_count, filenames, events_data=self._events_data) parser.parse_files(executor, filenames, events_data=self._events_data)
def filter_valid_files(self): def filter_valid_files(self):
""" """
...@@ -133,12 +129,12 @@ class _Parser: ...@@ -133,12 +129,12 @@ class _Parser:
self._latest_mtime = 0 self._latest_mtime = 0
self._summary_dir = summary_dir self._summary_dir = summary_dir
def parse_files(self, workers_count, filenames, events_data): def parse_files(self, executor, filenames, events_data):
""" """
Load files and parse files content. Load files and parse files content.
Args: Args:
workers_count (int): The count of workers. executor (Executor): The executor instance.
filenames (list[str]): File name list. filenames (list[str]): File name list.
events_data (EventsData): The container of event data. events_data (EventsData): The container of event data.
""" """
...@@ -186,7 +182,7 @@ class _Parser: ...@@ -186,7 +182,7 @@ class _Parser:
class _PbParser(_Parser): class _PbParser(_Parser):
"""This class is used to parse pb file.""" """This class is used to parse pb file."""
def parse_files(self, workers_count, filenames, events_data): def parse_files(self, executor, filenames, events_data):
pb_filenames = self.filter_files(filenames) pb_filenames = self.filter_files(filenames)
pb_filenames = self.sort_files(pb_filenames) pb_filenames = self.sort_files(pb_filenames)
for filename in pb_filenames: for filename in pb_filenames:
...@@ -264,12 +260,12 @@ class _SummaryParser(_Parser): ...@@ -264,12 +260,12 @@ class _SummaryParser(_Parser):
self._summary_file_handler = None self._summary_file_handler = None
self._events_data = None self._events_data = None
def parse_files(self, workers_count, filenames, events_data): def parse_files(self, executor, filenames, events_data):
""" """
Load summary file and parse file content. Load summary file and parse file content.
Args: Args:
workers_count (int): The count of workers. executor (Executor): The executor instance.
filenames (list[str]): File name list. filenames (list[str]): File name list.
events_data (EventsData): The container of event data. events_data (EventsData): The container of event data.
""" """
...@@ -295,7 +291,9 @@ class _SummaryParser(_Parser): ...@@ -295,7 +291,9 @@ class _SummaryParser(_Parser):
self._latest_file_size = new_size self._latest_file_size = new_size
try: try:
self._load_single_file(self._summary_file_handler, workers_count) self._load_single_file(self._summary_file_handler, executor)
# Wait for data in this file to be processed to avoid loading multiple files at the same time.
executor.wait_all_tasks_finish()
except UnknownError as ex: except UnknownError as ex:
logger.warning("Parse summary file failed, detail: %r," logger.warning("Parse summary file failed, detail: %r,"
"file path: %s.", str(ex), file_path) "file path: %s.", str(ex), file_path)
...@@ -314,28 +312,14 @@ class _SummaryParser(_Parser): ...@@ -314,28 +312,14 @@ class _SummaryParser(_Parser):
lambda filename: (re.search(r'summary\.\d+', filename) lambda filename: (re.search(r'summary\.\d+', filename)
and not filename.endswith("_lineage")), filenames)) and not filename.endswith("_lineage")), filenames))
def _load_single_file(self, file_handler, workers_count): def _load_single_file(self, file_handler, executor):
""" """
Load a log file data. Load a log file data.
Args: Args:
file_handler (FileHandler): A file handler. file_handler (FileHandler): A file handler.
workers_count (int): The count of workers. executor (Executor): The executor instance.
""" """
default_concurrency = 1
cpu_count = os.cpu_count()
if cpu_count is None:
concurrency = default_concurrency
else:
concurrency = min(math.floor(cpu_count / workers_count),
math.floor(settings.MAX_PROCESSES_COUNT / workers_count))
if concurrency <= 0:
concurrency = default_concurrency
logger.debug("Load single summary file, file path: %s, concurrency: %s.", file_handler.file_path, concurrency)
semaphore = threading.Semaphore(value=concurrency)
with futures.ProcessPoolExecutor(max_workers=concurrency) as executor:
while True: while True:
start_offset = file_handler.offset start_offset = file_handler.offset
try: try:
...@@ -344,8 +328,6 @@ class _SummaryParser(_Parser): ...@@ -344,8 +328,6 @@ class _SummaryParser(_Parser):
file_handler.reset_offset(start_offset) file_handler.reset_offset(start_offset)
break break
# Make sure we have at most concurrency tasks not finished to save memory.
semaphore.acquire()
future = executor.submit(self._event_parse, event_str, self._latest_filename) future = executor.submit(self._event_parse, event_str, self._latest_filename)
def _add_tensor_event_callback(future_value): def _add_tensor_event_callback(future_value):
...@@ -367,8 +349,6 @@ class _SummaryParser(_Parser): ...@@ -367,8 +349,6 @@ class _SummaryParser(_Parser):
# Log exception for debugging. # Log exception for debugging.
logger.exception(exc) logger.exception(exc)
raise raise
finally:
semaphore.release()
future.add_done_callback(_add_tensor_event_callback) future.add_done_callback(_add_tensor_event_callback)
except exceptions.CRCFailedError: except exceptions.CRCFailedError:
......
...@@ -213,7 +213,7 @@ class HistogramReservoir(Reservoir): ...@@ -213,7 +213,7 @@ class HistogramReservoir(Reservoir):
visual_range.update(histogram_container.max, histogram_container.min) visual_range.update(histogram_container.max, histogram_container.min)
if visual_range.max == visual_range.min and not max_count: if visual_range.max == visual_range.min and not max_count:
logger.info("Max equals to min. Count is zero.") logger.debug("Max equals to min. Count is zero.")
bins = calc_histogram_bins(max_count) bins = calc_histogram_bins(max_count)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Compute resource manager."""
import fractions
import math
import threading
from concurrent import futures
from mindinsight.utils.log import utils_logger as logger
from mindinsight.utils.constant import GeneralErrors
from mindinsight.utils.exceptions import MindInsightException
class ComputingResourceManager:
"""
Manager for computing resources.
This class provides executors for computing tasks. Executors can only be used once.
Args:
executors_cnt (int): Number of executors to be provided by this class.
max_processes_cnt (int): Max number of processes to be used for computing.
"""
def __init__(self, executors_cnt, max_processes_cnt):
self._max_processes_cnt = max_processes_cnt
self._executors_cnt = executors_cnt
self._lock = threading.Lock()
self._executors = {
ind: Executor(
self, executor_id=ind,
available_workers=fractions.Fraction(self._max_processes_cnt, self._executors_cnt))
for ind in range(self._executors_cnt)
}
self._remaining_executors = len(self._executors)
self._backend = futures.ProcessPoolExecutor(max_workers=max_processes_cnt)
logger.info("Initialized ComputingResourceManager with executors_cnt=%s, max_processes_cnt=%s.",
executors_cnt, max_processes_cnt)
def __enter__(self):
"""This method is not thread safe."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
This should not block because every executor have waited. If it blocks, there may be some problem.
This method is not thread safe.
"""
self._backend.shutdown()
def get_executor(self):
"""
Get an executor.
Returns:
Executor, which can be used for submitting tasks.
Raises:
ComputeResourceManagerException: when no more executor is available.
"""
with self._lock:
self._remaining_executors -= 1
if self._remaining_executors < 0:
raise ComputingResourceManagerException("No more executors.")
return self._executors[self._remaining_executors]
def destroy_executor(self, executor_id):
"""
Destroy an executor to reuse it's workers.
Args:
executor_id (int): Id of the executor to be destroyed.
"""
with self._lock:
released_workers = self._executors[executor_id].available_workers
self._executors.pop(executor_id)
remaining_executors = len(self._executors)
logger.info("Destroy executor %s. Will release %s worker(s). Remaining executors: %s.",
executor_id, released_workers, remaining_executors)
if not remaining_executors:
return
for executor in self._executors.values():
executor.add_worker(
fractions.Fraction(
released_workers.numerator,
released_workers.denominator * remaining_executors))
def submit(self, *args, **kwargs):
"""
Submit a task.
See concurrent.futures.Executor.submit() for details.
This method should only be called by Executor. Users should not call this method directly.
"""
with self._lock:
return self._backend.submit(*args, **kwargs)
class ComputingResourceManagerException(MindInsightException):
"""
Indicates a computing resource error has occurred.
This exception should not be presented to end users.
Args:
msg (str): Exception message.
"""
def __init__(self, msg):
super().__init__(error=GeneralErrors.COMPUTING_RESOURCE_ERROR, message=msg)
class WrappedFuture:
"""
Wrap Future objects with custom logics to release compute slots.
Args:
executor (Executor): The executor which generates this future.
original_future (futures.Future): Original future object.
"""
def __init__(self, executor, original_future: futures.Future):
self._original_future = original_future
self._executor = executor
def add_done_callback(self, callback):
"""
Add done callback.
See futures.Future.add_done_callback() for details.
"""
def _wrapped_callback(*args, **kwargs):
logger.debug("Future callback called.")
try:
return callback(*args, **kwargs)
finally:
self._executor.release_slot()
self._executor.remove_done_future(self._original_future)
self._original_future.add_done_callback(_wrapped_callback)
class Executor:
"""
Task executor.
Args:
mgr (ComputingResourceManager): The ComputingResourceManager that generates this executor.
executor_id (int): Executor id.
available_workers (fractions.Fraction): Available workers.
"""
def __init__(self, mgr: ComputingResourceManager, executor_id, available_workers):
self._mgr = mgr
self.closed = False
self._available_workers = available_workers
self._effective_workers = self._calc_effective_workers(self._available_workers)
self._slots = threading.Semaphore(value=self._effective_workers)
self._id = executor_id
self._futures = set()
self._lock = threading.Lock()
logger.debug("Available workers: %s.", available_workers)
def __enter__(self):
"""This method is not thread safe."""
if self.closed:
raise ComputingResourceManagerException("Can not reopen closed executor.")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""This method is not thread safe."""
self._close()
def submit(self, *args, **kwargs):
"""
Submit task.
See concurrent.futures.Executor.submit() for details. This method is not thread safe.
"""
logger.debug("Task submitted to executor %s.", self._id)
if self.closed:
raise ComputingResourceManagerException("Cannot submit task to a closed executor.")
# Thread will wait on acquire().
self._slots.acquire()
future = self._mgr.submit(*args, **kwargs)
# set.add is atomic in c-python.
self._futures.add(future)
return WrappedFuture(self, future)
def release_slot(self):
"""
Release a slot for new tasks to be submitted.
Semaphore is itself thread safe, so no lock is needed.
This method should only be called by ExecutorFuture.
"""
self._slots.release()
def remove_done_future(self, future):
"""
Remove done futures so the executor will not track them.
This method should only be called by WrappedFuture.
"""
# set.remove is atomic in c-python so no lock is needed.
self._futures.remove(future)
@staticmethod
def _calc_effective_workers(available_workers):
return 1 if available_workers <= 1 else math.floor(available_workers)
def _close(self):
self.closed = True
logger.debug("Executor is being closed, futures to wait: %s", self._futures)
futures.wait(self._futures)
logger.debug("Executor wait futures completed.")
self._mgr.destroy_executor(self._id)
logger.debug("Executor is closed.")
@property
def available_workers(self):
"""Get available workers."""
with self._lock:
return self._available_workers
def add_worker(self, added_available_workers):
"""This method should only be called by ComputeResourceManager."""
logger.debug("Add worker: %s", added_available_workers)
with self._lock:
self._available_workers += added_available_workers
new_effective_workers = self._calc_effective_workers(self._available_workers)
if new_effective_workers > self._effective_workers:
for _ in range(new_effective_workers - self._effective_workers):
self._slots.release()
self._effective_workers = new_effective_workers
def wait_all_tasks_finish(self):
"""
Wait all tasks finish.
This method is not thread safe.
"""
futures.wait(self._futures)
...@@ -43,6 +43,7 @@ class GeneralErrors(Enum): ...@@ -43,6 +43,7 @@ class GeneralErrors(Enum):
FILE_SYSTEM_PERMISSION_ERROR = 8 FILE_SYSTEM_PERMISSION_ERROR = 8
PORT_NOT_AVAILABLE_ERROR = 9 PORT_NOT_AVAILABLE_ERROR = 9
URL_DECODE_ERROR = 10 URL_DECODE_ERROR = 10
COMPUTING_RESOURCE_ERROR = 11
class ProfilerMgrErrors(Enum): class ProfilerMgrErrors(Enum):
......
...@@ -224,3 +224,6 @@ def setup_logger(sub_module, log_name, **kwargs): ...@@ -224,3 +224,6 @@ def setup_logger(sub_module, log_name, **kwargs):
logger.addHandler(logfile_handler) logger.addHandler(logfile_handler)
return logger return logger
utils_logger = setup_logger("utils", "utils")
...@@ -27,6 +27,7 @@ import pytest ...@@ -27,6 +27,7 @@ import pytest
from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid from mindinsight.datavisual.common.exceptions import SummaryLogPathInvalid
from mindinsight.datavisual.data_transform import data_loader from mindinsight.datavisual.data_transform import data_loader
from mindinsight.datavisual.data_transform.data_loader import DataLoader from mindinsight.datavisual.data_transform.data_loader import DataLoader
from mindinsight.utils.computing_resource_mgr import ComputingResourceManager
from ..mock import MockLogger from ..mock import MockLogger
...@@ -57,7 +58,7 @@ class TestDataLoader: ...@@ -57,7 +58,7 @@ class TestDataLoader:
"""Test loading method with empty file list.""" """Test loading method with empty file list."""
loader = DataLoader(self._summary_dir) loader = DataLoader(self._summary_dir)
with pytest.raises(SummaryLogPathInvalid): with pytest.raises(SummaryLogPathInvalid):
loader.load() loader.load(ComputingResourceManager(1, 1))
assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning'])
def test_load_with_invalid_file_list(self): def test_load_with_invalid_file_list(self):
...@@ -66,7 +67,7 @@ class TestDataLoader: ...@@ -66,7 +67,7 @@ class TestDataLoader:
self._generate_files(self._summary_dir, file_list) self._generate_files(self._summary_dir, file_list)
loader = DataLoader(self._summary_dir) loader = DataLoader(self._summary_dir)
with pytest.raises(SummaryLogPathInvalid): with pytest.raises(SummaryLogPathInvalid):
loader.load() loader.load(ComputingResourceManager(1, 1))
assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning']) assert 'No valid files can be loaded' in str(MockLogger.log_msg['warning'])
def test_load_success(self): def test_load_success(self):
...@@ -77,6 +78,6 @@ class TestDataLoader: ...@@ -77,6 +78,6 @@ class TestDataLoader:
file_list = ['summary.001', 'summary.002'] file_list = ['summary.001', 'summary.002']
self._generate_files(dir_path, file_list) self._generate_files(dir_path, file_list)
dataloader = DataLoader(dir_path) dataloader = DataLoader(dir_path)
dataloader.load() dataloader.load(ComputingResourceManager(1, 1))
assert dataloader._loader is not None assert dataloader._loader is not None
shutil.rmtree(dir_path) shutil.rmtree(dir_path)
...@@ -30,6 +30,7 @@ from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader ...@@ -30,6 +30,7 @@ from mindinsight.datavisual.data_transform.ms_data_loader import MSDataLoader
from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser from mindinsight.datavisual.data_transform.ms_data_loader import _PbParser
from mindinsight.datavisual.data_transform.events_data import TensorEvent from mindinsight.datavisual.data_transform.events_data import TensorEvent
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.utils.computing_resource_mgr import ComputingResourceManager
from ..mock import MockLogger from ..mock import MockLogger
from ....utils.log_generators.graph_pb_generator import create_graph_pb_file from ....utils.log_generators.graph_pb_generator import create_graph_pb_file
...@@ -85,7 +86,7 @@ class TestMsDataLoader: ...@@ -85,7 +86,7 @@ class TestMsDataLoader:
write_file(file1, SCALAR_RECORD) write_file(file1, SCALAR_RECORD)
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
ms_loader._latest_summary_filename = 'summary.00' ms_loader._latest_summary_filename = 'summary.00'
ms_loader.load() ms_loader.load(ComputingResourceManager(1, 1))
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
tag = ms_loader.get_events_data().list_tags_by_plugin('scalar') tag = ms_loader.get_events_data().list_tags_by_plugin('scalar')
tensors = ms_loader.get_events_data().tensors(tag[0]) tensors = ms_loader.get_events_data().tensors(tag[0])
...@@ -98,7 +99,7 @@ class TestMsDataLoader: ...@@ -98,7 +99,7 @@ class TestMsDataLoader:
file2 = os.path.join(summary_dir, 'summary.02') file2 = os.path.join(summary_dir, 'summary.02')
write_file(file2, SCALAR_RECORD) write_file(file2, SCALAR_RECORD)
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
ms_loader.load() ms_loader.load(ComputingResourceManager(1, 1))
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning']) assert 'Check crc faild and ignore this file' in str(MockLogger.log_msg['warning'])
...@@ -124,7 +125,7 @@ class TestMsDataLoader: ...@@ -124,7 +125,7 @@ class TestMsDataLoader:
summary_dir = tempfile.mkdtemp() summary_dir = tempfile.mkdtemp()
create_graph_pb_file(output_dir=summary_dir, filename=filename) create_graph_pb_file(output_dir=summary_dir, filename=filename)
ms_loader = MSDataLoader(summary_dir) ms_loader = MSDataLoader(summary_dir)
ms_loader.load() ms_loader.load(ComputingResourceManager(1, 1))
events_data = ms_loader.get_events_data() events_data = ms_loader.get_events_data()
plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value) plugins = events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
shutil.rmtree(summary_dir) shutil.rmtree(summary_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册