sample_prob.h 4.0 KB
Newer Older
X
xuezhong 已提交
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
X
xuezhong 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <iostream>
#include <unordered_set>
#include <vector>
W
wanghuancoder 已提交
19

X
xuezhong 已提交
20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/math/sampler.h"
23
#include "paddle/phi/core/ddim.h"
24

X
xuezhong 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
namespace paddle {
namespace operators {
namespace math {

using Tensor = framework::Tensor;

/* UNDERSTAND: utility function to adjust probability for unique sampling,
return whatever as it is if not using unique samping */
template <typename T>
static T adjust_prob(const T prob, const int num_samples, const int num_tries) {
  if (num_samples == num_tries) {
    return prob * num_samples;
  } else {
    return -expm1(num_tries * log1p(-prob));
  }
}

template <typename DeviceContext, typename T>
class SampleWithProb {
 public:
45 46 47 48 49
  void operator()(const DeviceContext& context,
                  const Sampler& sampler,
                  const std::size_t num_samples,
                  const Tensor* L,
                  Tensor* S,
X
xuezhong 已提交
50 51
                  Tensor* P) {
    // UNDERSTAND: dimension issues
52
    const auto& lbl_dim = L->dims();
X
xuezhong 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    const int batch_size = lbl_dim[0];
    const int num_true = lbl_dim[1];
    const int num_sampled_classes = num_true + num_samples;
    framework::DDim ret_dim{batch_size, num_sampled_classes};

    // UNDERSTAND: raw data view
    const int64_t* label_data = L->data<int64_t>();
    int64_t* samples_data =
        S->mutable_data<int64_t>(ret_dim, context.GetPlace());
    T* probabilities_data = P->mutable_data<T>(ret_dim, context.GetPlace());

    // temp sets for unique sampling
    std::unordered_set<int64_t> tmp_samples;
    int j = 0;  // column index
    // add true labels, not that efficient
    while (j < num_true) {
      for (int i = 0; i < batch_size; ++i) {
        auto samples_index = i * num_sampled_classes + j;
        auto v = label_data[i * num_true + j];
        samples_data[samples_index] = v;
        probabilities_data[samples_index] = sampler.Probability(v);
      }
      ++j;
    }

    // sample num_samles unique samples for an example, note that they are not
    // all negative samples
    tmp_samples.clear();
    int num_tries = 0;
    while (j < num_sampled_classes) {
      ++num_tries;
      auto v = sampler.Sample();
      auto insert_ok = tmp_samples.insert(v).second;
      if (!insert_ok) {
        continue;
      }
      auto p = sampler.Probability(v);
      for (int i = 0; i < batch_size; ++i) {
        auto samples_index = i * num_sampled_classes + j;
        samples_data[samples_index] = v;
        probabilities_data[samples_index] = p;
      }
      ++j;
    }

    // compute Q(y|x), because of unique sampling, probabilities need to be
    // adjusted
    for (int k = 0; k < num_sampled_classes; ++k) {
      for (int i = 0; i < batch_size; ++i) {
        auto samples_index = i * num_sampled_classes + k;
        probabilities_data[samples_index] = adjust_prob(
            probabilities_data[samples_index], num_samples, num_tries);
      }
    }
  }
};

110
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
X
xuezhong 已提交
111 112 113
template <typename T>
class GPUSampleWithProb {
 public:
114 115 116 117 118 119 120
  void operator()(const platform::CUDADeviceContext& context,
                  const int seed,
                  const int dict_size,
                  const bool uniq,
                  const std::size_t num_samples,
                  const Tensor* L,
                  Tensor* S,
X
xuezhong 已提交
121 122 123 124 125 126
                  Tensor* P);
};
#endif
}  // namespace math
}  // namespace operators
}  // namespace paddle