未验证 提交 5ada0329 编写于 作者: 1 123malin 提交者: GitHub

test=develop, optimize index_sampler (#32663)

上级 8fd724a5
......@@ -13,13 +13,10 @@
// limitations under the License.
#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
#include "paddle/fluid/operators/math/sampler.h"
namespace paddle {
namespace distributed {
using Sampler = paddle::operators::math::Sampler;
std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
const std::vector<uint64_t>& target_ids, bool with_hierarchy) {
......@@ -30,22 +27,7 @@ std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
std::vector<uint64_t>(user_feature_num + 2));
auto max_layer = tree_->Height();
std::vector<Sampler*> sampler_vec(max_layer - start_sample_layer_);
std::vector<std::vector<IndexNode>> layer_ids(max_layer -
start_sample_layer_);
auto layer_index = max_layer - 1;
size_t idx = 0;
while (layer_index >= start_sample_layer_) {
auto layer_codes = tree_->GetLayerCodes(layer_index);
layer_ids[idx] = tree_->GetNodes(layer_codes);
sampler_vec[idx] = new paddle::operators::math::UniformSampler(
layer_ids[idx].size() - 1, seed_);
layer_index--;
idx++;
}
idx = 0;
for (size_t i = 0; i < input_num; i++) {
auto travel_codes =
tree_->GetTravelCodes(target_ids[i], start_sample_layer_);
......@@ -76,18 +58,15 @@ std::vector<std::vector<uint64_t>> LayerWiseSampler::sample(
for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
int sample_res = 0;
do {
sample_res = sampler_vec[j]->Sample();
} while (layer_ids[j][sample_res].id() == travel_path[j].id());
sample_res = sampler_vec_[j]->Sample();
} while (layer_ids_[j][sample_res].id() == travel_path[j].id());
outputs[idx + idx_offset][user_feature_num] =
layer_ids[j][sample_res].id();
layer_ids_[j][sample_res].id();
outputs[idx + idx_offset][user_feature_num + 1] = 0;
}
idx += layer_counts_[j];
}
}
for (size_t i = 0; i < sampler_vec.size(); i++) {
delete sampler_vec[i];
}
return outputs;
}
......
......@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/math/sampler.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
......@@ -83,6 +84,23 @@ class LayerWiseSampler : public IndexSampler {
}
reverse(layer_counts_.begin(), layer_counts_.end());
VLOG(3) << "sample counts sum: " << layer_counts_sum_;
auto max_layer = tree_->Height();
sampler_vec_.clear();
layer_ids_.clear();
auto layer_index = max_layer - 1;
size_t idx = 0;
while (layer_index >= start_sample_layer_) {
auto layer_codes = tree_->GetLayerCodes(layer_index);
layer_ids_.push_back(tree_->GetNodes(layer_codes));
auto sampler_temp =
std::make_shared<paddle::operators::math::UniformSampler>(
layer_ids_[idx].size() - 1, seed_);
sampler_vec_.push_back(sampler_temp);
layer_index--;
idx++;
}
}
std::vector<std::vector<uint64_t>> sample(
const std::vector<std::vector<uint64_t>>& user_inputs,
......@@ -94,6 +112,8 @@ class LayerWiseSampler : public IndexSampler {
std::shared_ptr<TreeIndex> tree_{nullptr};
int seed_{0};
int start_sample_layer_{1};
std::vector<std::shared_ptr<paddle::operators::math::Sampler>> sampler_vec_;
std::vector<std::vector<IndexNode>> layer_ids_;
};
} // end namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册