提交 f3ebc731 编写于 作者: J jonyguo

fix: MindDataset padded log error

上级 b3f09b1d
......@@ -40,6 +40,8 @@ class ShardDistributedSample : public ShardSample {
private:
bool shuffle_;
int no_of_padded_samples_;
bool init_judgment_; // we should judge the (num_sample + num_padded) % num_shards == 0 in first time
};
} // namespace mindrecord
} // namespace mindspore
......
......@@ -24,7 +24,10 @@ namespace mindspore {
namespace mindrecord {
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
uint32_t seed)
: ShardSample(1, num_shards, shard_id), shuffle_(shuffle), no_of_padded_samples_(no_of_padded_samples) {
: ShardSample(1, num_shards, shard_id),
shuffle_(shuffle),
no_of_padded_samples_(no_of_padded_samples),
init_judgment_(false) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}
......@@ -45,11 +48,15 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
}
return 0;
}
MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) {
auto total_no = tasks.Size();
if (no_of_padded_samples_ > 0) {
if (no_of_padded_samples_ > 0 && init_judgment_ == false) { // we only judge this in first time
init_judgment_ = true;
if (total_no % denominator_ != 0) {
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards.";
MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. "
<< "task size: " << total_no << ", number padded: " << no_of_padded_samples_
<< ", denominator: " << denominator_;
return FAILED;
}
}
......
......@@ -120,7 +120,7 @@ def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
assert item['label'] == padded_sample['label']
assert (item['data'] == np.array(list(padded_sample['data']))).all()
num_iter += 1
assert num_padded_iter ==5
assert num_padded_iter == 5
assert num_iter == 15
......@@ -135,6 +135,8 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
num_readers = 4
def partitions(num_shards, num_padded, dataset_size):
num_padded_iter = 0
num_iter = 0
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
......@@ -142,8 +144,6 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
padded_sample=padded_sample,
num_padded=num_padded)
assert data_set.get_dataset_size() == dataset_size
num_iter = 0
num_padded_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
......@@ -156,11 +156,53 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
assert item['label'] == padded_sample['label']
assert (item['data'] == np.array(list(padded_sample['data']))).all()
num_iter += 1
return num_iter
assert num_padded_iter == num_padded
return num_iter == dataset_size * num_shards
partitions(4, 2, 3)
partitions(5, 5, 3)
partitions(9, 8, 2)
def test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
data = get_data(CV_DIR_NAME)
padded_sample = data[0]
padded_sample['label'] = -2
padded_sample['file_name'] = 'dummy.jpg'
num_readers = 4
def partitions(num_shards, num_padded, dataset_size):
repeat_size = 5
num_padded_iter = 0
num_iter = 0
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id,
padded_sample=padded_sample,
num_padded=num_padded)
assert data_set.get_dataset_size() == dataset_size
data_set = data_set.repeat(repeat_size)
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
if item['label'] == -2:
num_padded_iter += 1
assert item['file_name'] == bytes(padded_sample['file_name'], encoding='utf8')
assert item['label'] == padded_sample['label']
assert (item['data'] == np.array(list(padded_sample['data']))).all()
num_iter += 1
assert num_padded_iter == num_padded * repeat_size
assert num_iter == dataset_size * num_shards * repeat_size
assert partitions(4, 2, 3) == 3
assert partitions(5, 5, 3) == 3
assert partitions(9, 8, 2) == 2
partitions(4, 2, 3)
partitions(5, 5, 3)
partitions(9, 8, 2)
def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
......@@ -308,6 +350,8 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
num_readers = 4
def partitions(num_shards, num_padded, dataset_size):
num_padded_iter = 0
num_iter = 0
for partition_id in range(num_shards):
data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
......@@ -315,22 +359,61 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
padded_sample=padded_sample,
num_padded=num_padded)
assert data_set.get_dataset_size() == dataset_size
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(item["input_ids"], item["input_ids"].shape))
if item['id'] == '-1':
if item['id'] == bytes('-1', encoding='utf-8'):
num_padded_iter += 1
assert item['id'] == padded_sample['id']
assert item['input_ids'] == padded_sample['input_ids']
assert item['rating'] == padded_sample['rating']
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
assert (item['input_ids'] == padded_sample['input_ids']).all()
assert (item['rating'] == padded_sample['rating']).all()
num_iter += 1
return num_iter
assert num_padded_iter == num_padded
assert num_iter == dataset_size * num_shards
partitions(4, 6, 4)
partitions(5, 5, 3)
partitions(9, 8, 2)
def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file):
columns_list = ["input_ids", "id", "rating"]
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
padded_sample = data[0]
padded_sample['id'] = "-1"
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
padded_sample['rating'] = 1.0
num_readers = 4
repeat_size = 3
def partitions(num_shards, num_padded, dataset_size):
num_padded_iter = 0
num_iter = 0
for partition_id in range(num_shards):
data_set = ds.MindDataset(NLP_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id,
padded_sample=padded_sample,
num_padded=num_padded)
assert data_set.get_dataset_size() == dataset_size
data_set = data_set.repeat(repeat_size)
for item in data_set.create_dict_iterator():
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(item["input_ids"], item["input_ids"].shape))
if item['id'] == bytes('-1', encoding='utf-8'):
num_padded_iter += 1
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
assert (item['input_ids'] == padded_sample['input_ids']).all()
assert (item['rating'] == padded_sample['rating']).all()
num_iter += 1
assert num_padded_iter == num_padded * repeat_size
assert num_iter == dataset_size * num_shards * repeat_size
assert partitions(4, 6, 4) == 4
assert partitions(5, 5, 3) == 3
assert partitions(9, 8, 2) == 2
partitions(4, 6, 4)
partitions(5, 5, 3)
partitions(9, 8, 2)
def get_data(dir_name):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册