index_sampler.h 3.5 KB
Newer Older
1
123malin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>
#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace distributed {

class IndexSampler {
 public:
  virtual ~IndexSampler() {}
  IndexSampler() {}

  template <typename T>
  static std::shared_ptr<IndexSampler> Init(const std::string& name) {
    std::shared_ptr<IndexSampler> instance = nullptr;
    instance.reset(new T(name));
    return instance;
  }

  virtual void init_layerwise_conf(const std::vector<int>& layer_sample_counts,
                                   int start_sample_layer = 1, int seed = 0) {}
  virtual void init_beamsearch_conf(const int64_t k) {}
  virtual std::vector<std::vector<uint64_t>> sample(
      const std::vector<std::vector<uint64_t>>& user_inputs,
      const std::vector<uint64_t>& input_targets,
      bool with_hierarchy = false) = 0;
};

class LayerWiseSampler : public IndexSampler {
 public:
  virtual ~LayerWiseSampler() {}
  explicit LayerWiseSampler(const std::string& name) {
    tree_ = IndexWrapper::GetInstance()->get_tree_index(name);
  }

  void init_layerwise_conf(const std::vector<int>& layer_sample_counts,
                           int start_sample_layer, int seed) override {
    seed_ = seed;
    start_sample_layer_ = start_sample_layer;

    PADDLE_ENFORCE_GT(
        start_sample_layer_, 0,
        paddle::platform::errors::InvalidArgument(
            "start sampler layer = [%d], it should greater than 0.",
            start_sample_layer_));
    PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(),
                      paddle::platform::errors::InvalidArgument(
                          "start sampler layer = [%d], it should less than "
                          "max_layer, which is [%d].",
                          start_sample_layer_, tree_->Height()));

    size_t i = 0;
    layer_counts_sum_ = 0;
    layer_counts_.clear();
    int cur_layer = start_sample_layer_;
    while (cur_layer < tree_->Height()) {
      int layer_sample_num = 1;
      if (i < layer_sample_counts.size()) {
        layer_sample_num = layer_sample_counts[i];
      }
      layer_counts_sum_ += layer_sample_num + 1;
      layer_counts_.push_back(layer_sample_num);
      VLOG(3) << "[INFO] level " << cur_layer
              << " sample_layer_counts.push_back: " << layer_sample_num;
      cur_layer += 1;
      i += 1;
    }
    reverse(layer_counts_.begin(), layer_counts_.end());
    VLOG(3) << "sample counts sum: " << layer_counts_sum_;
  }
  std::vector<std::vector<uint64_t>> sample(
      const std::vector<std::vector<uint64_t>>& user_inputs,
      const std::vector<uint64_t>& target_ids, bool with_hierarchy) override;

 private:
  std::vector<int> layer_counts_;
  int64_t layer_counts_sum_{0};
  std::shared_ptr<TreeIndex> tree_{nullptr};
  int seed_{0};
  int start_sample_layer_{1};
};

}  // end namespace distributed
}  // end namespace paddle