// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/distributed/index_dataset/index_sampler.h" namespace paddle { namespace distributed { std::vector> LayerWiseSampler::sample( const std::vector>& user_inputs, const std::vector& target_ids, bool with_hierarchy) { auto input_num = target_ids.size(); auto user_feature_num = user_inputs[0].size(); std::vector> outputs( input_num * layer_counts_sum_, std::vector(user_feature_num + 2)); auto max_layer = tree_->Height(); size_t idx = 0; for (size_t i = 0; i < input_num; i++) { auto travel_codes = tree_->GetTravelCodes(target_ids[i], start_sample_layer_); auto travel_path = tree_->GetNodes(travel_codes); for (size_t j = 0; j < travel_path.size(); j++) { // user if (j > 0 && with_hierarchy) { auto ancestor_codes = tree_->GetAncestorCodes(user_inputs[i], max_layer - j - 1); auto hierarchical_user = tree_->GetNodes(ancestor_codes); for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) { for (size_t k = 0; k < user_feature_num; k++) { outputs[idx + idx_offset][k] = hierarchical_user[k].id(); } } } else { for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) { for (size_t k = 0; k < user_feature_num; k++) { outputs[idx + idx_offset][k] = user_inputs[i][k]; } } } // sampler ++ outputs[idx][user_feature_num] = travel_path[j].id(); outputs[idx][user_feature_num + 1] = 1.0; idx += 1; for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) { int sample_res = 0; do { sample_res = sampler_vec_[j]->Sample(); } while (layer_ids_[j][sample_res].id() == travel_path[j].id()); outputs[idx + idx_offset][user_feature_num] = layer_ids_[j][sample_res].id(); outputs[idx + idx_offset][user_feature_num + 1] = 0; } idx += layer_counts_[j]; } } return outputs; } } // end namespace distributed } // end namespace paddle