提交 e9422646 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4023 change sampler parameter to support std::move

Merge pull request !4023 from 章一智/cpp_api_sampler
...@@ -71,8 +71,8 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int ...@@ -71,8 +71,8 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
} }
/// Function to create a Subset Random Sampler. /// Function to create a Subset Random Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples) { std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples); auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
// Input validation // Input validation
if (!sampler->ValidateParams()) { if (!sampler->ValidateParams()) {
return nullptr; return nullptr;
...@@ -81,9 +81,9 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in ...@@ -81,9 +81,9 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
} }
/// Function to create a Weighted Random Sampler. /// Function to create a Weighted Random Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples,
bool replacement) { bool replacement) {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement); auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
// Input validation // Input validation
if (!sampler->ValidateParams()) { if (!sampler->ValidateParams()) {
return nullptr; return nullptr;
...@@ -190,8 +190,8 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() { ...@@ -190,8 +190,8 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
} }
// SubsetRandomSampler // SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples) SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(indices), num_samples_(num_samples) {} : indices_(std::move(indices)), num_samples_(num_samples) {}
bool SubsetRandomSamplerObj::ValidateParams() { bool SubsetRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
...@@ -208,9 +208,8 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() { ...@@ -208,9 +208,8 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
} }
// WeightedRandomSampler // WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples, WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
bool replacement) : weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
bool WeightedRandomSamplerObj::ValidateParams() { bool WeightedRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) { if (num_samples_ < 0) {
......
...@@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0, ...@@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0,
/// \param[in] indices - A vector sequence of indices. /// \param[in] indices - A vector sequence of indices.
/// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \return Shared pointer to the current Sampler. /// \return Shared pointer to the current Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
int64_t num_samples = 0);
/// Function to create a Weighted Random Sampler. /// Function to create a Weighted Random Sampler.
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given /// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
...@@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in ...@@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
/// \param[in] num_samples - The number of samples to draw (default to all elements). /// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] replacement - If True, put the sample ID back for the next draw. /// \param[in] replacement - If True, put the sample ID back for the next draw.
/// \return Shared pointer to the current Sampler. /// \return Shared pointer to the current Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
int64_t num_samples = 0, bool replacement = true); bool replacement = true);
/* ####################################### Derived Sampler classes ################################# */ /* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj { class DistributedSamplerObj : public SamplerObj {
...@@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj { ...@@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj {
class SubsetRandomSamplerObj : public SamplerObj { class SubsetRandomSamplerObj : public SamplerObj {
public: public:
SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples); SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetRandomSamplerObj() = default; ~SubsetRandomSamplerObj() = default;
...@@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj { ...@@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
const std::vector<int64_t> &indices_; const std::vector<int64_t> indices_;
int64_t num_samples_; int64_t num_samples_;
}; };
class WeightedRandomSamplerObj : public SamplerObj { class WeightedRandomSamplerObj : public SamplerObj {
public: public:
explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0, explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
bool replacement = true);
~WeightedRandomSamplerObj() = default; ~WeightedRandomSamplerObj() = default;
...@@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj { ...@@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
bool ValidateParams() override; bool ValidateParams() override;
private: private:
const std::vector<double> &weights_; const std::vector<double> weights_;
int64_t num_samples_; int64_t num_samples_;
bool replacement_; bool replacement_;
}; };
......
...@@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) { ...@@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
iter->Stop(); iter->Stop();
} }
TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
EXPECT_FALSE(indices.empty());
EXPECT_NE(sampl1->Build(), nullptr);
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
}
TEST_F(MindDataTestPipeline, TestPad) { TEST_F(MindDataTestPipeline, TestPad) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad."; MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册