From b520ca908734828c401a32aad23bf141e1c4c61f Mon Sep 17 00:00:00 2001 From: liyong Date: Thu, 7 May 2020 14:53:41 +0800 Subject: [PATCH] fix pk sampler in mindrecord --- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 10 +++++++--- mindspore/dataset/engine/datasets.py | 4 ++-- .../python/dataset/test_minddataset_exception.py | 14 ++++++++++++++ .../ut/python/dataset/test_minddataset_sampler.py | 14 ++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 804613e40..27854bdf0 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, } MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { - if (column_schema_id_.find(category_field) == column_schema_id_.end()) { - MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + std::map index_columns; + for (auto &field : get_shard_header()->get_fields()) { + index_columns[field.second] = field.first; + } + if (index_columns.find(category_field) == index_columns.end()) { + MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; return FAILED; } - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); if (SUCCESS != ret.first) { return FAILED; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 239b0afb8..622115ab6 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset): if block_reader is True and sampler is not None: raise ValueError("block reader not allowed true when use sampler") - if shuffle is True and sampler is not None: - raise ValueError("shuffle not allowed true when use sampler") + if shuffle is not None and sampler is not None: + raise ValueError("shuffle not allowed when use sampler") if block_reader is False and sampler is None: self.global_shuffle = not bool(shuffle is False) diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index e1d54fa7c..2a269ffc8 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column(): os.remove(CV_FILE_NAME) os.remove("{}.db".format(CV_FILE_NAME)) +def test_cv_minddataset_pk_sample_exclusive_shuffle(): + create_cv_mindrecord(1) + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(2) + with pytest.raises(Exception, match="shuffle not allowed when use sampler"): + data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, + sampler=sampler, shuffle=False) + num_iter = 0 + for item in data_set.create_dict_iterator(): + num_iter += 1 + os.remove(CV_FILE_NAME) + os.remove("{}.db".format(CV_FILE_NAME)) + diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 584bb8804..5656a08ae 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -60,7 +60,21 @@ def add_and_remove_cv_file(): os.remove("{}".format(x)) os.remove("{}.db".format(x)) +def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + num_readers = 4 + sampler = ds.PKSampler(2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] -- GitLab