提交 edd7e184 编写于 作者: M ms_yan

modify config api

上级 8844462e
......@@ -18,7 +18,7 @@ datasets in special format, including mindrecord, tfrecord, manifest. Users
can also create samplers with this module to sample data.
"""
from .core.configuration import config
from .core import config
from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, NumpySlicesDataset, \
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CocoDataset, CelebADataset,\
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import random
import numpy
import mindspore._c_dataengine as cde
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_monitor_sampling_interval', 'get_monitor_sampling_interval', 'load']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
_config = cde.GlobalContext.config_manager()
def set_seed(seed):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set.
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> ds.config.set_seed(1000)
"""
if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range.")
_config.set_seed(seed)
random.seed(seed)
# numpy.random isn't thread safe
numpy.random.seed(seed)
def get_seed():
"""
Get the seed.
Returns:
Int, seed.
"""
return _config.get_seed()
def set_prefetch_size(size):
"""
Set the number of rows to be prefetched.
Args:
size (int): total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new prefetch value.
>>> ds.config.set_prefetch_size(1000)
"""
if size <= 0 or size > INT32_MAX:
raise ValueError("Prefetch size given is not within the required range.")
_config.set_op_connector_size(size)
def get_prefetch_size():
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return _config.get_op_connector_size()
def set_num_parallel_workers(num):
"""
Set the default number of parallel workers.
Args:
num (int): number of parallel workers to be used as a default for each operation.
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> ds.config.set_num_parallel_workers(8)
"""
if num <= 0 or num > INT32_MAX:
raise ValueError("Num workers given is not within the required range.")
_config.set_num_parallel_workers(num)
def get_num_parallel_workers():
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return _config.get_num_parallel_workers()
def set_monitor_sampling_interval(interval):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval (int): interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> # sets the new interval value.
>>> ds.config.set_monitor_sampling_interval(100)
"""
if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range.")
_config.set_monitor_sampling_interval(interval)
def get_monitor_sampling_interval():
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return _config.get_monitor_sampling_interval()
def __str__():
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return str(_config)
def load(file):
"""
Load configuration from a file.
Args:
file (str): path the config file to be loaded.
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> # sets the default value according to values in configuration file.
>>> ds.config.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
_config.load(file)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration manager.
"""
import random
import numpy
import mindspore._c_dataengine as cde
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
class ConfigurationManager:
"""The configuration manager"""
def __init__(self):
self.config = cde.GlobalContext.config_manager()
def set_seed(self, seed):
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library and numpy.random library
for deterministic python augmentations using randomness. This set_seed function should
be called with every iterator created to reset the random seed. In our pipeline this
does not guarantee deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set
Raises:
ValueError: If seed is invalid (< 0 or > MAX_UINT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new seed value, now operators with a random seed will use new seed value.
>>> con.set_seed(1000)
"""
if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range")
self.config.set_seed(seed)
random.seed(seed)
# numpy.random isn't thread safe
numpy.random.seed(seed)
def get_seed(self):
"""
Get the seed
Returns:
Int, seed.
"""
return self.config.get_seed()
def set_prefetch_size(self, size):
"""
Set the number of rows to be prefetched.
Args:
size: total number of rows to be prefetched.
Raises:
ValueError: If prefetch_size is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new prefetch value.
>>> con.set_prefetch_size(1000)
"""
if size <= 0 or size > INT32_MAX:
raise ValueError("Prefetch size given is not within the required range")
self.config.set_op_connector_size(size)
def get_prefetch_size(self):
"""
Get the prefetch size in number of rows.
Returns:
Size, total number of rows to be prefetched.
"""
return self.config.get_op_connector_size()
def set_num_parallel_workers(self, num):
"""
Set the default number of parallel workers
Args:
num: number of parallel workers to be used as a default for each operation
Raises:
ValueError: If num_parallel_workers is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new parallel_workers value, now parallel dataset operators will run with 8 workers.
>>> con.set_num_parallel_workers(8)
"""
if num <= 0 or num > INT32_MAX:
raise ValueError("Num workers given is not within the required range")
self.config.set_num_parallel_workers(num)
def get_num_parallel_workers(self):
"""
Get the default number of parallel workers.
Returns:
Int, number of parallel workers to be used as a default for each operation
"""
return self.config.get_num_parallel_workers()
def set_monitor_sampling_interval(self, interval):
"""
Set the default interval(ms) of monitor sampling.
Args:
interval: interval(ms) to be used to performance monitor sampling.
Raises:
ValueError: If interval is invalid (<= 0 or > MAX_INT_32).
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the new interval value.
>>> con.set_monitor_sampling_interval(100)
"""
if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range")
self.config.set_monitor_sampling_interval(interval)
def get_monitor_sampling_interval(self):
"""
Get the default interval of performance monitor sampling.
Returns:
Interval: interval(ms) of performance monitor sampling.
"""
return self.config.get_monitor_sampling_interval()
def __str__(self):
"""
String representation of the configurations.
Returns:
Str, configurations.
"""
return str(self.config)
def load(self, file):
"""
Load configuration from a file.
Args:
file: path the config file to be loaded
Raises:
RuntimeError: If file is invalid and parsing fails.
Examples:
>>> import mindspore.dataset as ds
>>> con = ds.engine.ConfigurationManager()
>>> # sets the default value according to values in configuration file.
>>> con.load("path/to/config/file")
>>> # example config file:
>>> # {
>>> # "logFilePath": "/tmp",
>>> # "rowsPerBuffer": 32,
>>> # "numParallelWorkers": 4,
>>> # "workerConnectorSize": 16,
>>> # "opConnectorSize": 16,
>>> # "seed": 5489,
>>> # "monitorSamplingInterval": 30
>>> # }
"""
self.config.load(file)
config = ConfigurationManager()
......@@ -26,10 +26,9 @@ from .datasets import *
from .iterators import *
from .serializer_deserializer import serialize, deserialize, show, compare
from .samplers import *
from ..core.configuration import config, ConfigurationManager
from ..core import config
__all__ = ["config", "ConfigurationManager", "zip",
"ImageFolderDatasetV2", "MnistDataset",
__all__ = ["config", "zip", "ImageFolderDatasetV2", "MnistDataset",
"MindDataset", "GeneratorDataset", "TFRecordDataset", "CLUEDataset",
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
"VOCDataset", "CocoDataset", "TextFileDataset", "Schema", "DistributedSampler",
......
......@@ -22,7 +22,7 @@ import sys
from mindspore import log as logger
from . import datasets as de
from ..transforms.vision.utils import Inter, Border
from ..core.configuration import config
from ..core import config
def serialize(dataset, json_filepath=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册