提交 af1fde39 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1509 dataset: PR1457 fix 3 bug reports for split

Merge pull request !1509 from ms_yan/r0.3_split
......@@ -71,7 +71,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
......
......@@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
rnd_.seed(seed_);
if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
shuffled_ids_.push_back(i);
}
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} else {
num_samples_ = std::min(num_samples_, user_num_samples_);
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
......
......@@ -32,9 +32,7 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
......@@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
WeightedRandomSampler, SubsetSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData
......
......@@ -633,9 +633,9 @@ class Dataset:
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset cannot be sharded if split is going to be called.
......@@ -678,7 +678,8 @@ class Dataset:
ds = copy.deepcopy(self)
if randomize:
# want to shuffle the same way every epoch before split
ds = ds.shuffle()
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
ds = ds.shuffle(10000)
ds.reshuffle_each_epoch = False
if rows_to_skip > 0:
......@@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("new_sampler is not an instance of a sampler.")
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
......@@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
def is_sharded(self):
raise NotImplementedError("MappableDataset must implement is_sharded.")
def _get_sampler_dataset_size(self):
if self.sampler is not None:
return self.sampler.get_dataset_size()
return None
@check_split
def split(self, sizes, randomize=True):
......@@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset should not be sharded if split is going to be called. Instead, create a
......@@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
self.iterator = TupleIterator(self)
class RangeDataset(MappableDataset):
"""
A source dataset that reads and parses datasets stored on disk in a range.
......@@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
else:
num_samples = self.num_samples
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
......@@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
num_samples = self.num_samples
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -2926,7 +2944,12 @@ class GeneratorDataset(MappableDataset):
Return:
Number, number of batches.
"""
return self._dataset_size
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self._dataset_size
return min(rows_from_sampler, self._dataset_size)
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
......@@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
class_indexing = self.class_indexing
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
......@@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
return num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def is_shuffled(self):
return True
......@@ -3871,7 +3914,12 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
return self.num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def get_class_indexing(self):
"""
......
......@@ -114,6 +114,9 @@ class Sampler:
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self._get_indices().size
class BuiltinSampler:
"""
......@@ -146,6 +149,12 @@ class BuiltinSampler:
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
def get_dataset_size(self):
if self.child_sampler is not None:
return self.child_sampler.get_dataset_size()
return None
class DistributedSampler(BuiltinSampler):
"""
......@@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
class SequentialSampler(BuiltinSampler):
"""
......@@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.subset_size
class SubsetRandomSampler(BuiltinSampler):
"""
......@@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
return cde.MindrecordSubsetRandomSampler(self.indices)
def get_dataset_size(self):
return len(indices)
class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
......@@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
return False
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
......@@ -30,6 +30,14 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_basic():
"""
Test basic configuration functions
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
prefetch_size_original = ds.config.get_prefetch_size()
seed_original = ds.config.get_seed()
ds.config.load('../data/dataset/declient.cfg')
# assert ds.config.get_rows_per_buffer() == 32
......@@ -50,6 +58,11 @@ def test_basic():
assert ds.config.get_prefetch_size() == 4
assert ds.config.get_seed() == 5
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_prefetch_size(prefetch_size_original)
ds.config.set_seed(seed_original)
def test_get_seed():
"""
......@@ -62,6 +75,9 @@ def test_pipeline():
"""
Test that our configuration pipeline works when we set parameters at different locations in dataset code
"""
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
......@@ -85,6 +101,9 @@ def test_pipeline():
except IOError:
logger.info("Error while deleting: {}".format(f))
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
def test_deterministic_run_fail():
"""
......@@ -92,6 +111,10 @@ def test_deterministic_run_fail():
"""
logger.info("test_deterministic_run_fail")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
......@@ -120,12 +143,21 @@ def test_deterministic_run_fail():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_run_pass():
"""
Test deterministic run with with setting the seed
"""
logger.info("test_deterministic_run_pass")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
......@@ -152,13 +184,23 @@ def test_deterministic_run_pass():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_seed_undeterministic():
"""
Test seed with num parallel workers in c, this test is expected to fail some of the time
"""
logger.info("test_seed_undeterministic")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
......@@ -178,6 +220,10 @@ def test_seed_undeterministic():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_run_distribution():
"""
......@@ -185,6 +231,10 @@ def test_deterministic_run_distribution():
"""
logger.info("test_deterministic_run_distribution")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
# when we set the seed all operations within our dataset should be deterministic
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
......@@ -206,12 +256,21 @@ def test_deterministic_run_distribution():
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal(item1["image"], item2["image"])
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_python_seed():
"""
Test deterministic execution with seed in python
"""
logger.info("deterministic_random_crop_op_python_2")
# Save original configuration values
num_parallel_workers_original = ds.config.get_num_parallel_workers()
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
ds.config.set_num_parallel_workers(1)
......@@ -242,12 +301,20 @@ def test_deterministic_python_seed():
np.testing.assert_equal(data1_output, data2_output)
# Restore original configuration values
ds.config.set_num_parallel_workers(num_parallel_workers_original)
ds.config.set_seed(seed_original)
def test_deterministic_python_seed_multi_thread():
"""
Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
"""
logger.info("deterministic_random_crop_op_python_2")
# Save original configuration values
seed_original = ds.config.get_seed()
ds.config.set_seed(0)
# when we set the seed all operations within our dataset should be deterministic
# First dataset
......@@ -282,6 +349,9 @@ def test_deterministic_python_seed_multi_thread():
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Array" in str(e)
# Restore original configuration values
ds.config.set_seed(seed_original)
if __name__ == '__main__':
test_basic()
......
......@@ -14,6 +14,8 @@
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
from util import config_get_set_num_parallel_workers
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*"
......@@ -38,7 +40,7 @@ def test_textline_dataset_all_file():
def test_textline_dataset_totext():
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.",
......@@ -48,6 +50,8 @@ def test_textline_dataset_totext():
assert (str == line[count])
count += 1
assert (count == 5)
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_textline_dataset_num_samples():
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
......@@ -164,6 +165,35 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_subset_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, start_index, subset_size):
sampler = ds.SubsetSampler(start_index, subset_size)
d = ds.ManifestDataset(manifest_file, sampler=sampler)
res = []
for item in d.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
with pytest.raises(RuntimeError) as info:
test_config(5, 0, 0)
assert "subset_size <= 0" in str(info.value)
assert test_config(5, 0, 1) == [0]
assert test_config(5, 0, 2) == [0, 1]
assert test_config(5, 0, 3) == [0, 1, 2]
assert test_config(5, 0, 4) == [0, 1, 2, 3]
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
assert test_config(5, 1, 1) == [1]
assert test_config(5, 2, 3) == [2, 3, 4]
assert test_config(5, 3, 2) == [3, 4]
assert test_config(5, 4, 1) == [4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
......@@ -190,10 +220,26 @@ def test_sampler_chain():
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
def test_add_sampler_invalid_input():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
data1 = ds.ManifestDataset(manifest_file)
with pytest.raises(TypeError) as info:
data1.use_sampler(1)
assert "not an instance of a sampler" in str(info.value)
with pytest.raises(TypeError) as info:
data1.use_sampler("sampler")
assert "not an instance of a sampler" in str(info.value)
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()
test_subset_sampler()
test_sampler_chain()
test_add_sampler_invalid_input()
......@@ -14,6 +14,8 @@
# ==============================================================================
import pytest
import mindspore.dataset as ds
from util import config_get_set_num_parallel_workers
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
# the label of each image is [0,0,0,1,1] each image can be uniquely identified
......@@ -21,7 +23,11 @@ import mindspore.dataset as ds
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def split_with_invalid_inputs(d):
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
def split_with_invalid_inputs(d):
with pytest.raises(ValueError) as info:
s1, s2 = d.split([])
assert "sizes cannot be empty" in str(info.value)
......@@ -66,8 +72,8 @@ def split_with_invalid_inputs(d):
s1, s2 = d.split([0.05, 0.95])
assert "percentage 0.05 is too small" in str(info.value)
def test_unmappable_invalid_input():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
d = ds.TextFileDataset(text_file_dataset_path)
split_with_invalid_inputs(d)
......@@ -76,11 +82,10 @@ def test_unmappable_invalid_input():
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
ds.config.set_num_parallel_workers(4)
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)
......@@ -123,6 +128,145 @@ def test_unmappable_split():
assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_randomize_deterministic():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
for _ in range(10):
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
# note no overlap
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
assert s2_output == [text_file_data[3]]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_randomize_repeatable():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
num_epochs = 5
s1 = s1.repeat(num_epochs)
s2 = s2.repeat(num_epochs)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
# note no overlap
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
assert s2_output == [text_file_data[3]] * num_epochs
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_get_dataset_size():
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 5
assert s1.get_dataset_size() == 4
assert s2.get_dataset_size() == 1
def test_unmappable_multi_split():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1])
s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
assert s1_output == s1_correct_output
# no randomize in second split
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(item["text"].item().decode("utf8"))
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(item["text"].item().decode("utf8"))
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(item["text"].item().decode("utf8"))
assert s1s1_output == [s1_correct_output[0]]
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
assert s1s3_output == [s1_correct_output[3]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s2_output == [text_file_data[3]]
# randomize in second split
# the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
shuffled_ids = [2, 3, 1, 0]
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(item["text"].item().decode("utf8"))
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(item["text"].item().decode("utf8"))
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(item["text"].item().decode("utf8"))
assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s2_output == [text_file_data[3]]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_mappable_invalid_input():
d = ds.ManifestDataset(manifest_file)
split_with_invalid_inputs(d)
......@@ -132,6 +276,7 @@ def test_mappable_invalid_input():
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general():
d = ds.ManifestDataset(manifest_file, shuffle=False)
d = d.take(5)
......@@ -178,6 +323,7 @@ def test_mappable_split_general():
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_split_optimized():
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -223,9 +369,9 @@ def test_mappable_split_optimized():
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_randomize_deterministic():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -244,9 +390,9 @@ def test_mappable_randomize_deterministic():
assert s1_output == [0, 1, 3, 4]
assert s2_output == [2]
def test_mappable_randomize_repeatable():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -268,9 +414,10 @@ def test_mappable_randomize_repeatable():
assert s1_output == [0, 1, 3, 4] * num_epochs
assert s2_output == [2] * num_epochs
def test_mappable_sharding():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
num_epochs = 5
......@@ -331,12 +478,94 @@ def test_mappable_sharding():
assert s2_output == [2]
assert d2s2_output == [2]
def test_mappable_get_dataset_size():
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([4, 1])
assert d.get_dataset_size() == 5
assert s1.get_dataset_size() == 4
assert s2.get_dataset_size() == 1
def test_mappable_multi_split():
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([4, 1])
s1_correct_output = [0, 1, 3, 4]
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == s1_correct_output
# no randomize in second split
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1s1_output == [s1_correct_output[0]]
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
assert s1s3_output == [s1_correct_output[3]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
# randomize in second split
# the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
random_sampler_ids = [3, 1, 2, 0]
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
if __name__ == '__main__':
test_unmappable_invalid_input()
test_unmappable_split()
test_unmappable_randomize_deterministic()
test_unmappable_randomize_repeatable()
test_unmappable_get_dataset_size()
test_unmappable_multi_split()
test_mappable_invalid_input()
test_mappable_split_general()
test_mappable_split_optimized()
test_mappable_randomize_deterministic()
test_mappable_randomize_repeatable()
test_mappable_sharding()
test_mappable_get_dataset_size()
test_mappable_multi_split()
......@@ -15,11 +15,11 @@
import hashlib
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import os
# import jsbeautifier
import mindspore.dataset as ds
from mindspore import log as logger
# These are the column names defined in the testTFTestAllTypes dataset
......@@ -221,3 +221,26 @@ def visualize(image_original, image_transformed):
plt.title("Transformed image")
plt.show()
def config_get_set_seed(seed_new):
"""
Get and return the original configuration seed value.
Set the new configuration seed value.
"""
seed_original = ds.config.get_seed()
ds.config.set_seed(seed_new)
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
return seed_original
def config_get_set_num_parallel_workers(num_parallel_workers_new):
"""
Get and return the original configuration num_parallel_workers value.
Set the new configuration num_parallel_workers value.
"""
num_parallel_workers_original = ds.config.get_num_parallel_workers()
ds.config.set_num_parallel_workers(num_parallel_workers_new)
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
num_parallel_workers_new))
return num_parallel_workers_original
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册