diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc index 7117e142787f67c326f3d56b00845ede9806a09e..75c2c6bcc15727c236d384f8fee7daf921fd1306 100644 --- a/mindspore/ccsrc/minddata/dataset/api/samplers.cc +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -71,8 +71,8 @@ std::shared_ptr SequentialSampler(int64_t start_index, int } /// Function to create a Subset Random Sampler. -std::shared_ptr SubsetRandomSampler(const std::vector &indices, int64_t num_samples) { - auto sampler = std::make_shared(indices, num_samples); +std::shared_ptr SubsetRandomSampler(std::vector indices, int64_t num_samples) { + auto sampler = std::make_shared(std::move(indices), num_samples); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -81,9 +81,9 @@ std::shared_ptr SubsetRandomSampler(const std::vector WeightedRandomSampler(const std::vector &weights, int64_t num_samples, +std::shared_ptr WeightedRandomSampler(std::vector weights, int64_t num_samples, bool replacement) { - auto sampler = std::make_shared(weights, num_samples, replacement); + auto sampler = std::make_shared(std::move(weights), num_samples, replacement); // Input validation if (!sampler->ValidateParams()) { return nullptr; @@ -190,8 +190,8 @@ std::shared_ptr SequentialSamplerObj::Build() { } // SubsetRandomSampler -SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples) - : indices_(indices), num_samples_(num_samples) {} +SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector indices, int64_t num_samples) + : indices_(std::move(indices)), num_samples_(num_samples) {} bool SubsetRandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { @@ -208,9 +208,8 @@ std::shared_ptr SubsetRandomSamplerObj::Build() { } // WeightedRandomSampler -WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples, - bool replacement) - : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} +WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector weights, int64_t num_samples, bool replacement) + : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {} bool WeightedRandomSamplerObj::ValidateParams() { if (num_samples_ < 0) { diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index 3a06879b8ceafd6bf7c1374e79496ad3a913e4dc..9d423c78fa0bef8f700378f27e417780f4432aa4 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -87,8 +87,7 @@ std::shared_ptr SequentialSampler(int64_t start_index = 0, /// \param[in] indices - A vector sequence of indices. /// \param[in] num_samples - The number of samples to draw (default to all elements). /// \return Shared pointer to the current Sampler. -std::shared_ptr SubsetRandomSampler(const std::vector &indices, - int64_t num_samples = 0); +std::shared_ptr SubsetRandomSampler(std::vector indices, int64_t num_samples = 0); /// Function to create a Weighted Random Sampler. /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given @@ -97,8 +96,8 @@ std::shared_ptr SubsetRandomSampler(const std::vector WeightedRandomSampler(const std::vector &weights, - int64_t num_samples = 0, bool replacement = true); +std::shared_ptr WeightedRandomSampler(std::vector weights, int64_t num_samples = 0, + bool replacement = true); /* ####################################### Derived Sampler classes ################################# */ class DistributedSamplerObj : public SamplerObj { @@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj { class SubsetRandomSamplerObj : public SamplerObj { public: - SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples); + SubsetRandomSamplerObj(std::vector indices, int64_t num_samples); ~SubsetRandomSamplerObj() = default; @@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj { bool ValidateParams() override; private: - const std::vector &indices_; + const std::vector indices_; int64_t num_samples_; }; class WeightedRandomSamplerObj : public SamplerObj { public: - explicit WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples = 0, - bool replacement = true); + explicit WeightedRandomSamplerObj(std::vector weights, int64_t num_samples = 0, bool replacement = true); ~WeightedRandomSamplerObj() = default; @@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj { bool ValidateParams() override; private: - const std::vector &weights_; + const std::vector weights_; int64_t num_samples_; bool replacement_; }; diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 5ec2e5405641a82831e4195fc0e4b03781b5c109..4e5db80fe37b94f35c4df5b53d588dd217b69f2b 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) { + std::vector indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23}; + std::shared_ptr sampl1 = SubsetRandomSampler(indices); + EXPECT_FALSE(indices.empty()); + EXPECT_NE(sampl1->Build(), nullptr); + std::shared_ptr sampl2 = SubsetRandomSampler(std::move(indices)); + EXPECT_TRUE(indices.empty()); + EXPECT_NE(sampl2->Build(), nullptr); +} + TEST_F(MindDataTestPipeline, TestPad) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad.";