提交 eda63a55 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!477 Fix VOC dataset test cases

Merge pull request !477 from xiefangqi/xfq_fix_voc
......@@ -3335,14 +3335,17 @@ class VOCDataset(SourceDataset):
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the dataset
(default=None, expected order behavior shown in the table).
distribution (str, optional): Path to the json distribution file to configure
dataset sharding (default=None). This argument should be specified
only when no 'sampler' is used.
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument should be specified only when num_shards is also specified.
Raises:
RuntimeError: If distribution and sampler are specified at the same time.
RuntimeError: If distribution is failed to read.
RuntimeError: If shuffle and sampler are specified at the same time.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).
Examples:
>>> import mindspore.dataset as ds
......@@ -3356,27 +3359,15 @@ class VOCDataset(SourceDataset):
@check_vocdataset
def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None,
shuffle=None, decode=False, sampler=None, distribution=None):
shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers)
self.dataset_dir = dataset_dir
self.sampler = sampler
if distribution is not None:
if sampler is not None:
raise RuntimeError("Cannot specify distribution and sampler at the same time.")
try:
with open(distribution, 'r') as load_d:
json.load(load_d)
except json.decoder.JSONDecodeError:
raise RuntimeError("Json decode error when load distribution file")
except Exception:
raise RuntimeError("Distribution file has failed to load.")
elif shuffle is not None:
if sampler is not None:
raise RuntimeError("Cannot specify shuffle and sampler at the same time.")
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.decode = decode
self.distribution = distribution
self.shuffle_level = shuffle
self.num_shards = num_shards
self.shard_id = shard_id
def get_args(self):
args = super().get_args()
......@@ -3385,7 +3376,8 @@ class VOCDataset(SourceDataset):
args["sampler"] = self.sampler
args["decode"] = self.decode
args["shuffle"] = self.shuffle_level
args["distribution"] = self.distribution
args["num_shards"] = self.num_shards
args["shard_id"] = self.shard_id
return args
def get_dataset_size(self):
......
......@@ -286,7 +286,8 @@ def create_node(node):
elif dataset_op == 'VOCDataset':
sampler = construct_sampler(node.get('sampler'))
pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'),
node.get('shuffle'), node.get('decode'), sampler, node.get('distribution'))
node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'),
node.get('shard_id'))
elif dataset_op == 'CelebADataset':
sampler = construct_sampler(node.get('sampler'))
......
......@@ -443,9 +443,8 @@ def check_vocdataset(method):
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers']
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
nreq_param_str = ['distribution']
# check dataset_dir; required argument
dataset_dir = param_dict.get('dataset_dir')
......@@ -457,7 +456,7 @@ def check_vocdataset(method):
check_param_type(nreq_param_bool, param_dict, bool)
check_param_type(nreq_param_str, param_dict, str)
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册