diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.cc b/paddle/fluid/distributed/index_dataset/index_sampler.cc index 58f85d98fb09c6576daa0816be2d58c90c5a8a42..3e573bbdd2de97130a109ddb583a724cf363c6be 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.cc +++ b/paddle/fluid/distributed/index_dataset/index_sampler.cc @@ -13,13 +13,10 @@ // limitations under the License. #include "paddle/fluid/distributed/index_dataset/index_sampler.h" -#include "paddle/fluid/operators/math/sampler.h" namespace paddle { namespace distributed { -using Sampler = paddle::operators::math::Sampler; - std::vector> LayerWiseSampler::sample( const std::vector>& user_inputs, const std::vector& target_ids, bool with_hierarchy) { @@ -30,22 +27,7 @@ std::vector> LayerWiseSampler::sample( std::vector(user_feature_num + 2)); auto max_layer = tree_->Height(); - std::vector sampler_vec(max_layer - start_sample_layer_); - std::vector> layer_ids(max_layer - - start_sample_layer_); - - auto layer_index = max_layer - 1; size_t idx = 0; - while (layer_index >= start_sample_layer_) { - auto layer_codes = tree_->GetLayerCodes(layer_index); - layer_ids[idx] = tree_->GetNodes(layer_codes); - sampler_vec[idx] = new paddle::operators::math::UniformSampler( - layer_ids[idx].size() - 1, seed_); - layer_index--; - idx++; - } - - idx = 0; for (size_t i = 0; i < input_num; i++) { auto travel_codes = tree_->GetTravelCodes(target_ids[i], start_sample_layer_); @@ -76,18 +58,15 @@ std::vector> LayerWiseSampler::sample( for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) { int sample_res = 0; do { - sample_res = sampler_vec[j]->Sample(); - } while (layer_ids[j][sample_res].id() == travel_path[j].id()); + sample_res = sampler_vec_[j]->Sample(); + } while (layer_ids_[j][sample_res].id() == travel_path[j].id()); outputs[idx + idx_offset][user_feature_num] = - layer_ids[j][sample_res].id(); + layer_ids_[j][sample_res].id(); outputs[idx + idx_offset][user_feature_num + 1] = 0; } idx += layer_counts_[j]; } } - for (size_t i = 0; i < sampler_vec.size(); i++) { - delete sampler_vec[i]; - } return outputs; } diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.h b/paddle/fluid/distributed/index_dataset/index_sampler.h index 66882bedc9b76593b9b28f184fc26ff4897494e6..8813421446a21c1379ca872952fe8b367d0724ca 100644 --- a/paddle/fluid/distributed/index_dataset/index_sampler.h +++ b/paddle/fluid/distributed/index_dataset/index_sampler.h @@ -16,6 +16,7 @@ #include #include "paddle/fluid/distributed/index_dataset/index_wrapper.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -83,6 +84,23 @@ class LayerWiseSampler : public IndexSampler { } reverse(layer_counts_.begin(), layer_counts_.end()); VLOG(3) << "sample counts sum: " << layer_counts_sum_; + + auto max_layer = tree_->Height(); + sampler_vec_.clear(); + layer_ids_.clear(); + + auto layer_index = max_layer - 1; + size_t idx = 0; + while (layer_index >= start_sample_layer_) { + auto layer_codes = tree_->GetLayerCodes(layer_index); + layer_ids_.push_back(tree_->GetNodes(layer_codes)); + auto sampler_temp = + std::make_shared( + layer_ids_[idx].size() - 1, seed_); + sampler_vec_.push_back(sampler_temp); + layer_index--; + idx++; + } } std::vector> sample( const std::vector>& user_inputs, @@ -94,6 +112,8 @@ class LayerWiseSampler : public IndexSampler { std::shared_ptr tree_{nullptr}; int seed_{0}; int start_sample_layer_{1}; + std::vector> sampler_vec_; + std::vector> layer_ids_; }; } // end namespace distributed