// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/distributed/index_dataset/index_wrapper.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace distributed { class IndexSampler { public: virtual ~IndexSampler() {} IndexSampler() {} template static std::shared_ptr Init(const std::string& name) { std::shared_ptr instance = nullptr; instance.reset(new T(name)); return instance; } virtual void init_layerwise_conf(const std::vector& layer_sample_counts, int start_sample_layer = 1, int seed = 0) {} virtual void init_beamsearch_conf(const int64_t k) {} virtual std::vector> sample( const std::vector>& user_inputs, const std::vector& input_targets, bool with_hierarchy = false) = 0; }; class LayerWiseSampler : public IndexSampler { public: virtual ~LayerWiseSampler() {} explicit LayerWiseSampler(const std::string& name) { tree_ = IndexWrapper::GetInstance()->get_tree_index(name); } void init_layerwise_conf(const std::vector& layer_sample_counts, int start_sample_layer, int seed) override { seed_ = seed; start_sample_layer_ = start_sample_layer; PADDLE_ENFORCE_GT( start_sample_layer_, 0, paddle::platform::errors::InvalidArgument( "start sampler layer = [%d], it should greater than 0.", start_sample_layer_)); PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(), paddle::platform::errors::InvalidArgument( "start sampler layer = [%d], it should less than " "max_layer, which is [%d].", start_sample_layer_, tree_->Height())); size_t i = 0; layer_counts_sum_ = 0; layer_counts_.clear(); int cur_layer = start_sample_layer_; while (cur_layer < tree_->Height()) { int layer_sample_num = 1; if (i < layer_sample_counts.size()) { layer_sample_num = layer_sample_counts[i]; } layer_counts_sum_ += layer_sample_num + 1; layer_counts_.push_back(layer_sample_num); VLOG(3) << "[INFO] level " << cur_layer << " sample_layer_counts.push_back: " << layer_sample_num; cur_layer += 1; i += 1; } reverse(layer_counts_.begin(), layer_counts_.end()); VLOG(3) << "sample counts sum: " << layer_counts_sum_; } std::vector> sample( const std::vector>& user_inputs, const std::vector& target_ids, bool with_hierarchy) override; private: std::vector layer_counts_; int64_t layer_counts_sum_{0}; std::shared_ptr tree_{nullptr}; int seed_{0}; int start_sample_layer_{1}; }; } // end namespace distributed } // end namespace paddle