Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
7e17d6ff
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7e17d6ff
编写于
5月 11, 2020
作者:
W
wenkai
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor data manager and unify cache and data access/reload
上级
46d44977
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
828 addition
and
301 deletion
+828
-301
mindinsight/backend/data_manager/__init__.py
mindinsight/backend/data_manager/__init__.py
+35
-0
mindinsight/backend/datavisual/__init__.py
mindinsight/backend/datavisual/__init__.py
+0
-6
mindinsight/datavisual/common/exceptions.py
mindinsight/datavisual/common/exceptions.py
+9
-0
mindinsight/datavisual/data_transform/data_manager.py
mindinsight/datavisual/data_transform/data_manager.py
+711
-266
mindinsight/datavisual/processors/train_task_manager.py
mindinsight/datavisual/processors/train_task_manager.py
+15
-3
mindinsight/lineagemgr/cache_item_updater.py
mindinsight/lineagemgr/cache_item_updater.py
+39
-0
mindinsight/utils/constant.py
mindinsight/utils/constant.py
+1
-0
tests/st/func/datavisual/conftest.py
tests/st/func/datavisual/conftest.py
+1
-2
tests/ut/datavisual/data_transform/test_data_manager.py
tests/ut/datavisual/data_transform/test_data_manager.py
+11
-13
tests/ut/datavisual/processors/test_graph_processor.py
tests/ut/datavisual/processors/test_graph_processor.py
+2
-3
tests/ut/datavisual/processors/test_histogram_processor.py
tests/ut/datavisual/processors/test_histogram_processor.py
+1
-2
tests/ut/datavisual/processors/test_images_processor.py
tests/ut/datavisual/processors/test_images_processor.py
+1
-2
tests/ut/datavisual/processors/test_scalars_processor.py
tests/ut/datavisual/processors/test_scalars_processor.py
+1
-2
tests/ut/datavisual/processors/test_train_task_manager.py
tests/ut/datavisual/processors/test_train_task_manager.py
+1
-2
未找到文件。
mindinsight/backend/data_manager/__init__.py
0 → 100644
浏览文件 @
7e17d6ff
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Trigger data manager load."""
from
mindinsight.datavisual.data_transform.data_manager
import
DATA_MANAGER
from
mindinsight.datavisual.common.log
import
logger
from
mindinsight.conf
import
settings
from
mindinsight.lineagemgr.cache_item_updater
import
LineageCacheItemUpdater
def
init_module
(
app
):
"""
Interface to init module.
Args:
app (Flask): An instance of Flask.
"""
# Just to suppress pylint warning about unused arg.
logger
.
debug
(
"App: %s"
,
type
(
app
))
DATA_MANAGER
.
register_brief_cache_item_updater
(
LineageCacheItemUpdater
())
DATA_MANAGER
.
start_load_data
(
reload_interval
=
int
(
settings
.
RELOAD_INTERVAL
),
max_threads_count
=
int
(
settings
.
MAX_THREADS_COUNT
))
mindinsight/backend/datavisual/__init__.py
浏览文件 @
7e17d6ff
...
...
@@ -18,9 +18,6 @@ from mindinsight.backend.datavisual.static_resource_api import init_module as st
from
mindinsight.backend.datavisual.task_manager_api
import
init_module
as
task_init_module
from
mindinsight.backend.datavisual.train_visual_api
import
init_module
as
train_init_module
from
mindinsight.conf
import
settings
from
mindinsight.datavisual.data_transform.data_manager
import
DATA_MANAGER
def
init_module
(
app
):
"""
...
...
@@ -33,6 +30,3 @@ def init_module(app):
static_init_module
(
app
)
task_init_module
(
app
)
train_init_module
(
app
)
DATA_MANAGER
.
start_load_data
(
reload_interval
=
int
(
settings
.
RELOAD_INTERVAL
),
max_threads_count
=
int
(
settings
.
MAX_THREADS_COUNT
))
mindinsight/datavisual/common/exceptions.py
浏览文件 @
7e17d6ff
...
...
@@ -150,3 +150,12 @@ class HistogramNotExistError(MindInsightException):
super
(
HistogramNotExistError
,
self
).
__init__
(
DataVisualErrors
.
HISTOGRAM_NOT_EXIST
,
error_msg
,
http_code
=
400
)
class
TrainJobDetailNotInCacheError
(
MindInsightException
):
"""Detail info of given train job is not in cache."""
def
__init__
(
self
,
error_detail
=
"no detail provided."
):
error_msg
=
f
'Detail info of the given train job is not in cache. Detail:
{
error_detail
}
'
super
().
__init__
(
DataVisualErrors
.
TRAIN_JOB_DETAIL_NOT_IN_CACHE
,
error_msg
,
http_code
=
400
)
mindinsight/datavisual/data_transform/data_manager.py
浏览文件 @
7e17d6ff
...
...
@@ -20,11 +20,18 @@ It can read events data through the DataLoader.
This module also acts as a thread pool manager.
"""
import
abc
import
enum
import
threading
import
time
import
datetime
import
os
from
typing
import
Iterable
,
Optional
from
concurrent.futures
import
ThreadPoolExecutor
,
wait
,
ALL_COMPLETED
from
mindinsight.datavisual.data_transform.summary_watcher
import
SummaryWatcher
from
mindinsight.conf
import
settings
from
mindinsight.datavisual.common
import
exceptions
from
mindinsight.datavisual.common.log
import
logger
...
...
@@ -37,44 +44,369 @@ from mindinsight.utils.exceptions import MindInsightException
from
mindinsight.utils.exceptions
import
ParamValueError
class
DataManager
:
@
enum
.
unique
class
_CacheStatus
(
enum
.
Enum
):
"""Train job cache status."""
NOT_IN_CACHE
=
"NOT_IN_CACHE"
CACHING
=
"CACHING"
CACHED
=
"CACHED"
class
_BasicTrainJob
:
"""
DataManager manages a pool of loader which help access events data
.
Basic info about train job
.
Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
Args:
train_id (str): Id of the train job.
abs_summary_base_dir (str): The canonical path of summary base directory. It should be the return value of
realpath().
abs_summary_dir (str): The canonical path of summary directory. It should be the return value of realpath().
create_time (DateTime): The create time of summary directory.
update_time (DateTime): The latest modify time of summary files directly in the summary directory.
"""
def
__init__
(
self
,
train_id
,
abs_summary_base_dir
,
abs_summary_dir
,
create_time
,
update_time
):
self
.
_train_id
=
train_id
self
.
_abs_summary_base_dir
=
abs_summary_base_dir
self
.
_abs_summary_dir
=
abs_summary_dir
self
.
_create_time
=
create_time
self
.
_update_time
=
update_time
@
property
def
summary_dir
(
self
):
"""Get summary directory path."""
return
self
.
_abs_summary_dir
@
property
def
train_id
(
self
):
"""Get train id."""
return
self
.
_train_id
class
CachedTrainJob
:
"""
def
__init__
(
self
,
loader_generators
):
Cache item for BriefCacheManager.
DetailCacheManager will also wrap it's return value with this class.
Args:
basic_info (_BasicTrainJob): Basic info about the train job.
"""
Initialize the pool of loader and the dict of name-to-path.
def
__init__
(
self
,
basic_info
:
_BasicTrainJob
):
self
.
_basic_info
=
basic_info
self
.
_last_access_time
=
datetime
.
datetime
.
utcnow
()
# Other cached content is stored here.
self
.
_content
=
{}
self
.
_cache_status
=
_CacheStatus
.
NOT_IN_CACHE
@
property
def
cache_status
(
self
):
"""Get cache status."""
return
self
.
_cache_status
@
cache_status
.
setter
def
cache_status
(
self
,
value
):
"""Set cache status."""
self
.
_cache_status
=
value
def
update_access_time
(
self
):
"""Update last access time of this cache item."""
self
.
_last_access_time
=
datetime
.
datetime
.
utcnow
()
@
property
def
last_access_time
(
self
):
"""Get last access time for purposes such as LRU."""
return
self
.
_last_access_time
@
property
def
summary_dir
(
self
):
"""Get summary directory path."""
return
self
.
_basic_info
.
summary_dir
def
set
(
self
,
key
,
value
):
"""Set value to cache."""
self
.
_content
[
key
]
=
value
def
get
(
self
,
key
):
"""Get value from cache."""
try
:
return
self
.
_content
[
key
]
except
KeyError
:
raise
ParamValueError
(
"Invalid cache key({})."
.
format
(
key
))
@
property
def
basic_info
(
self
):
"""Get basic train job info."""
return
self
.
_basic_info
@
basic_info
.
setter
def
basic_info
(
self
,
value
):
"""Set basic train job info."""
self
.
_basic_info
=
value
class
TrainJob
:
"""
Train job object.
You must not create TrainJob objects manually. You should always get TrainJob objects from DataManager.
Args:
loader_generators (list[LoaderGenerator]): Loader generators help generate loaders.
brief_train_job (CachedTrainJob): Brief info about train job.
detail_train_job (Optional[CachedTrainJob]): Detailed info about train job. Default: None.
"""
def
__init__
(
self
,
brief_train_job
:
CachedTrainJob
,
detail_train_job
:
Optional
[
CachedTrainJob
]
=
None
):
self
.
_brief
=
brief_train_job
self
.
_detail
=
detail_train_job
if
self
.
_detail
is
None
:
self
.
_cache_status
=
_CacheStatus
.
NOT_IN_CACHE
else
:
self
.
_cache_status
=
self
.
_detail
.
cache_status
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
self._loader_pool: {'loader_id': <LoaderStruct>}.
def
has_detail
(
self
):
"""Whether this train job has detailed info in cache."""
return
bool
(
self
.
_detail
is
not
None
)
def
get_detail
(
self
,
key
):
"""
Get detail content.
Args:
key (Any): Cache key.
Returns:
Any, cache content.
Raises:
TrainJobDetailNotInCacheError: when this train job has no detail cache.
"""
if
not
self
.
has_detail
():
raise
exceptions
.
TrainJobDetailNotInCacheError
()
return
self
.
_detail
.
get
(
key
)
def
get_brief
(
self
,
key
):
"""
Get brief content.
Args:
key (Any): Cache key.
Returns:
Any, cache content.
"""
return
self
.
_brief
.
get
(
key
)
class
BaseCacheItemUpdater
(
abc
.
ABC
):
"""Abstract base class for other modules to update cache content."""
def
update_item
(
self
,
cache_item
:
CachedTrainJob
):
"""
Update cache item in place.
Args:
cache_item (CachedTrainJob): The cache item to be processed.
"""
raise
NotImplementedError
()
class
_BaseCacheManager
:
"""Base class for cache manager."""
def
__init__
(
self
):
# Use dict to remove duplicate updaters.
self
.
_updaters
=
{}
# key is train_id
self
.
_lock
=
threading
.
Lock
()
self
.
_cache_items
=
{}
def
size
(
self
):
"""Gets used cache slots."""
return
len
(
self
.
_cache_items
)
def
register_cache_item_updater
(
self
,
updater
:
BaseCacheItemUpdater
):
"""Register cache item updater."""
self
.
_updaters
[
updater
.
__class__
.
__qualname__
]
=
updater
def
get_train_jobs
(
self
):
"""Get cached train jobs."""
copied_train_jobs
=
dict
(
self
.
_cache_items
)
return
copied_train_jobs
def
get_train_job
(
self
,
train_id
):
"""Get cached train job."""
try
:
return
self
.
_cache_items
[
train_id
]
except
KeyError
:
raise
TrainJobNotExistError
(
train_id
)
def
cache_train_job
(
self
,
train_id
)
->
bool
:
"""
Cache given train job and update train job's last access time.
This method should return true if reload actions should be taken to cache the train job.
Args:
train_id (str): Train Id.
"""
raise
NotImplementedError
()
def
delete_train_job
(
self
,
train_id
):
"""Delete train job from cache."""
if
train_id
in
self
.
_cache_items
:
del
self
.
_cache_items
[
train_id
]
def
has_content
(
self
):
"""Whether this cache manager has train jobs."""
return
bool
(
self
.
_cache_items
)
def
update_cache
(
self
,
disk_train_jobs
:
Iterable
[
_BasicTrainJob
]):
"""
Update cache according to given train jobs on disk.
Different cache manager should implement different cache update policies in this method.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Train jobs on disk.
"""
raise
NotImplementedError
()
def
_merge_with_disk
(
self
,
disk_train_jobs
:
Iterable
[
_BasicTrainJob
]):
"""
Merge train jobs in cache with train jobs from disk
This method will remove train jobs not on disk. Call this function with lock for thread safety.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic train jobs info from disk.
Returns:
dict, a dict containing train jobs to be cached.
"""
new_cache_items
=
{}
for
train_job
in
disk_train_jobs
:
if
train_job
.
train_id
not
in
self
.
_cache_items
:
new_cache_items
[
train_job
.
train_id
]
=
CachedTrainJob
(
train_job
)
else
:
reused_train_job
=
self
.
_cache_items
[
train_job
.
train_id
]
reused_train_job
.
basic_info
=
train_job
new_cache_items
[
train_job
.
train_id
]
=
reused_train_job
return
new_cache_items
class
_BriefCacheManager
(
_BaseCacheManager
):
"""A cache manager that holds all disk train jobs on disk."""
def
cache_train_job
(
self
,
train_id
):
"""
Cache given train job.
All disk train jobs are cached on every reload, so this method always return false.
Args:
train_id (str): Train Id.
"""
if
train_id
in
self
.
_cache_items
:
self
.
_cache_items
[
train_id
].
update_access_time
()
return
False
def
update_cache
(
self
,
disk_train_jobs
):
"""Update cache."""
with
self
.
_lock
:
new_cache_items
=
self
.
_merge_with_disk
(
disk_train_jobs
)
self
.
_cache_items
=
new_cache_items
for
updater
in
self
.
_updaters
.
values
():
for
cache_item
in
self
.
_cache_items
.
values
():
updater
.
update_item
(
cache_item
)
# Key for plugin tags.
DATAVISUAL_PLUGIN_KEY
=
"tag_mapping"
# Detail train job cache key for datavisual content.
DATAVISUAL_CACHE_KEY
=
"datavisual"
class
_DetailCacheManager
(
_BaseCacheManager
):
"""A cache manager that holds detailed info for most recently used train jobs."""
def
__init__
(
self
,
loader_generators
):
super
().
__init__
()
self
.
_loader_pool
=
{}
self
.
_deleted_id_list
=
[]
self
.
_status
=
DataManagerStatus
.
INIT
.
value
self
.
_status_mutex
=
threading
.
Lock
()
self
.
_loader_pool_mutex
=
threading
.
Lock
()
self
.
_max_threads_count
=
30
self
.
_reload_interval
=
3
self
.
_loader_generators
=
loader_generators
def
size
(
self
):
"""
Get the number of items in this cache manager.
To be implemented.
Returns:
int, the number of items in this cache manager.
"""
raise
NotImplementedError
()
def
loader_pool_size
(
self
):
"""Get loader pool size."""
return
len
(
self
.
_loader_pool
)
def
update_cache
(
self
,
disk_train_jobs
:
Iterable
[
_BasicTrainJob
]):
"""
Update cache.
Will switch to using disk_train_jobs in the future.
Args:
disk_train_jobs (Iterable[_BasicTrainJob]): Basic info about train jobs on disk.
"""
self
.
_generate_loaders
()
self
.
_execute_load_data
()
def
cache_train_job
(
self
,
train_id
):
"""Cache given train job."""
loader
=
None
need_reload
=
False
with
self
.
_loader_pool_mutex
:
if
self
.
_is_loader_in_loader_pool
(
train_id
,
self
.
_loader_pool
):
loader
=
self
.
_loader_pool
.
get
(
train_id
)
if
loader
is
None
:
for
generator
in
self
.
_loader_generators
:
tmp_loader
=
generator
.
generate_loader_by_train_id
(
train_id
)
if
loader
and
loader
.
latest_update_time
>
tmp_loader
.
latest_update_time
:
continue
loader
=
tmp_loader
if
loader
is
None
:
raise
TrainJobNotExistError
(
train_id
)
self
.
_add_loader
(
loader
)
need_reload
=
True
self
.
_update_loader_latest_update_time
(
loader
.
loader_id
)
return
need_reload
def
get_train_jobs
(
self
):
"""
Get train jobs
To be implemented.
"""
def
_add_loader
(
self
,
loader
):
"""
Add a loader to load data.
Args:
loader (LoaderStruct): A object of `Loader`.
"""
if
len
(
self
.
_loader_pool
)
>=
MAX_DATA_LOADER_SIZE
:
delete_number
=
len
(
self
.
_loader_pool
)
-
MAX_DATA_LOADER_SIZE
+
1
...
...
@@ -85,40 +417,328 @@ class DataManager:
self
.
_delete_loader
(
delete_loader_id
)
self
.
_loader_pool
.
update
({
loader
.
loader_id
:
loader
})
def
_delete_loader
(
self
,
loader_id
):
"""
Delete loader from loader pool by loader id.
def
_delete_loader
(
self
,
loader_id
):
"""
Delete loader from loader pool by loader id.
Args:
loader_id (str): ID of loader.
"""
if
self
.
_loader_pool
.
get
(
loader_id
)
is
not
None
:
logger
.
debug
(
"delete loader %s"
,
loader_id
)
self
.
_loader_pool
.
pop
(
loader_id
)
def
_execute_loader
(
self
,
loader_id
):
"""
Load data form data_loader.
If there is something wrong by loading, add logs and delete the loader.
Args:
loader_id (str): An ID for `Loader`.
"""
try
:
with
self
.
_loader_pool_mutex
:
loader
=
self
.
_loader_pool
.
get
(
loader_id
,
None
)
if
loader
is
None
:
logger
.
debug
(
"Loader %r has been deleted, will not load data."
,
loader_id
)
return
loader
.
data_loader
.
load
()
except
MindInsightException
as
ex
:
logger
.
warning
(
"Data loader %r load data failed. "
"Delete data_loader. Detail: %s"
,
loader_id
,
ex
)
with
self
.
_loader_pool_mutex
:
self
.
_delete_loader
(
loader_id
)
def
_generate_loaders
(
self
):
"""This function generates the loader from given path."""
loader_dict
=
{}
for
generator
in
self
.
_loader_generators
:
loader_dict
.
update
(
generator
.
generate_loaders
(
self
.
_loader_pool
))
sorted_loaders
=
sorted
(
loader_dict
.
items
(),
key
=
lambda
loader
:
loader
[
1
].
latest_update_time
)
latest_loaders
=
sorted_loaders
[
-
MAX_DATA_LOADER_SIZE
:]
self
.
_deal_loaders
(
latest_loaders
)
def
_deal_loaders
(
self
,
latest_loaders
):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with
self
.
_loader_pool_mutex
:
for
loader_id
,
loader
in
latest_loaders
:
if
self
.
_loader_pool
.
get
(
loader_id
,
None
)
is
None
:
self
.
_add_loader
(
loader
)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if
self
.
_loader_pool
[
loader_id
].
latest_update_time
<
loader
.
latest_update_time
:
self
.
_update_loader_latest_update_time
(
loader_id
,
loader
.
latest_update_time
)
def
_execute_load_data
(
self
):
"""Load data through multiple threads."""
threads_count
=
self
.
_get_threads_count
()
if
not
threads_count
:
logger
.
info
(
"Can not find any valid train log path to load, loader pool is empty."
)
return
logger
.
info
(
"Start to execute load data. threads_count: %s."
,
threads_count
)
with
ThreadPoolExecutor
(
max_workers
=
threads_count
)
as
executor
:
futures
=
[]
loader_pool
=
self
.
_get_snapshot_loader_pool
()
for
loader_id
in
loader_pool
:
future
=
executor
.
submit
(
self
.
_execute_loader
,
loader_id
)
futures
.
append
(
future
)
wait
(
futures
,
return_when
=
ALL_COMPLETED
)
def
_get_threads_count
(
self
):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count
=
min
(
self
.
_max_threads_count
,
len
(
self
.
_loader_pool
))
return
threads_count
def
delete_train_job
(
self
,
train_id
):
"""
Delete train job with a train id.
Args:
train_id (str): ID for train job.
"""
with
self
.
_loader_pool_mutex
:
self
.
_delete_loader
(
train_id
)
def
list_tensors
(
self
,
train_id
,
tag
):
"""
List tensors of the given train job and tag.
If the tensor can not find by the given tag, will raise exception.
Args:
train_id (str): ID for train job.
tag (str): The tag name.
Returns:
NamedTuple, the tuple format is `collections.namedtuple('_Tensor', ['wall_time', 'event_step', 'value'])`.
the value will contain the given tag data.
"""
loader_pool
=
self
.
_get_snapshot_loader_pool
()
if
not
self
.
_is_loader_in_loader_pool
(
train_id
,
loader_pool
):
raise
TrainJobNotExistError
(
"Can not find the given train job in cache."
)
data_loader
=
loader_pool
[
train_id
].
data_loader
events_data
=
data_loader
.
get_events_data
()
try
:
tensors
=
events_data
.
tensors
(
tag
)
except
KeyError
:
error_msg
=
"Can not find any data in this train job by given tag."
raise
ParamValueError
(
error_msg
)
return
tensors
def
_check_train_job_exist
(
self
,
train_id
,
loader_pool
):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist
=
False
if
train_id
in
loader_pool
:
return
for
generator
in
self
.
_loader_generators
:
if
generator
.
check_train_job_exist
(
train_id
):
is_exist
=
True
break
if
not
is_exist
:
raise
TrainJobNotExistError
(
"Can not find the train job in data manager."
)
def
_is_loader_in_loader_pool
(
self
,
train_id
,
loader_pool
):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if
train_id
in
loader_pool
:
return
True
return
False
def
_get_snapshot_loader_pool
(
self
):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with
self
.
_loader_pool_mutex
:
return
dict
(
self
.
_loader_pool
)
def
get_train_job
(
self
,
train_id
):
"""
Get train job by train ID.
This method overrides parent method.
Args:
train_id (str): Train ID for train job.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self
.
_check_train_job_exist
(
train_id
,
self
.
_loader_pool
)
loader
=
self
.
_get_loader
(
train_id
)
if
loader
is
None
:
logger
.
warning
(
"No valid summary log in train job %s, "
"or it is not in the cache."
,
train_id
)
return
None
train_job
=
loader
.
to_dict
()
train_job
.
pop
(
'data_loader'
)
plugin_data
=
{}
for
plugin_name
in
PluginNameEnum
.
list_members
():
job
=
self
.
get_train_job_by_plugin
(
train_id
,
plugin_name
=
plugin_name
)
if
job
is
None
:
plugin_data
[
plugin_name
]
=
[]
else
:
plugin_data
[
plugin_name
]
=
job
[
'tags'
]
train_job
.
update
({
DATAVISUAL_PLUGIN_KEY
:
plugin_data
})
# Will fill basic_info value in future.
train_job_obj
=
CachedTrainJob
(
basic_info
=
None
)
train_job_obj
.
set
(
DATAVISUAL_CACHE_KEY
,
train_job
)
# Will assign real value in future.
train_job_obj
.
cache_status
=
_CacheStatus
.
CACHED
return
train_job_obj
def
_get_loader
(
self
,
train_id
):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
Returns:
LoaderStruct, the loader.
"""
loader
=
None
with
self
.
_loader_pool_mutex
:
if
self
.
_is_loader_in_loader_pool
(
train_id
,
self
.
_loader_pool
):
loader
=
self
.
_loader_pool
.
get
(
train_id
)
return
loader
def
_update_loader_latest_update_time
(
self
,
loader_id
,
latest_update_time
=
None
):
"""
Update loader with latest_update_time.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
"""
if
latest_update_time
is
None
:
latest_update_time
=
time
.
time
()
self
.
_loader_pool
[
loader_id
].
latest_update_time
=
latest_update_time
def
get_train_job_by_plugin
(
self
,
train_id
,
plugin_name
):
"""
Get a train job by train job id.
If the given train job does not has the given plugin data, the tag list will be empty.
Args:
train_id (str): Get train job info by the given id.
plugin_name (str): Get tags by given plugin.
Returns:
TypedDict('TrainJobEntity', {'id': str, 'name': str, 'tags': List[str]}),
a train job object.
"""
self
.
_check_train_job_exist
(
train_id
,
self
.
_loader_pool
)
loader
=
self
.
_get_loader
(
train_id
)
if
loader
is
None
:
logger
.
warning
(
"No valid summary log in train job %s, "
"or it is not in the cache."
,
train_id
)
return
None
name
=
loader
.
name
data_loader
=
loader
.
data_loader
tags
=
[]
try
:
events_data
=
data_loader
.
get_events_data
()
tags
=
events_data
.
list_tags_by_plugin
(
plugin_name
)
except
KeyError
:
logger
.
debug
(
"Plugin name %r does not exist "
"in train job %r, and set tags to empty list."
,
plugin_name
,
name
)
except
AttributeError
:
logger
.
debug
(
"Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list."
,
name
)
result
=
dict
(
id
=
train_id
,
name
=
name
,
tags
=
tags
)
return
result
Args:
loader_id (str): ID of loader.
"""
if
self
.
_loader_pool
.
get
(
loader_id
)
is
not
None
:
logger
.
debug
(
"delete loader %s"
,
loader_id
)
self
.
_loader_pool
.
pop
(
loader_id
)
def
_execute_loader
(
self
,
loader_id
)
:
class
DataManager
:
"""
Load data form data_loader
.
DataManager manages a pool of loader which help access events data
.
If there is something wrong by loading, add logs and delete the loader.
Each loader helps deal the data of the events.
A loader corresponds to an events_data.
The DataManager build a pool including all the data_loader.
The data_loader provides extracting
method to get the information of events.
"""
def
__init__
(
self
,
summary_base_dir
):
"""
Initialize the pool of loader and the dict of name-to-path.
Args:
loader_id (str): An ID for `Loader`.
summary_base_dir (str): Base summary directory.
self._status: Refer `datavisual.common.enums.DataManagerStatus`.
"""
try
:
with
self
.
_loader_pool_mutex
:
loader
=
self
.
_loader_pool
.
get
(
loader_id
,
None
)
if
loader
is
None
:
logger
.
debug
(
"Loader %r has been deleted, will not load data."
,
loader_id
)
return
loader
.
data_loader
.
load
()
except
MindInsightException
as
ex
:
logger
.
warning
(
"Data loader %r load data failed. "
"Delete data_loader. Detail: %s"
,
loader_id
,
ex
)
self
.
_summary_base_dir
=
os
.
path
.
realpath
(
summary_base_dir
)
self
.
_status
=
DataManagerStatus
.
INIT
.
value
self
.
_status_mutex
=
threading
.
Lock
()
with
self
.
_loader_pool_mutex
:
self
.
_delete_loader
(
loader_id
)
self
.
_reload_interval
=
3
loader_generators
=
[
DataLoaderGenerator
(
self
.
_summary_base_dir
)]
self
.
_detail_cache
=
_DetailCacheManager
(
loader_generators
)
self
.
_brief_cache
=
_BriefCacheManager
()
def
start_load_data
(
self
,
reload_interval
=
settings
.
RELOAD_INTERVAL
,
...
...
@@ -176,64 +796,31 @@ class DataManager:
return
self
.
status
=
DataManagerStatus
.
LOADING
.
value
self
.
_generate_loaders
()
self
.
_execute_load_data
()
if
not
self
.
_loader_pool
:
summaries_info
=
SummaryWatcher
().
list_summary_directories
(
self
.
_summary_base_dir
)
basic_train_jobs
=
[]
for
info
in
summaries_info
:
basic_train_jobs
.
append
(
_BasicTrainJob
(
train_id
=
info
[
'relative_path'
],
abs_summary_base_dir
=
self
.
_summary_base_dir
,
abs_summary_dir
=
os
.
path
.
realpath
(
os
.
path
.
join
(
self
.
_summary_base_dir
,
info
[
'relative_path'
]
)),
create_time
=
info
[
'create_time'
],
update_time
=
info
[
'update_time'
]
))
self
.
_brief_cache
.
update_cache
(
basic_train_jobs
)
self
.
_detail_cache
.
update_cache
(
basic_train_jobs
)
if
not
self
.
_brief_cache
.
has_content
()
and
not
self
.
_detail_cache
.
has_content
():
self
.
status
=
DataManagerStatus
.
INVALID
.
value
else
:
self
.
status
=
DataManagerStatus
.
DONE
.
value
logger
.
info
(
"Load event data end, status: %r, and loader pool size is %r."
,
self
.
status
,
len
(
self
.
_loader_pool
))
def
_generate_loaders
(
self
):
"""This function generates the loader from given path."""
loader_dict
=
{}
for
generator
in
self
.
_loader_generators
:
loader_dict
.
update
(
generator
.
generate_loaders
(
self
.
_loader_pool
))
sorted_loaders
=
sorted
(
loader_dict
.
items
(),
key
=
lambda
loader
:
loader
[
1
].
latest_update_time
)
latest_loaders
=
sorted_loaders
[
-
MAX_DATA_LOADER_SIZE
:]
self
.
_deal_loaders
(
latest_loaders
)
def
_deal_loaders
(
self
,
latest_loaders
):
"""
This function determines which loaders to keep or remove or added.
It is based on the given dict of loaders.
Args:
latest_loaders (list[dict]): A list of <loader_id: LoaderStruct>.
"""
with
self
.
_loader_pool_mutex
:
for
loader_id
,
loader
in
latest_loaders
:
if
self
.
_loader_pool
.
get
(
loader_id
,
None
)
is
None
:
self
.
_add_loader
(
loader
)
continue
# If this loader was updated manually before,
# its latest_update_time may bigger than update_time in summary.
if
self
.
_loader_pool
[
loader_id
].
latest_update_time
<
loader
.
latest_update_time
:
self
.
_update_loader_latest_update_time
(
loader_id
,
loader
.
latest_update_time
)
def
_execute_load_data
(
self
):
"""Load data through multiple threads."""
threads_count
=
self
.
_get_threads_count
()
if
not
threads_count
:
logger
.
info
(
"Can not find any valid train log path to load, loader pool is empty."
)
return
logger
.
info
(
"Start to execute load data. threads_count: %s."
,
threads_count
)
with
ThreadPoolExecutor
(
max_workers
=
threads_count
)
as
executor
:
futures
=
[]
loader_pool
=
self
.
_get_snapshot_loader_pool
()
for
loader_id
in
loader_pool
:
future
=
executor
.
submit
(
self
.
_execute_loader
,
loader_id
)
futures
.
append
(
future
)
wait
(
futures
,
return_when
=
ALL_COMPLETED
)
self
.
status
,
self
.
_detail_cache
.
loader_pool_size
())
@
staticmethod
def
check_reload_interval
(
reload_interval
):
...
...
@@ -262,18 +849,6 @@ class DataManager:
if
max_threads_count
<=
0
:
raise
ParamValueError
(
"The value of max threads count should be > 0."
)
def
_get_threads_count
(
self
):
"""
Use the maximum number of threads available.
Returns:
int, number of threads.
"""
threads_count
=
min
(
self
.
_max_threads_count
,
len
(
self
.
_loader_pool
))
return
threads_count
def
get_train_job_by_plugin
(
self
,
train_id
,
plugin_name
):
"""
Get a train job by train job id.
...
...
@@ -290,32 +865,9 @@ class DataManager:
"""
self
.
_check_status_valid
()
self
.
_check_train_job_exist
(
train_id
,
self
.
_loader_pool
)
loader
=
self
.
_get_loader
(
train_id
)
if
loader
is
None
:
logger
.
warning
(
"No valid summary log in train job %s, "
"or it is not in the cache."
,
train_id
)
return
None
name
=
loader
.
name
data_loader
=
loader
.
data_loader
tags
=
[]
try
:
events_data
=
data_loader
.
get_events_data
()
tags
=
events_data
.
list_tags_by_plugin
(
plugin_name
)
except
KeyError
:
logger
.
debug
(
"Plugin name %r does not exist "
"in train job %r, and set tags to empty list."
,
plugin_name
,
name
)
except
AttributeError
:
logger
.
debug
(
"Train job %r has been deleted or it has not loaded data, "
"and set tags to empty list."
,
name
)
result
=
dict
(
id
=
train_id
,
name
=
name
,
tags
=
tags
)
return
result
return
self
.
_detail_cache
.
get_train_job_by_plugin
(
train_id
,
plugin_name
)
def
delete_train_job
(
self
,
train_id
):
def
delete_train_job
(
self
,
train_id
,
only_delete_from_cache
=
True
):
"""
Delete train job with a train id.
...
...
@@ -323,8 +875,11 @@ class DataManager:
train_id (str): ID for train job.
"""
with
self
.
_loader_pool_mutex
:
self
.
_delete_loader
(
train_id
)
if
not
only_delete_from_cache
:
raise
NotImplementedError
(
"Delete from both cache and disk is not supported."
)
self
.
_brief_cache
.
delete_train_job
(
train_id
)
self
.
_detail_cache
.
delete_train_job
(
train_id
)
def
list_tensors
(
self
,
train_id
,
tag
):
"""
...
...
@@ -342,66 +897,7 @@ class DataManager:
"""
self
.
_check_status_valid
()
loader_pool
=
self
.
_get_snapshot_loader_pool
()
if
not
self
.
_is_loader_in_loader_pool
(
train_id
,
loader_pool
):
raise
TrainJobNotExistError
(
"Can not find the given train job in cache."
)
data_loader
=
loader_pool
[
train_id
].
data_loader
events_data
=
data_loader
.
get_events_data
()
try
:
tensors
=
events_data
.
tensors
(
tag
)
except
KeyError
:
error_msg
=
"Can not find any data in this train job by given tag."
raise
ParamValueError
(
error_msg
)
return
tensors
def
_check_train_job_exist
(
self
,
train_id
,
loader_pool
):
"""
Check train job exist, if not exist, will raise exception.
Args:
train_id (str): The given train job id.
loader_pool (dict[str, LoaderStruct]): Refer to self._loader_pool.
Raises:
TrainJobNotExistError: Can not find train job in data manager.
"""
is_exist
=
False
if
train_id
in
loader_pool
:
return
for
generator
in
self
.
_loader_generators
:
if
generator
.
check_train_job_exist
(
train_id
):
is_exist
=
True
break
if
not
is_exist
:
raise
TrainJobNotExistError
(
"Can not find the train job in data manager."
)
def
_is_loader_in_loader_pool
(
self
,
train_id
,
loader_pool
):
"""
Check train job exist, if not exist, return False. Else, return True.
Args:
train_id (str): The given train job id.
loader_pool (dict): See self._loader_pool.
Returns:
bool, if loader in loader pool, return True.
"""
if
train_id
in
loader_pool
:
return
True
return
False
def
_get_snapshot_loader_pool
(
self
):
"""
Create a snapshot of data loader pool to avoid concurrent mutation and iteration issues.
Returns:
dict, a copy of `self._loader_pool`.
"""
with
self
.
_loader_pool_mutex
:
return
dict
(
self
.
_loader_pool
)
return
self
.
_detail_cache
.
list_tensors
(
train_id
,
tag
)
def
_check_status_valid
(
self
):
"""Check if the status is valid to load data."""
...
...
@@ -409,90 +905,29 @@ class DataManager:
if
self
.
status
==
DataManagerStatus
.
INIT
.
value
:
raise
exceptions
.
SummaryLogIsLoading
(
"Data is being loaded, current status: %s."
%
self
.
_status
)
def
get_
single_train_job
(
self
,
train_id
,
manual_update
=
False
):
def
get_
train_job
(
self
,
train_id
):
"""
Get train job by train ID.
Args:
train_id (str): Train ID for train job.
manual_update (bool): If manual update, True.
Returns:
dict, single train job, if can not find any data, will return None.
"""
self
.
_check_status_valid
()
self
.
_check_train_job_exist
(
train_id
,
self
.
_loader_pool
)
loader
=
self
.
_get_loader
(
train_id
,
manual_update
)
if
loader
is
None
:
logger
.
warning
(
"No valid summary log in train job %s, "
"or it is not in the cache."
,
train_id
)
return
None
train_job
=
loader
.
to_dict
()
train_job
.
pop
(
'data_loader'
)
plugin_data
=
{}
for
plugin_name
in
PluginNameEnum
.
list_members
():
job
=
self
.
get_train_job_by_plugin
(
train_id
,
plugin_name
=
plugin_name
)
if
job
is
None
:
plugin_data
[
plugin_name
]
=
[]
else
:
plugin_data
[
plugin_name
]
=
job
[
'tags'
]
train_job
.
update
({
'tag_mapping'
:
plugin_data
})
return
train_job
def
_get_loader
(
self
,
train_id
,
manual_update
=
False
):
"""
Get loader by train id.
Args:
train_id (str): Train Id.
manual_update (bool): If manual, True. Else False.
Returns:
LoaderStruct, the loader.
"""
loader
=
None
is_reload
=
False
with
self
.
_loader_pool_mutex
:
if
self
.
_is_loader_in_loader_pool
(
train_id
,
self
.
_loader_pool
):
loader
=
self
.
_loader_pool
.
get
(
train_id
)
if
manual_update
and
loader
is
None
:
for
generator
in
self
.
_loader_generators
:
tmp_loader
=
generator
.
generate_loader_by_train_id
(
train_id
)
if
loader
and
loader
.
latest_update_time
>
tmp_loader
.
latest_update_time
:
continue
loader
=
tmp_loader
if
loader
is
None
:
return
None
self
.
_add_loader
(
loader
)
is_reload
=
True
detail_train_job
=
self
.
_detail_cache
.
get_train_job
(
train_id
)
brief_train_job
=
self
.
_brief_cache
.
get_train_job
(
train_id
)
if
manual_update
:
self
.
_update_loader_latest_update_time
(
loader
.
loader_id
)
if
is_reload
:
self
.
reload_data
()
return
loader
return
TrainJob
(
brief_train_job
,
detail_train_job
)
def
_update_loader_latest_update_time
(
self
,
loader_id
,
latest_update_time
=
None
):
def
list_train_jobs
(
self
):
"""
Update loader with latest_update_time
.
List train jobs
.
Args:
loader_id (str): ID of loader.
latest_update_time (float): Timestamp.
To be implemented.
"""
if
latest_update_time
is
None
:
latest_update_time
=
time
.
time
()
self
.
_loader_pool
[
loader_id
].
latest_update_time
=
latest_update_time
raise
NotImplementedError
()
@
property
def
status
(
self
):
...
...
@@ -509,6 +944,16 @@ class DataManager:
"""Set data manger status."""
self
.
_status
=
status
def
cache_train_job
(
self
,
train_id
):
"""Cache given train job (async)."""
brief_need_reload
=
self
.
_brief_cache
.
cache_train_job
(
train_id
)
detail_need_reload
=
self
.
_detail_cache
.
cache_train_job
(
train_id
)
if
brief_need_reload
or
detail_need_reload
:
self
.
reload_data
()
def
register_brief_cache_item_updater
(
self
,
updater
:
BaseCacheItemUpdater
):
"""Register brief cache item updater for brief cache manager."""
self
.
_brief_cache
.
register_cache_item_updater
(
updater
)
_loader_generators
=
[
DataLoaderGenerator
(
settings
.
SUMMARY_BASE_DIR
)]
DATA_MANAGER
=
DataManager
(
_loader_generators
)
DATA_MANAGER
=
DataManager
(
settings
.
SUMMARY_BASE_DIR
)
mindinsight/datavisual/processors/train_task_manager.py
浏览文件 @
7e17d6ff
...
...
@@ -18,6 +18,7 @@ from mindinsight.datavisual.common import exceptions
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.datavisual.common.validation
import
Validation
from
mindinsight.datavisual.processors.base_processor
import
BaseProcessor
from
mindinsight.datavisual.data_transform.data_manager
import
DATAVISUAL_PLUGIN_KEY
,
DATAVISUAL_CACHE_KEY
class
TrainTaskManager
(
BaseProcessor
):
...
...
@@ -53,13 +54,24 @@ class TrainTaskManager(BaseProcessor):
dict, refer to restful api.
"""
Validation
.
check_param_empty
(
train_id
=
train_id
)
train_job
=
self
.
_data_manager
.
get_single_train_job
(
train_id
,
manual_update
=
manual_update
)
if
not
train_job
:
if
manual_update
:
self
.
_data_manager
.
cache_train_job
(
train_id
)
train_job
=
self
.
_data_manager
.
get_train_job
(
train_id
)
try
:
data_visual_content
=
train_job
.
get_detail
(
DATAVISUAL_CACHE_KEY
)
plugins
=
data_visual_content
.
get
(
DATAVISUAL_PLUGIN_KEY
)
except
exceptions
.
TrainJobDetailNotInCacheError
:
plugins
=
[]
if
not
plugins
:
default_result
=
dict
()
for
plugin_name
in
PluginNameEnum
.
list_members
():
default_result
.
update
({
plugin_name
:
list
()})
return
dict
(
plugins
=
default_result
)
return
dict
(
plugins
=
train_job
[
'tag_mapping'
]
plugins
=
plugins
)
mindinsight/lineagemgr/cache_item_updater.py
0 → 100644
浏览文件 @
7e17d6ff
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cache item updater."""
import
os
from
mindinsight.datavisual.data_transform.data_manager
import
BaseCacheItemUpdater
,
CachedTrainJob
from
mindinsight.lineagemgr.querier.query_model
import
LineageObj
from
mindinsight.lineagemgr.summary.lineage_summary_analyzer
import
LineageSummaryAnalyzer
class
LineageCacheItemUpdater
(
BaseCacheItemUpdater
):
"""Cache item updater for lineage info."""
def
update_item
(
self
,
cache_item
:
CachedTrainJob
):
"""Update cache item in place."""
log_path
=
cache_item
.
summary_dir
log_dir
=
os
.
path
.
dirname
(
log_path
)
lineage_info
=
LineageSummaryAnalyzer
.
get_summary_infos
(
log_path
)
user_defined_info
=
LineageSummaryAnalyzer
.
get_user_defined_info
(
log_path
)
lineage_obj
=
LineageObj
(
log_dir
,
train_lineage
=
lineage_info
.
train_lineage
,
evaluation_lineage
=
lineage_info
.
eval_lineage
,
dataset_graph
=
lineage_info
.
dataset_graph
,
user_defined_info
=
user_defined_info
)
cache_item
.
set
(
key
=
"lineage"
,
value
=
lineage_obj
)
mindinsight/utils/constant.py
浏览文件 @
7e17d6ff
...
...
@@ -63,3 +63,4 @@ class DataVisualErrors(Enum):
IMAGE_NOT_EXIST
=
13
SCALAR_NOT_EXIST
=
14
HISTOGRAM_NOT_EXIST
=
15
TRAIN_JOB_DETAIL_NOT_IN_CACHE
=
16
tests/st/func/datavisual/conftest.py
浏览文件 @
7e17d6ff
...
...
@@ -25,7 +25,6 @@ from flask import Response
from
mindinsight.conf
import
settings
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.data_manager
import
DataManager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.data_transform.loader_generators.loader_generator
import
MAX_DATA_LOADER_SIZE
from
mindinsight.datavisual.utils
import
tools
...
...
@@ -59,7 +58,7 @@ def init_summary_logs():
log_operations
=
LogOperations
()
summaries_metadata
=
log_operations
.
create_summary_logs
(
summary_base_dir
,
constants
.
SUMMARY_DIR_NUM_FIRST
,
constants
.
SUMMARY_DIR_PREFIX
)
mock_data_manager
=
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
mock_data_manager
=
DataManager
(
summary_base_dir
)
mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
check_loading_done
(
mock_data_manager
)
...
...
tests/ut/datavisual/data_transform/test_data_manager.py
浏览文件 @
7e17d6ff
...
...
@@ -33,7 +33,6 @@ from mindinsight.datavisual.data_transform import data_manager, ms_data_loader
from
mindinsight.datavisual.data_transform.data_loader
import
DataLoader
from
mindinsight.datavisual.data_transform.data_manager
import
DataManager
from
mindinsight.datavisual.data_transform.events_data
import
EventsData
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.data_transform.loader_generators.loader_generator
import
MAX_DATA_LOADER_SIZE
from
mindinsight.datavisual.data_transform.loader_generators.loader_struct
import
LoaderStruct
from
mindinsight.datavisual.data_transform.ms_data_loader
import
MSDataLoader
...
...
@@ -89,7 +88,7 @@ class TestDataManager:
train_ids
.
append
(
f
'./dir
{
i
}
'
)
data_manager
.
logger
=
MockLogger
mock_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
mock_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
mock_manager
.
start_load_data
(
reload_interval
=
0
)
check_loading_done
(
mock_manager
)
...
...
@@ -112,7 +111,7 @@ class TestDataManager:
def
test_start_load_data_with_invalid_params
(
self
,
params
):
"""Test start_load_data with invalid reload_interval or invalid max_threads_count."""
summary_base_dir
=
tempfile
.
mkdtemp
()
d_manager
=
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
d_manager
=
DataManager
(
summary_base_dir
)
with
pytest
.
raises
(
ParamValueError
):
d_manager
.
start_load_data
(
**
params
)
shutil
.
rmtree
(
summary_base_dir
)
...
...
@@ -142,9 +141,9 @@ class TestDataManager:
latest_update_time
=
modify_time_01
,
data_loader
=
loader_01
)
loader_pool
=
{
train_job_01
:
loader
}
d_manager
=
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
d_manager
=
DataManager
(
summary_base_dir
)
d_manager
.
_status
=
DataManagerStatus
.
LOADING
.
value
d_manager
.
_loader_pool
=
loader_pool
d_manager
.
_
detail_cache
.
_
loader_pool
=
loader_pool
res
=
d_manager
.
list_tensors
(
train_job_01
,
tag
)
assert
res
==
{
'test result'
}
...
...
@@ -169,9 +168,9 @@ class TestDataManager:
latest_update_time
=
modify_time_01
,
data_loader
=
loader_01
)
loader_pool
=
{
train_job_01
:
loader
}
d_manager
=
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
d_manager
=
DataManager
(
summary_base_dir
)
d_manager
.
_status
=
DataManagerStatus
.
LOADING
.
value
d_manager
.
_loader_pool
=
loader_pool
d_manager
.
_
detail_cache
.
_
loader_pool
=
loader_pool
tag
=
'image'
with
pytest
.
raises
(
ParamValueError
):
d_manager
.
list_tensors
(
train_job_01
,
tag
)
...
...
@@ -181,7 +180,7 @@ class TestDataManager:
def
test_list_tensors_with_not_exist_train_job
(
self
):
"""Test list_tensors method with parameter train_id not found in loader_pool."""
summary_base_dir
=
tempfile
.
mkdtemp
()
d_manager
=
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
d_manager
=
DataManager
(
summary_base_dir
)
d_manager
.
_status
=
DataManagerStatus
.
LOADING
.
value
tag
=
'image'
train_job_01
=
'train_01'
...
...
@@ -200,13 +199,12 @@ class TestDataManager:
expected_loader_ids
=
list
(
loader_dict
.
keys
())
mock_generate_loaders
.
return_value
=
loader_dict
generators
=
[
data_manager
.
DataLoaderGenerator
(
summary_base_dir
)]
mock_data_manager
=
data_manager
.
DataManager
(
generators
)
mock_data_manager
.
_execute_load_data
=
Mock
()
mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
mock_data_manager
.
_detail_cache
.
_execute_load_data
=
Mock
()
mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
check_loading_done
(
mock_data_manager
,
3
)
current_loader_ids
=
mock_data_manager
.
_loader_pool
.
keys
()
current_loader_ids
=
mock_data_manager
.
_
detail_cache
.
_
loader_pool
.
keys
()
assert
sorted
(
current_loader_ids
)
==
sorted
(
expected_loader_ids
)
...
...
@@ -221,7 +219,7 @@ class TestDataManager:
mock_generate_loaders
.
return_value
=
loader_dict
mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
check_loading_done
(
mock_data_manager
)
current_loader_ids
=
mock_data_manager
.
_loader_pool
.
keys
()
current_loader_ids
=
mock_data_manager
.
_
detail_cache
.
_
loader_pool
.
keys
()
assert
sorted
(
current_loader_ids
)
==
sorted
(
expected_loader_ids
)
...
...
tests/ut/datavisual/processors/test_graph_processor.py
浏览文件 @
7e17d6ff
...
...
@@ -30,7 +30,6 @@ from mindinsight.datavisual.common.exceptions import GraphNotExistError
from
mindinsight.datavisual.common.exceptions
import
NodeNotInGraphError
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.data_manager
import
DataManager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.processors.graph_processor
import
GraphProcessor
from
mindinsight.datavisual.utils
import
crc32
from
mindinsight.utils.exceptions
import
ParamValueError
...
...
@@ -74,7 +73,7 @@ class TestGraphProcessor:
self
.
_temp_path
,
self
.
_graph_dict
,
_
=
log_operation
.
generate_log
(
PluginNameEnum
.
GRAPH
.
value
,
log_dir
)
self
.
_generated_path
.
append
(
summary_base_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
# wait for loading done
...
...
@@ -93,7 +92,7 @@ class TestGraphProcessor:
self
.
_generated_path
.
append
(
summary_base_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
# wait for loading done
...
...
tests/ut/datavisual/processors/test_histogram_processor.py
浏览文件 @
7e17d6ff
...
...
@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from
mindinsight.datavisual.common.exceptions
import
TrainJobNotExistError
from
mindinsight.datavisual.common.exceptions
import
HistogramNotExistError
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.processors.histogram_processor
import
HistogramProcessor
from
mindinsight.datavisual.utils
import
crc32
...
...
@@ -72,7 +71,7 @@ class TestHistogramProcessor:
PluginNameEnum
.
HISTOGRAM
.
value
,
log_dir
,
dict
(
step
=
self
.
_steps_list
,
tag
=
self
.
_tag_name
))
self
.
_generated_path
.
append
(
summary_base_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
# wait for loading done
...
...
tests/ut/datavisual/processors/test_images_processor.py
浏览文件 @
7e17d6ff
...
...
@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from
mindinsight.datavisual.common.exceptions
import
TrainJobNotExistError
from
mindinsight.datavisual.common.exceptions
import
ImageNotExistError
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.processors.images_processor
import
ImageProcessor
from
mindinsight.datavisual.utils
import
crc32
...
...
@@ -81,7 +80,7 @@ class TestImagesProcessor:
PluginNameEnum
.
IMAGE
.
value
,
log_dir
,
dict
(
steps
=
steps_list
,
tag
=
self
.
_tag_name
))
self
.
_generated_path
.
append
(
summary_base_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
# wait for loading done
...
...
tests/ut/datavisual/processors/test_scalars_processor.py
浏览文件 @
7e17d6ff
...
...
@@ -27,7 +27,6 @@ from mindinsight.datavisual.common.enums import PluginNameEnum
from
mindinsight.datavisual.common.exceptions
import
TrainJobNotExistError
from
mindinsight.datavisual.common.exceptions
import
ScalarNotExistError
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.processors.scalars_processor
import
ScalarsProcessor
from
mindinsight.datavisual.utils
import
crc32
...
...
@@ -73,7 +72,7 @@ class TestScalarsProcessor:
PluginNameEnum
.
SCALAR
.
value
,
log_dir
,
dict
(
step
=
self
.
_steps_list
,
tag
=
self
.
_tag_name
))
self
.
_generated_path
.
append
(
summary_base_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
summary_base_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
summary_base_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
# wait for loading done
...
...
tests/ut/datavisual/processors/test_train_task_manager.py
浏览文件 @
7e17d6ff
...
...
@@ -27,7 +27,6 @@ import pytest
from
mindinsight.datavisual.common.enums
import
PluginNameEnum
from
mindinsight.datavisual.common.exceptions
import
TrainJobNotExistError
from
mindinsight.datavisual.data_transform
import
data_manager
from
mindinsight.datavisual.data_transform.loader_generators.data_loader_generator
import
DataLoaderGenerator
from
mindinsight.datavisual.processors.train_task_manager
import
TrainTaskManager
from
mindinsight.datavisual.utils
import
crc32
...
...
@@ -97,7 +96,7 @@ class TestTrainTaskManager:
self
.
_generated_path
.
append
(
self
.
_root_dir
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
[
DataLoaderGenerator
(
self
.
_root_dir
)]
)
self
.
_mock_data_manager
=
data_manager
.
DataManager
(
self
.
_root_dir
)
self
.
_mock_data_manager
.
start_load_data
(
reload_interval
=
0
)
check_loading_done
(
self
.
_mock_data_manager
,
time_limit
=
30
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录