提交 4e8e82f2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1457 fix 3 bug reports for split

Merge pull request !1457 from Peilin/splitOp-after-testing
......@@ -71,7 +71,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, repeat_num=cfg.num_epochs)
ds_train = create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
......
......@@ -70,21 +70,26 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) {
}
Status RandomSampler::InitSampler() {
num_samples_ = (user_num_samples_ < num_samples_) ? user_num_samples_ : num_samples_;
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive.");
rnd_.seed(seed_);
if (replacement_ == false) {
num_samples_ = std::min(num_samples_, num_rows_);
shuffled_ids_.reserve(num_rows_);
for (int64_t i = 0; i < num_rows_; i++) {
shuffled_ids_.push_back(i);
}
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
} else {
num_samples_ = std::min(num_samples_, user_num_samples_);
dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1);
}
CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive.");
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
return Status::OK();
}
......
......@@ -32,10 +32,8 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) {
}
// Handshake and init child first.
if (HasChildSampler()) {
RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op));
}
}
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n");
RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_));
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -28,9 +28,9 @@ SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size)
: Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {}
Status SubsetSampler::InitSampler() {
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size_ <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows_\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n");
CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n");
num_samples_ = subset_size_;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......
......@@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \
Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
WeightedRandomSampler, SubsetSampler, Sampler
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData
......
......@@ -633,9 +633,9 @@ class Dataset:
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset cannot be sharded if split is going to be called.
......@@ -678,7 +678,8 @@ class Dataset:
ds = copy.deepcopy(self)
if randomize:
# want to shuffle the same way every epoch before split
ds = ds.shuffle()
# in alter_tree, shuffle buffer is minimum 10000, so use 10000 here
ds = ds.shuffle(10000)
ds.reshuffle_each_epoch = False
if rows_to_skip > 0:
......@@ -1209,6 +1210,9 @@ class MappableDataset(SourceDataset):
>>> new_sampler = ds.DistributedSampler(10, 2)
>>> data.use_sampler(new_sampler)
"""
if new_sampler is not None and not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("new_sampler is not an instance of a sampler.")
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
......@@ -1218,6 +1222,11 @@ class MappableDataset(SourceDataset):
def is_sharded(self):
raise NotImplementedError("MappableDataset must implement is_sharded.")
def _get_sampler_dataset_size(self):
if self.sampler is not None:
return self.sampler.get_dataset_size()
return None
@check_split
def split(self, sizes, randomize=True):
......@@ -1236,9 +1245,9 @@ class MappableDataset(SourceDataset):
Datasets of size f1*K, f2*K, …, fn*K (rounded to nearest integer) where K is the size
of the original dataset. If after rounding, any size equals 0, an error will occur.
All floats must be between 0 and 1 and must sum to 1, otherwise an error will occur.
randomize (bool): determines whether or not to split the data randomly. If true, the data
will be randomly split. Otherwise, each split will be created with consecutive rows
from the dataset.
randomize (bool, optional): determines whether or not to split the data randomly (default=True).
If true, the data will be randomly split. Otherwise, each split will be created with
consecutive rows from the dataset.
Note:
1. Dataset should not be sharded if split is going to be called. Instead, create a
......@@ -2105,7 +2114,6 @@ class TransferDataset(DatasetOp):
self.iterator = TupleIterator(self)
class RangeDataset(MappableDataset):
"""
A source dataset that reads and parses datasets stored on disk in a range.
......@@ -2296,8 +2304,13 @@ class ImageFolderDatasetV2(MappableDataset):
else:
num_samples = self.num_samples
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
......@@ -2425,8 +2438,13 @@ class MnistDataset(MappableDataset):
num_samples = self.num_samples
num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -2926,8 +2944,13 @@ class GeneratorDataset(MappableDataset):
Return:
Number, number of batches.
"""
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self._dataset_size
return min(rows_from_sampler, self._dataset_size)
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
if value >= 0:
......@@ -3220,8 +3243,13 @@ class ManifestDataset(MappableDataset):
class_indexing = self.class_indexing
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def num_classes(self):
"""
......@@ -3379,8 +3407,13 @@ class Cifar10Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return get_num_rows(num_rows, self.num_shards)
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -3498,8 +3531,13 @@ class Cifar100Dataset(MappableDataset):
num_samples = self.num_samples
num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
return get_num_rows(num_rows, self.num_shards)
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
def is_shuffled(self):
if self.shuffle_level is None:
......@@ -3562,7 +3600,12 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
return num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def is_shuffled(self):
return True
......@@ -3871,8 +3914,13 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return self.num_samples
return min(rows_from_sampler, self.num_samples)
def get_class_indexing(self):
"""
Get the class index.
......
......@@ -114,6 +114,9 @@ class Sampler:
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self._get_indices().size
class BuiltinSampler:
"""
......@@ -146,6 +149,12 @@ class BuiltinSampler:
def is_sharded(self):
raise NotImplementedError("Sampler must implement is_sharded.")
def get_dataset_size(self):
if self.child_sampler is not None:
return self.child_sampler.get_dataset_size()
return None
class DistributedSampler(BuiltinSampler):
"""
......@@ -330,6 +339,9 @@ class RandomSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
class SequentialSampler(BuiltinSampler):
"""
......@@ -421,6 +433,9 @@ class SubsetSampler(BuiltinSampler):
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.subset_size
class SubsetRandomSampler(BuiltinSampler):
"""
......@@ -467,6 +482,10 @@ class SubsetRandomSampler(BuiltinSampler):
return cde.MindrecordSubsetRandomSampler(self.indices)
def get_dataset_size(self):
return len(indices)
class WeightedRandomSampler(BuiltinSampler):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
......@@ -522,3 +541,6 @@ class WeightedRandomSampler(BuiltinSampler):
return False
return self.child_sampler.is_sharded()
def get_dataset_size(self):
return self.num_samples
......@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
......@@ -164,6 +165,35 @@ def test_python_sampler():
assert list(sp1.get_indices()) == [0, 1, 2, 3, 4]
def test_subset_sampler():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
def test_config(num_samples, start_index, subset_size):
sampler = ds.SubsetSampler(start_index, subset_size)
d = ds.ManifestDataset(manifest_file, sampler=sampler)
res = []
for item in d.create_dict_iterator():
res.append(map[(item["image"].shape[0], item["label"].item())])
return res
with pytest.raises(RuntimeError) as info:
test_config(5, 0, 0)
assert "subset_size <= 0" in str(info.value)
assert test_config(5, 0, 1) == [0]
assert test_config(5, 0, 2) == [0, 1]
assert test_config(5, 0, 3) == [0, 1, 2]
assert test_config(5, 0, 4) == [0, 1, 2, 3]
assert test_config(5, 0, 5) == [0, 1, 2, 3, 4]
assert test_config(5, 1, 1) == [1]
assert test_config(5, 2, 3) == [2, 3, 4]
assert test_config(5, 3, 2) == [3, 4]
assert test_config(5, 4, 1) == [4]
def test_sampler_chain():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
......@@ -190,10 +220,26 @@ def test_sampler_chain():
assert test_config(5, 3) == [3]
assert test_config(5, 4) == [4]
def test_add_sampler_invalid_input():
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
data1 = ds.ManifestDataset(manifest_file)
with pytest.raises(TypeError) as info:
data1.use_sampler(1)
assert "not an instance of a sampler" in str(info.value)
with pytest.raises(TypeError) as info:
data1.use_sampler("sampler")
assert "not an instance of a sampler" in str(info.value)
if __name__ == '__main__':
test_sequential_sampler(True)
test_random_sampler(True)
test_random_sampler_multi_iter(True)
test_sampler_py_api()
test_python_sampler()
test_subset_sampler()
test_sampler_chain()
test_add_sampler_invalid_input()
......@@ -23,6 +23,10 @@ from util import config_get_set_num_parallel_workers
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
manifest_map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4}
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
def split_with_invalid_inputs(d):
with pytest.raises(ValueError) as info:
s1, s2 = d.split([])
......@@ -68,8 +72,8 @@ def split_with_invalid_inputs(d):
s1, s2 = d.split([0.05, 0.95])
assert "percentage 0.05 is too small" in str(info.value)
def test_unmappable_invalid_input():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
d = ds.TextFileDataset(text_file_dataset_path)
split_with_invalid_inputs(d)
......@@ -78,11 +82,10 @@ def test_unmappable_invalid_input():
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_unmappable_split():
text_file_dataset_path = "../data/dataset/testTextFileDataset/*"
text_file_data = ["This is a text file.", "Another file.", "Be happy every day.",
"End of file.", "Good luck to everyone."]
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1], randomize=False)
......@@ -124,6 +127,142 @@ def test_unmappable_split():
assert s1_output == text_file_data[0:2]
assert s2_output == text_file_data[2:]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_randomize_deterministic():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
for _ in range(10):
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
# note no overlap
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
assert s2_output == [text_file_data[3]]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_randomize_repeatable():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
num_epochs = 5
s1 = s1.repeat(num_epochs)
s2 = s2.repeat(num_epochs)
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
# note no overlap
assert s1_output == [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]] * num_epochs
assert s2_output == [text_file_data[3]] * num_epochs
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_unmappable_get_dataset_size():
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([0.8, 0.2])
assert d.get_dataset_size() == 5
assert s1.get_dataset_size() == 4
assert s2.get_dataset_size() == 1
def test_unmappable_multi_split():
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
# the labels outputted by ShuffleOp for seed 53 is [0, 2, 1, 4, 3]
ds.config.set_seed(53)
d = ds.TextFileDataset(text_file_dataset_path, shuffle=False)
s1, s2 = d.split([4, 1])
s1_correct_output = [text_file_data[0], text_file_data[2], text_file_data[1], text_file_data[4]]
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(item["text"].item().decode("utf8"))
assert s1_output == s1_correct_output
# no randomize in second split
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(item["text"].item().decode("utf8"))
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(item["text"].item().decode("utf8"))
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(item["text"].item().decode("utf8"))
assert s1s1_output == [s1_correct_output[0]]
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
assert s1s3_output == [s1_correct_output[3]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s2_output == [text_file_data[3]]
# randomize in second split
# the labels outputted by the ShuffleOp for seed 53 is [2, 3, 1, 0]
shuffled_ids = [2, 3, 1, 0]
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(item["text"].item().decode("utf8"))
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(item["text"].item().decode("utf8"))
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(item["text"].item().decode("utf8"))
assert s1s1_output == [s1_correct_output[shuffled_ids[0]]]
assert s1s2_output == [s1_correct_output[shuffled_ids[1]], s1_correct_output[shuffled_ids[2]]]
assert s1s3_output == [s1_correct_output[shuffled_ids[3]]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(item["text"].item().decode("utf8"))
assert s2_output == [text_file_data[3]]
# Restore configuration num_parallel_workers
ds.config.set_num_parallel_workers(original_num_parallel_workers)
......@@ -137,6 +276,7 @@ def test_mappable_invalid_input():
s1, s2 = d.split([4, 1])
assert "dataset should not be sharded before split" in str(info.value)
def test_mappable_split_general():
d = ds.ManifestDataset(manifest_file, shuffle=False)
d = d.take(5)
......@@ -183,6 +323,7 @@ def test_mappable_split_general():
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_split_optimized():
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -228,9 +369,9 @@ def test_mappable_split_optimized():
assert s1_output == [0, 1]
assert s2_output == [2, 3, 4]
def test_mappable_randomize_deterministic():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -249,9 +390,9 @@ def test_mappable_randomize_deterministic():
assert s1_output == [0, 1, 3, 4]
assert s2_output == [2]
def test_mappable_randomize_repeatable():
# set arbitrary seed for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
......@@ -273,9 +414,10 @@ def test_mappable_randomize_repeatable():
assert s1_output == [0, 1, 3, 4] * num_epochs
assert s2_output == [2] * num_epochs
def test_mappable_sharding():
# set arbitrary seed for repeatability for shard after split
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4]
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
num_epochs = 5
......@@ -336,12 +478,94 @@ def test_mappable_sharding():
assert s2_output == [2]
assert d2s2_output == [2]
def test_mappable_get_dataset_size():
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([4, 1])
assert d.get_dataset_size() == 5
assert s1.get_dataset_size() == 4
assert s2.get_dataset_size() == 1
def test_mappable_multi_split():
# the labels outputted by ManifestDataset for seed 53 is [0, 1, 3, 4, 2]
ds.config.set_seed(53)
d = ds.ManifestDataset(manifest_file, shuffle=False)
s1, s2 = d.split([4, 1])
s1_correct_output = [0, 1, 3, 4]
s1_output = []
for item in s1.create_dict_iterator():
s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1_output == s1_correct_output
# no randomize in second split
s1s1, s1s2, s1s3 = s1.split([1, 2, 1], randomize=False)
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1s1_output == [s1_correct_output[0]]
assert s1s2_output == [s1_correct_output[1], s1_correct_output[2]]
assert s1s3_output == [s1_correct_output[3]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
# randomize in second split
# the labels outputted by the RandomSampler for seed 53 is [3, 1, 2, 0]
random_sampler_ids = [3, 1, 2, 0]
s1s1, s1s2, s1s3 = s1.split([1, 2, 1])
s1s1_output = []
for item in s1s1.create_dict_iterator():
s1s1_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s2_output = []
for item in s1s2.create_dict_iterator():
s1s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
s1s3_output = []
for item in s1s3.create_dict_iterator():
s1s3_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s1s1_output == [s1_correct_output[random_sampler_ids[0]]]
assert s1s2_output == [s1_correct_output[random_sampler_ids[1]], s1_correct_output[random_sampler_ids[2]]]
assert s1s3_output == [s1_correct_output[random_sampler_ids[3]]]
s2_output = []
for item in s2.create_dict_iterator():
s2_output.append(manifest_map[(item["image"].shape[0], item["label"].item())])
assert s2_output == [2]
if __name__ == '__main__':
test_unmappable_invalid_input()
test_unmappable_split()
test_unmappable_randomize_deterministic()
test_unmappable_randomize_repeatable()
test_unmappable_get_dataset_size()
test_unmappable_multi_split()
test_mappable_invalid_input()
test_mappable_split_general()
test_mappable_split_optimized()
test_mappable_randomize_deterministic()
test_mappable_randomize_repeatable()
test_mappable_sharding()
test_mappable_get_dataset_size()
test_mappable_multi_split()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册