提交 c65670b1 编写于 作者: E ervinzhang

changed sampler parameter to support std::move

上级 6aa65da5
......@@ -71,8 +71,8 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int
}
/// Function to create a Subset Random Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples) {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), num_samples);
// Input validation
if (!sampler->ValidateParams()) {
return nullptr;
......@@ -81,9 +81,9 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
}
/// Function to create a Weighted Random Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples,
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples,
bool replacement) {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
auto sampler = std::make_shared<WeightedRandomSamplerObj>(std::move(weights), num_samples, replacement);
// Input validation
if (!sampler->ValidateParams()) {
return nullptr;
......@@ -190,8 +190,8 @@ std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples)
: indices_(indices), num_samples_(num_samples) {}
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}
bool SubsetRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
......@@ -208,9 +208,8 @@ std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples,
bool replacement)
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}
bool WeightedRandomSamplerObj::ValidateParams() {
if (num_samples_ < 0) {
......
......@@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0,
/// \param[in] indices - A vector sequence of indices.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \return Shared pointer to the current Sampler.
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices,
int64_t num_samples = 0);
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
/// Function to create a Weighted Random Sampler.
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
......@@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] replacement - If True, put the sample ID back for the next draw.
/// \return Shared pointer to the current Sampler.
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights,
int64_t num_samples = 0, bool replacement = true);
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
bool replacement = true);
/* ####################################### Derived Sampler classes ################################# */
class DistributedSamplerObj : public SamplerObj {
......@@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj {
class SubsetRandomSamplerObj : public SamplerObj {
public:
SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples);
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
~SubsetRandomSamplerObj() = default;
......@@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj {
bool ValidateParams() override;
private:
const std::vector<int64_t> &indices_;
const std::vector<int64_t> indices_;
int64_t num_samples_;
};
class WeightedRandomSamplerObj : public SamplerObj {
public:
explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0,
bool replacement = true);
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
~WeightedRandomSamplerObj() = default;
......@@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
bool ValidateParams() override;
private:
const std::vector<double> &weights_;
const std::vector<double> weights_;
int64_t num_samples_;
bool replacement_;
};
......
......@@ -369,6 +369,16 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
EXPECT_FALSE(indices.empty());
EXPECT_NE(sampl1->Build(), nullptr);
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->Build(), nullptr);
}
TEST_F(MindDataTestPipeline, TestPad) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册