提交 2412ee09 编写于 作者: L liyong 提交者: jonyguo

fix distributedSampler reshuffle and fix random_device failed

上级 23d0497d
......@@ -19,13 +19,16 @@
#if defined(_WIN32) || defined(_WIN64)
#include <stdlib.h>
#endif
#include <chrono>
#include <limits>
#include <memory>
#include <random>
#include <string>
#include <thread>
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
......@@ -35,6 +38,17 @@ inline std::mt19937 GetRandomDevice() {
rand_s(&number);
std::mt19937 random_device{static_cast<uint32_t>(number)};
#else
int i = 0;
while (i < 5) {
try {
std::mt19937 random_device{std::random_device("/dev/urandom")()};
return random_device;
} catch (const std::exception &e) {
MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
i++;
}
}
std::mt19937 random_device{std::random_device("/dev/urandom")()};
#endif
return random_device;
......
......@@ -44,8 +44,8 @@ class ShardDistributedSample : public ShardSample {
private:
bool shuffle_;
int no_of_padded_samples_;
bool init_judgment_; // we should judge the (num_sample + num_padded) % num_shards == 0 in first time
bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch
ShardTask task_; // maintain the input tasks in first epoch
};
} // namespace mindrecord
} // namespace mindspore
......
......@@ -17,6 +17,7 @@
#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_
#define MINDRECORD_INCLUDE_SHARD_TASK_H_
#include <algorithm>
#include <iostream>
#include <string>
#include <tuple>
......@@ -27,6 +28,14 @@ namespace mindspore {
namespace mindrecord {
class ShardTask {
public:
ShardTask();
ShardTask(const ShardTask &task); // copy construction
ShardTask &operator=(const ShardTask &task); // assignment operator
~ShardTask() = default;
void MakePerm();
void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
......@@ -46,10 +55,11 @@ class ShardTask {
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements);
uint32_t categories = 1;
uint32_t categories;
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
std::vector<int> permutation_;
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
};
} // namespace mindrecord
} // namespace mindspore
......
......@@ -1434,14 +1434,15 @@ void ShardReader::ShuffleTask() {
for (const auto &op : operators_) {
if (std::dynamic_pointer_cast<ShardShuffle>(op) && has_sharding == false) {
if (SUCCESS != (*op)(tasks_)) {
MS_LOG(WARNING) << "Reshuffle reader tasks failed.";
MS_LOG(WARNING) << "Redo randomSampler failed.";
}
} else if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
if (SUCCESS != op->PreExecute(tasks_)) {
MS_LOG(WARNING) << "Distribute reshuffle reader tasks failed.";
if (SUCCESS != (*op)(tasks_)) {
MS_LOG(WARNING) << "Redo distributeSampler failed.";
}
}
}
if (tasks_.permutation_.empty()) tasks_.MakePerm();
}
} // namespace mindrecord
......
......@@ -27,7 +27,7 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int
: ShardSample(1, num_shards, shard_id),
shuffle_(shuffle),
no_of_padded_samples_(no_of_padded_samples),
init_judgment_(false) {
first_epoch_(true) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}
......@@ -54,8 +54,7 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
auto total_no = tasks.Size();
if (no_of_padded_samples_ > 0 && init_judgment_ == false) { // we only judge this in first time
init_judgment_ = true;
if (no_of_padded_samples_ > 0 && first_epoch_) {
if (total_no % denominator_ != 0) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. "
<< "task size: " << total_no << ", number padded: " << no_of_padded_samples_
......@@ -63,6 +62,12 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
return FAILED;
}
}
if (first_epoch_) {
first_epoch_ = false;
task_ = tasks;
} else {
tasks = task_;
}
if (shuffle_ == true) {
if (SUCCESS != (*shuffle_op_)(tasks)) {
return FAILED;
......
......@@ -43,6 +43,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
}
MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
if (reshuffle_each_epoch_) shuffle_seed_++;
if (tasks.categories < 1) {
return FAILED;
}
......@@ -81,7 +82,6 @@ MSRStatus ShardShuffle::Execute(ShardTask &tasks) {
}
}
}
if (reshuffle_each_epoch_) shuffle_seed_++;
return SUCCESS;
}
} // namespace mindrecord
......
......@@ -24,6 +24,19 @@ using mindspore::MsLogLevel::DEBUG;
namespace mindspore {
namespace mindrecord {
ShardTask::ShardTask() : categories(1) {}
ShardTask::ShardTask(const ShardTask &other)
: categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {}
ShardTask &ShardTask::operator=(const ShardTask &other) {
ShardTask tmp(other);
std::swap(categories, tmp.categories);
permutation_.swap(tmp.permutation_);
task_list_.swap(tmp.task_list_);
return *this;
}
void ShardTask::MakePerm() {
permutation_ = std::vector<int>(task_list_.size());
for (uint32_t i = 0; i < task_list_.size(); i++) {
......
......@@ -278,6 +278,41 @@ def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_c
epoch3 = []
def test_cv_minddataset_partition_tutorial_check_whole_reshuffle_result_per_epoch(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4
num_shards = 3
epoch_result = [[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 0 result
[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]], # save partition 1 result
[["", "", "", ""], ["", "", "", ""], ["", "", "", ""]]] # svae partition 2 result
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards, shard_id=partition_id)
data_set = data_set.repeat(3)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
# total 3 partition, 4 result per epoch, total 12 result
epoch_result[partition_id][int(num_iter / 4)][num_iter % 4] = item["file_name"] # save epoch result
num_iter += 1
assert num_iter == 12
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
epoch_result[partition_id][0].sort()
epoch_result[partition_id][1].sort()
epoch_result[partition_id][2].sort()
assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
def test_cv_minddataset_check_shuffle_result(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
......
......@@ -468,6 +468,64 @@ def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_
partitions(5, 5, 3)
partitions(9, 8, 2)
def test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_result_per_epoch(add_and_remove_nlp_file):
columns_list = ["input_ids", "id", "rating"]
padded_sample = {}
padded_sample['id'] = "-1"
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
padded_sample['rating'] = 1.0
num_readers = 4
repeat_size = 3
def partitions(num_shards, num_padded, dataset_size):
num_padded_iter = 0
num_iter = 0
epoch_result = [[["" for i in range(dataset_size)] for i in range(repeat_size)] for i in range(num_shards)]
for partition_id in range(num_shards):
data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id,
padded_sample=padded_sample,
num_padded=num_padded)
assert data_set.get_dataset_size() == dataset_size
data_set = data_set.repeat(repeat_size)
inner_num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------"
.format(item["input_ids"], item["input_ids"].shape))
if item['id'] == bytes('-1', encoding='utf-8'):
num_padded_iter += 1
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
assert (item['input_ids'] == padded_sample['input_ids']).all()
assert (item['rating'] == padded_sample['rating']).all()
# save epoch result
epoch_result[partition_id][int(inner_num_iter / dataset_size)][inner_num_iter % dataset_size] = item["id"]
num_iter += 1
inner_num_iter += 1
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
assert epoch_result[partition_id][1] not in (epoch_result[partition_id][0], epoch_result[partition_id][2])
assert epoch_result[partition_id][2] not in (epoch_result[partition_id][1], epoch_result[partition_id][0])
if dataset_size > 2:
epoch_result[partition_id][0].sort()
epoch_result[partition_id][1].sort()
epoch_result[partition_id][2].sort()
assert epoch_result[partition_id][0] != epoch_result[partition_id][1]
assert epoch_result[partition_id][1] != epoch_result[partition_id][2]
assert epoch_result[partition_id][2] != epoch_result[partition_id][0]
assert num_padded_iter == num_padded * repeat_size
assert num_iter == dataset_size * num_shards * repeat_size
partitions(4, 6, 4)
partitions(5, 5, 3)
partitions(9, 8, 2)
def get_data(dir_name):
"""
usage: get data from imagenet dataset
......
......@@ -586,6 +586,13 @@ def test_cv_minddataset_split_sharding(add_and_remove_cv_file):
assert epoch2_dataset not in (epoch1_dataset, epoch3_dataset)
assert epoch3_dataset not in (epoch1_dataset, epoch2_dataset)
epoch1_dataset.sort()
epoch2_dataset.sort()
epoch3_dataset.sort()
assert epoch1_dataset != epoch2_dataset
assert epoch2_dataset != epoch3_dataset
assert epoch3_dataset != epoch1_dataset
def get_data(dir_name, sampler=False):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册