diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index cf7050450b3f50518a6327f4d23c694812aa9f4d..f572db0cdf97ade9cb9284e769ce05e13d405c5b 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -391,30 +391,6 @@ Status DEPipeline::CheckMindRecordPartitionInfo(const py::dict &args, std::vecto return Status::OK(); } -Status DEPipeline::GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, - std::shared_ptr *ptr) { - std::vector indices; - for (auto &arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "indices") { - indices = ToIntVector(value); - } else { - std::string err_msg = "ERROR: parameter " + key + " is invalid."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - } - if (sampler_name == "SubsetRandomSampler") { - *ptr = std::make_shared(indices); - } else { - std::string err_msg = "ERROR: parameter sampler_name is invalid."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *ptr) { if (args["dataset_file"].is_none()) { std::string err_msg = "Error: at least one of dataset_files is missing"; @@ -446,12 +422,10 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr(seed)); - } else if (key == "sampler_name") { - std::shared_ptr sample_op; - auto ret = GetMindrecordSampler(ToString(value), args["sampler_params"], &sample_op); - if (Status::OK() != ret) { - return ret; - } + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("_create_for_minddataset"); + std::shared_ptr sample_op = + create().cast>(); operators.push_back(sample_op); } } diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 491a75390e0b709f7651afdf27a7f55bc9e384c0..acffc390cc059ddc2247f86c421e0a6665216ea6 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -145,9 +145,6 @@ class DEPipeline { Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); - Status GetMindrecordSampler(const std::string &sampler_name, const py::dict &args, - std::shared_ptr *ptr); - private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 1b0d913f3eb046ab12f36f93588470d191a31164..3d543f946b3db3a8aacfd10252c244e538bed9f0 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -54,6 +54,9 @@ #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" #include "dataset/kernels/data/to_float16_op.h" +#include "dataset/util/random.h" +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" #include "pybind11/stl_bind.h" @@ -382,6 +385,7 @@ void bindTensorOps4(py::module *m) { void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "Sampler"); + (void)py::class_>(*m, "ShardOperator"); (void)py::class_>(*m, "DistributedSampler") .def(py::init(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), @@ -399,6 +403,10 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "SubsetRandomSampler") .def(py::init>(), py::arg("indices")); + (void)py::class_>( + *m, "MindrecordSubsetRandomSampler") + .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), py::arg("replacement")); diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h index 08e5ac9c2e9412f7312d93ff5337bfff269eae66..b8a761154094783bd4d0b8ba319966930b701f63 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/mindrecord/include/shard_category.h @@ -32,7 +32,7 @@ class ShardCategory : public ShardOperator { const std::vector> &get_categories() const; - MSRStatus operator()(ShardTask &tasks) override; + MSRStatus execute(ShardTask &tasks) override; private: std::vector> categories_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 9d00fb7628de6096451f26fed137d1934a406e63..9f302e5321df7b215a5bab0af0cbf089eb531283 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -24,7 +24,25 @@ namespace mindrecord { class ShardOperator { public: virtual ~ShardOperator() = default; - virtual MSRStatus operator()(ShardTask &tasks) = 0; + + MSRStatus operator()(ShardTask &tasks) { + if (SUCCESS != this->pre_execute(tasks)) { + return FAILED; + } + if (SUCCESS != this->execute(tasks)) { + return FAILED; + } + if (SUCCESS != this->suf_execute(tasks)) { + return FAILED; + } + return SUCCESS; + } + + virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; } + + virtual MSRStatus execute(ShardTask &tasks) = 0; + + virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index aeb3374f2810daa1b05f0a74df98adfc5885a2cc..15353fd0ff5d34f9dacbe248f3668470f40caa87 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -17,10 +17,12 @@ #ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ #define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#include #include #include #include #include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_shuffle.h" namespace mindspore { namespace mindrecord { @@ -32,21 +34,23 @@ class ShardSample : public ShardOperator { ShardSample(int num, int den, int par); - explicit ShardSample(const std::vector &indices); + ShardSample(const std::vector &indices, uint32_t seed); ~ShardSample() override{}; const std::pair get_partitions() const; - MSRStatus operator()(ShardTask &tasks) override; + MSRStatus execute(ShardTask &tasks) override; + MSRStatus suf_execute(ShardTask &tasks) override; private: int numerator_; int denominator_; int no_of_samples_; int partition_id_; - std::vector indices_; + std::vector indices_; SamplerType sampler_type_; + std::shared_ptr shuffle_op_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index a9992ab4bcbe2ed9da6f68f188843618d5b3511c..464881aa7a0a06234cb19457b2a0c620e281107e 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator { ~ShardShuffle() override{}; - MSRStatus operator()(ShardTask &tasks) override; + MSRStatus execute(ShardTask &tasks) override; private: uint32_t shuffle_seed_; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 12aecea21fb9853ea548c46fa432d76b257973cb..2413da3737cd954b2fe89870c970e830d119db32 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -779,8 +779,12 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { // Sort row group by (group_id, shard_id), prepare for parallel reading std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); - CreateTasks(row_group_summary, operators_); - MS_LOG(INFO) << "Launching read threads"; + if (CreateTasks(row_group_summary, operators_) != SUCCESS) { + MS_LOG(ERROR) << "Failed to launch read threads."; + interrupt_ = true; + return FAILED; + } + MS_LOG(INFO) << "Launching read threads."; if (isSimpleReader) return SUCCESS; @@ -1152,6 +1156,9 @@ std::vector, json>> ShardReader::GetBlockNext() } std::vector, json>> ShardReader::GetNext() { + if (interrupt_) { + return std::vector, json>>(); + } if (block_reader_) return GetBlockNext(); if (deliver_id_ >= static_cast(tasks_.Size())) { return std::vector, json>>(); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index c64a7bfc70853abdfd82aff53a351448c6137420..859a3b343fecc21ca24a0f8bab8e1362462c5e50 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -23,6 +23,6 @@ ShardCategory::ShardCategory(const std::vector> &ShardCategory::get_categories() const { return categories_; } -MSRStatus ShardCategory::operator()(ShardTask &tasks) { return SUCCESS; } +MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index 367c7a5cf9d17a9c5cda04ba71ff2ec2954f7a24..ef627b0c09f98aed0b6fdf36acba4806facc5c14 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -46,13 +46,15 @@ ShardSample::ShardSample(int num, int den, int par) indices_({}), sampler_type_(kCustomTopPercentSampler) {} -ShardSample::ShardSample(const std::vector &indices) +ShardSample::ShardSample(const std::vector &indices, uint32_t seed) : numerator_(0), denominator_(0), no_of_samples_(0), partition_id_(0), indices_(indices), - sampler_type_(kSubsetRandomSampler) {} + sampler_type_(kSubsetRandomSampler) { + shuffle_op_ = std::make_shared(seed); +} const std::pair ShardSample::get_partitions() const { if (numerator_ == 1 && denominator_ > 1) { @@ -61,7 +63,7 @@ const std::pair ShardSample::get_partitions() const { return std::pair(-1, -1); } -MSRStatus ShardSample::operator()(ShardTask &tasks) { +MSRStatus ShardSample::execute(ShardTask &tasks) { int no_of_categories = static_cast(tasks.categories); int total_no = static_cast(tasks.Size()); @@ -115,5 +117,14 @@ MSRStatus ShardSample::operator()(ShardTask &tasks) { } return SUCCESS; } + +MSRStatus ShardSample::suf_execute(ShardTask &tasks) { + if (sampler_type_ == kSubsetRandomSampler) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index 14816e9e9fa9b75957aa4f9cb3d8508f9951b731..f8ad2c341dc84ae0c6558354b8130f4ba382141f 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace mindrecord { ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} -MSRStatus ShardShuffle::operator()(ShardTask &tasks) { +MSRStatus ShardShuffle::execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 4e0b0827347d60be9c3356f61bf1b134e2295dd5..f92f6f28a83fe07deebf145988c62352a3624dd1 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1683,9 +1683,7 @@ class MindDataset(SourceDataset): args["block_reader"] = self.block_reader args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id - if self.sampler: - args["sampler_name"] = self.sampler.__class__.__name__ - args["sampler_params"] = self.sampler.__dict__ + args["sampler"] = self.sampler return args def get_dataset_size(self): diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 62a3dbed18a1c1478d70e602dc93a972f1d42702..fd9c50e951ef629b8704d4dc944058018c28addf 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -195,6 +195,8 @@ class SubsetRandomSampler(): def create(self): return cde.SubsetRandomSampler(self.indices) + def _create_for_minddataset(self): + return cde.MindrecordSubsetRandomSampler(self.indices) class WeightedRandomSampler(): """ diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 143931658a258f6072abd50f976b502ba0e336dc..549e2140f41b1ce1da3cee33000f94b1a51eda6a 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -30,9 +30,9 @@ #include "mindrecord/include/shard_shuffle.h" #include "ut_common.h" -using mindspore::MsLogLevel::INFO; -using mindspore::ExceptionType::NoExceptionType; using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { @@ -65,31 +65,31 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { ASSERT_TRUE(i <= kSampleCount); } -// TEST_F(TestShardOperator, TestShardSampleWrongNumber) { -// MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); -// -// std::string file_name = "./imagenet.shard01"; -// auto column_list = std::vector{"file_name"}; -// -// const int kNum = 5; -// const int kDen = 0; -// std::vector> ops; -// ops.push_back(std::make_shared(kNum, kDen)); -// -// ShardReader dataset; -// dataset.Open(file_name, 4, column_list, ops); -// dataset.Launch(); -// -// int i = 0; -// while (true) { -// auto x = dataset.GetNext(); -// if (x.empty()) break; -// MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); -// i++; -// } -// dataset.Finish(); -// ASSERT_TRUE(i <= 5); -// } +TEST_F(TestShardOperator, TestShardSampleWrongNumber) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name"}; + + const int kNum = 5; + const int kDen = 0; + std::vector> ops; + ops.push_back(std::make_shared(kNum, kDen)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); + i++; + } + dataset.Finish(); + ASSERT_TRUE(i <= 5); +} TEST_F(TestShardOperator, TestShardSampleRatio) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); @@ -117,7 +117,6 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { ASSERT_TRUE(i <= 10); } - TEST_F(TestShardOperator, TestShardSamplePartition) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); std::string file_name = "./imagenet.shard01"; @@ -170,8 +169,8 @@ TEST_F(TestShardOperator, TestShardCategory) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); @@ -199,8 +198,8 @@ TEST_F(TestShardOperator, TestShardShuffle) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } dataset.Finish(); @@ -224,8 +223,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } dataset.Finish(); @@ -251,8 +250,8 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } dataset.Finish(); @@ -278,8 +277,8 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; } dataset.Finish(); @@ -307,8 +306,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; auto y = compare_dataset.GetNext(); @@ -342,8 +341,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); @@ -376,8 +375,8 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); category_no++; @@ -410,8 +409,8 @@ TEST_F(TestShardOperator, TestShardCategorySample) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); @@ -448,8 +447,8 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { while (true) { auto x = dataset.GetNext(); if (x.empty()) break; - MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) << - ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); + MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); i++; ASSERT_TRUE((std::get<1>(x[0]))["label"] == categories[category_no].second); diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 7662a0e390096d37ed3e22eaa855b76693885dda..3cad3877efa22841250bd8f05eeb81276d01df08 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -81,8 +81,6 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) - assert data[indices[num_iter]]['file_name'] == "".join( - [chr(x) for x in item['file_name']]) num_iter += 1 assert num_iter == 5 @@ -107,8 +105,6 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) - assert data[indices[num_iter]]['file_name'] == "".join( - [chr(x) for x in item['file_name']]) num_iter += 1 assert num_iter == 6 @@ -133,8 +129,6 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) - assert data[indices[num_iter]]['file_name'] == "".join( - [chr(x) for x in item['file_name']]) num_iter += 1 assert num_iter == 0 @@ -159,8 +153,6 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) - assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ - chr(x) for x in item['file_name']]) num_iter += 1 assert num_iter == 5 @@ -185,8 +177,6 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): "-------------- item[file_name]: {} ------------------------".format(item["file_name"])) logger.info( "-------------- item[label]: {} ----------------------------".format(item["label"])) - assert data[indices[num_iter] % len(data)]['file_name'] == "".join([ - chr(x) for x in item['file_name']]) num_iter += 1 assert num_iter == 5