提交 e2ea1fa0 编写于 作者: L liyong

activate num_samples in distributed samplers

上级 11732f0e
......@@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) {
(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());
(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")
......
......@@ -29,9 +29,10 @@ namespace mindspore {
namespace mindrecord {
class ShardDistributedSample : public ShardSample {
public:
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed,
int no_of_samples = 0);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0);
void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }
......
......@@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {
ShardSample(int num, int den);
ShardSample(int num, int den, int par);
ShardSample(int num, int den, int par, int no_of_samples = 0);
ShardSample(const std::vector<int64_t> &indices, uint32_t seed);
......
......@@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
uint32_t seed)
: ShardSample(1, num_shards, shard_id),
uint32_t seed, int no_of_samples)
: ShardSample(1, num_shards, shard_id, no_of_samples),
shuffle_(shuffle),
no_of_padded_samples_(no_of_padded_samples),
first_epoch_(true) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed,
int no_of_samples)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {}
int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) {
......
......@@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den)
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}
ShardSample::ShardSample(int num, int den, int par)
ShardSample::ShardSample(int num, int den, int par, int no_of_samples)
: numerator_(num),
denominator_(den),
partition_id_(par),
no_of_samples_(0),
no_of_samples_(no_of_samples),
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}
......@@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
}
} else {
int count = 0;
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
count++;
}
}
std::swap(tasks, new_tasks);
......@@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
int count = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
count++;
}
std::swap(tasks, new_tasks);
}
......
......@@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler):
return c_sampler
def create_for_minddataset(self):
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
self.seed, num_samples)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler
......
......@@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=1)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 1
assert partitions(5) == 1
assert partitions(9) == 1
def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=2)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 2
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=3)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter
assert partitions(4) == 3
assert partitions(5) == 2
assert partitions(9) == 2
def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
......
......@@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_minddataset_partition_num_samples_equals_0():
"""tutorial for cv minddataset."""
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4
def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=0)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
with pytest.raises(Exception) as error_info:
partitions(5)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册