Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
edd7e184
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
edd7e184
编写于
7月 07, 2020
作者:
M
ms_yan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify config api
上级
8844462e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
199 addition
and
200 deletion
+199
-200
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+1
-1
mindspore/dataset/core/config.py
mindspore/dataset/core/config.py
+195
-0
mindspore/dataset/core/configuration.py
mindspore/dataset/core/configuration.py
+0
-195
mindspore/dataset/engine/__init__.py
mindspore/dataset/engine/__init__.py
+2
-3
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+1
-1
未找到文件。
mindspore/dataset/__init__.py
浏览文件 @
edd7e184
...
...
@@ -18,7 +18,7 @@ datasets in special format, including mindrecord, tfrecord, manifest. Users
can also create samplers with this module to sample data.
"""
from
.core
.configuration
import
config
from
.core
import
config
from
.engine.datasets
import
TFRecordDataset
,
ImageFolderDatasetV2
,
MnistDataset
,
MindDataset
,
NumpySlicesDataset
,
\
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CocoDataset
,
CelebADataset
,
\
TextFileDataset
,
CLUEDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
...
...
mindspore/dataset/core/config.py
0 → 100644
浏览文件 @
edd7e184
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import
random
import
numpy
import
mindspore._c_dataengine
as
cde
__all__
=
[
'set_seed'
,
'get_seed'
,
'set_prefetch_size'
,
'get_prefetch_size'
,
'set_num_parallel_workers'
,
'get_num_parallel_workers'
,
'set_monitor_sampling_interval'
,
'get_monitor_sampling_interval'
,
'load'
]
INT32_MAX
=
2147483647
UINT32_MAX
=
4294967295
_config
=
cde
.
GlobalContext
.
config_manager
()
def
set_seed
(
seed
):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set.
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> ds.config.set_seed(1000)
"""
if
seed
<
0
or
seed
>
UINT32_MAX
:
raise
ValueError
(
"Seed given is not within the required range."
)
_config
.
set_seed
(
seed
)
random
.
seed
(
seed
)
# numpy.random isn't thread safe
numpy
.
random
.
seed
(
seed
)
def
get_seed
():
"""
Get the seed.
Returns:
Int, seed.
"""
return
_config
.
get_seed
()
def
set_prefetch_size
(
size
):
"""
Set the number of rows to be prefetched.
Args:
size (int): total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new prefetch value.
>>> ds.config.set_prefetch_size(1000)
"""
if
size
<=
0
or
size
>
INT32_MAX
:
raise
ValueError
(
"Prefetch size given is not within the required range."
)
_config
.
set_op_connector_size
(
size
)
def
get_prefetch_size
():
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return
_config
.
get_op_connector_size
()
def
set_num_parallel_workers
(
num
):
"""
Set the default number of parallel workers.
Args:
num (int): number of parallel workers to be used as a default for each operation.
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> ds.config.set_num_parallel_workers(8)
"""
if
num
<=
0
or
num
>
INT32_MAX
:
raise
ValueError
(
"Num workers given is not within the required range."
)
_config
.
set_num_parallel_workers
(
num
)
def
get_num_parallel_workers
():
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return
_config
.
get_num_parallel_workers
()
def
set_monitor_sampling_interval
(
interval
):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval (int): interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new interval value.
>>> ds.config.set_monitor_sampling_interval(100)
"""
if
interval
<=
0
or
interval
>
INT32_MAX
:
raise
ValueError
(
"Interval given is not within the required range."
)
_config
.
set_monitor_sampling_interval
(
interval
)
def
get_monitor_sampling_interval
():
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return
_config
.
get_monitor_sampling_interval
()
def
__str__
():
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return
str
(
_config
)
def
load
(
file
):
"""
Load configuration from a file.
Args:
file (str): path the config file to be loaded.
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> # sets the default value according to values in configuration file.
>>> ds.config.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
_config
.
load
(
file
)
mindspore/dataset/core/configuration.py
已删除
100644 → 0
浏览文件 @
8844462e
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import
random
import
numpy
import
mindspore._c_dataengine
as
cde
INT32_MAX
=
2147483647
UINT32_MAX
=
4294967295
class
ConfigurationManager
:
"""The configuration manager"""
def
__init__
(
self
):
self
.
config
=
cde
.
GlobalContext
.
config_manager
()
def
set_seed
(
self
,
seed
):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> con.set_seed(1000)
"""
if
seed
<
0
or
seed
>
UINT32_MAX
:
raise
ValueError
(
"Seed given is not within the required range"
)
self
.
config
.
set_seed
(
seed
)
random
.
seed
(
seed
)
# numpy.random isn't thread safe
numpy
.
random
.
seed
(
seed
)
def
get_seed
(
self
):
"""
Get the seed
Returns:
Int, seed.
"""
return
self
.
config
.
get_seed
()
def
set_prefetch_size
(
self
,
size
):
"""
Set the number of rows to be prefetched.
Args:
size: total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new prefetch value.
>>> con.set_prefetch_size(1000)
"""
if
size
<=
0
or
size
>
INT32_MAX
:
raise
ValueError
(
"Prefetch size given is not within the required range"
)
self
.
config
.
set_op_connector_size
(
size
)
def
get_prefetch_size
(
self
):
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return
self
.
config
.
get_op_connector_size
()
def
set_num_parallel_workers
(
self
,
num
):
"""
Set the default number of parallel workers
Args:
num: number of parallel workers to be used as a default for each operation
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> con.set_num_parallel_workers(8)
"""
if
num
<=
0
or
num
>
INT32_MAX
:
raise
ValueError
(
"Num workers given is not within the required range"
)
self
.
config
.
set_num_parallel_workers
(
num
)
def
get_num_parallel_workers
(
self
):
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return
self
.
config
.
get_num_parallel_workers
()
def
set_monitor_sampling_interval
(
self
,
interval
):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval: interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new interval value.
>>> con.set_monitor_sampling_interval(100)
"""
if
interval
<=
0
or
interval
>
INT32_MAX
:
raise
ValueError
(
"Interval given is not within the required range"
)
self
.
config
.
set_monitor_sampling_interval
(
interval
)
def
get_monitor_sampling_interval
(
self
):
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return
self
.
config
.
get_monitor_sampling_interval
()
def
__str__
(
self
):
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return
str
(
self
.
config
)
def
load
(
self
,
file
):
"""
Load configuration from a file.
Args:
file: path the config file to be loaded
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the default value according to values in configuration file.
>>> con.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
self
.
config
.
load
(
file
)
config
=
ConfigurationManager
()
mindspore/dataset/engine/__init__.py
浏览文件 @
edd7e184
...
...
@@ -26,10 +26,9 @@ from .datasets import *
from
.iterators
import
*
from
.serializer_deserializer
import
serialize
,
deserialize
,
show
,
compare
from
.samplers
import
*
from
..core
.configuration
import
config
,
ConfigurationManager
from
..core
import
config
__all__
=
[
"config"
,
"ConfigurationManager"
,
"zip"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
__all__
=
[
"config"
,
"zip"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"MindDataset"
,
"GeneratorDataset"
,
"TFRecordDataset"
,
"CLUEDataset"
,
"ManifestDataset"
,
"Cifar10Dataset"
,
"Cifar100Dataset"
,
"CelebADataset"
,
"VOCDataset"
,
"CocoDataset"
,
"TextFileDataset"
,
"Schema"
,
"DistributedSampler"
,
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
edd7e184
...
...
@@ -22,7 +22,7 @@ import sys
from
mindspore
import
log
as
logger
from
.
import
datasets
as
de
from
..transforms.vision.utils
import
Inter
,
Border
from
..core
.configuration
import
config
from
..core
import
config
def
serialize
(
dataset
,
json_filepath
=
None
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录