提交 b520ca90 编写于 作者: L liyong

fix pk sampler in mindrecord

上级 5a03bd80
......@@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
std::map<std::string, uint64_t> index_columns;
for (auto &field : get_shard_header()->get_fields()) {
index_columns[field.second] = field.first;
}
if (index_columns.find(category_field) == index_columns.end()) {
MS_LOG(ERROR) << "Index field " << category_field << " does not exist.";
return FAILED;
}
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field));
if (SUCCESS != ret.first) {
return FAILED;
}
......
......@@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset):
if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler")
if shuffle is True and sampler is not None:
raise ValueError("shuffle not allowed true when use sampler")
if shuffle is not None and sampler is not None:
raise ValueError("shuffle not allowed when use sampler")
if block_reader is False and sampler is None:
self.global_shuffle = not bool(shuffle is False)
......
......@@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
def test_cv_minddataset_pk_sample_exclusive_shuffle():
create_cv_mindrecord(1)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(2)
with pytest.raises(Exception, match="shuffle not allowed when use sampler"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
sampler=sampler, shuffle=False)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -60,7 +60,21 @@ def add_and_remove_cv_file():
os.remove("{}".format(x))
os.remove("{}.db".format(x))
def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
num_readers = 4
sampler = ds.PKSampler(2)
data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
sampler=sampler)
assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册