diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8ef5f9dd31bccab6d416c7acc556232a5eef5297..16487347048912be1d357a2bc62e2dc85e5fba3a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3335,14 +3335,17 @@ class VOCDataset(SourceDataset): decode (bool, optional): Decode the images after reading (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, expected order behavior shown in the table). - distribution (str, optional): Path to the json distribution file to configure - dataset sharding (default=None). This argument should be specified - only when no 'sampler' is used. + num_shards (int, optional): Number of shards that the dataset should be divided + into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. Raises: - RuntimeError: If distribution and sampler are specified at the same time. - RuntimeError: If distribution is failed to read. - RuntimeError: If shuffle and sampler are specified at the same time. + RuntimeError: If sampler and shuffle are specified at the same time. + RuntimeError: If sampler and sharding are specified at the same time. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + ValueError: If shard_id is invalid (< 0 or >= num_shards). Examples: >>> import mindspore.dataset as ds @@ -3356,27 +3359,15 @@ class VOCDataset(SourceDataset): @check_vocdataset def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, - shuffle=None, decode=False, sampler=None, distribution=None): + shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir - self.sampler = sampler - if distribution is not None: - if sampler is not None: - raise RuntimeError("Cannot specify distribution and sampler at the same time.") - try: - with open(distribution, 'r') as load_d: - json.load(load_d) - except json.decoder.JSONDecodeError: - raise RuntimeError("Json decode error when load distribution file") - except Exception: - raise RuntimeError("Distribution file has failed to load.") - elif shuffle is not None: - if sampler is not None: - raise RuntimeError("Cannot specify shuffle and sampler at the same time.") + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.decode = decode - self.distribution = distribution self.shuffle_level = shuffle + self.num_shards = num_shards + self.shard_id = shard_id def get_args(self): args = super().get_args() @@ -3385,7 +3376,8 @@ class VOCDataset(SourceDataset): args["sampler"] = self.sampler args["decode"] = self.decode args["shuffle"] = self.shuffle_level - args["distribution"] = self.distribution + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id return args def get_dataset_size(self): diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 61417e4d52e16df6c8e45b09ef2623342f1c9021..f588d572bbab188962b83cb9d1a68856cfed5169 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -286,7 +286,8 @@ def create_node(node): elif dataset_op == 'VOCDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), - node.get('shuffle'), node.get('decode'), sampler, node.get('distribution')) + node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), + node.get('shard_id')) elif dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a8d18ab2c106878c26da35ef5329400c1885aaa4..29bce25bd128346ea57f21f773ba92fb5d66ffcc 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -443,9 +443,8 @@ def check_vocdataset(method): def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) - nreq_param_int = ['num_samples', 'num_parallel_workers'] + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] - nreq_param_str = ['distribution'] # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') @@ -457,7 +456,7 @@ def check_vocdataset(method): check_param_type(nreq_param_bool, param_dict, bool) - check_param_type(nreq_param_str, param_dict, str) + check_sampler_shuffle_shard_options(param_dict) return method(*args, **kwargs)