multinomial_op.cu 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/transform.h>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {
namespace operators {

template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
                                     T* sum_rows) {
32 33 34 35 36 37 38 39 40 41 42 43 44 45
  int id = threadIdx.x + blockIdx.x * blockDim.x +
           blockIdx.y * gridDim.x * blockDim.x;
  norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
}

template <typename T>
__global__ void Cumsum(T* norm_probs_data, int64_t num_distributions,
                       int64_t num_categories, T* cumulative_probs) {
  for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
    thrust::inclusive_scan(thrust::device,
                           norm_probs_data + id * num_categories,
                           norm_probs_data + (id + 1) * num_categories,
                           cumulative_probs + id * num_categories);
  }
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
}

template <typename T>
struct RandomGeneratorCudaFunctor {
  unsigned int seed_;
  __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {}

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed_);
    thrust::uniform_real_distribution<T> dist(0.0, 1.0);
    rng.discard(n);
    return dist(rng);
  }
};

template <typename T>
P
pangyoki 已提交
63 64 65
__device__ int binarySearchFunctor(T* cumdist, T* dist, int size, T val) {
  int left = 0;
  int right = size;
66 67 68
  // cumdist[size - 1] = 0 => all zero prob dist
  // CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0));

P
pangyoki 已提交
69 70
  while (right - left > 0) {
    int mid = left + (right - left) / 2;
71 72 73

    T midVal = cumdist[mid];
    if (midVal < val) {
P
pangyoki 已提交
74
      left = mid + 1;
75
    } else {
P
pangyoki 已提交
76
      right = mid;
77 78 79
    }
  }

P
pangyoki 已提交
80
  if (left == size) {
81
    // No probability mass or precision problems; just return the
P
pangyoki 已提交
82
    // first non-zero element by setting left to size-1 here,
83 84 85
    // the code below will move it to the last non-zero probability
    // this actually can happen when the random number is 1
    // (github pytorch issue #4858).
P
pangyoki 已提交
86
    left = size - 1;
87 88
  }

P
pangyoki 已提交
89
  while (left >= 1 && dist[left] == 0) left--;
90

P
pangyoki 已提交
91
  return left;
92 93 94 95
}

template <typename T>
__global__ void sampleMultinomialWithReplacement(
P
pangyoki 已提交
96 97 98
    T* rng_data, const int64_t num_samples, T* out_data,
    const int64_t num_distributions, const int64_t num_categories,
    T* cumulative_probs, T* norm_probs_data) {
99 100 101 102 103 104 105 106
  // At the moment, each warp computes one sample value in the binary
  // search due to divergence. It seems possible to compute multiple
  // values and limit divergence though later on.

  // global index formula for 2D grid of 1D blocks
  // int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x +
  // threadIdx.x;

107
  // int idx = blockIdx.x * blockDim.x + threadIdx.x;
108

109 110
  int idx = threadIdx.x + blockIdx.x * blockDim.x +
            blockIdx.y * gridDim.x * blockDim.x;
111

P
pangyoki 已提交
112
  for (int curDist = blockIdx.y; curDist < num_distributions;
113 114
       curDist += gridDim.y) {
    for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
P
pangyoki 已提交
115
         sample < num_samples; sample += blockDim.x * gridDim.x) {
116 117
      // we are losing 3 out of 4 generated numbers but it's ok
      // this kernel is not very efficient anyway
118

119
      // T uniform_random = dist(rng);
P
pangyoki 已提交
120
      T uniform_random = rng_data[sample + curDist * num_samples];
121 122

      // Find the bucket that a uniform sample lies in
P
pangyoki 已提交
123 124 125 126
      int choice =
          binarySearchFunctor<T>(cumulative_probs + curDist * num_categories,
                                 norm_probs_data + curDist * num_categories,
                                 num_categories, uniform_random);
127

P
pangyoki 已提交
128
      out_data[sample + curDist * num_samples] = choice;
129
    }
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
  }
}

template <typename T>
class MultinomialOpKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto x = ctx.Input<framework::Tensor>("X");
    auto out = ctx.Output<framework::Tensor>("Out");

    const int64_t num_samples = ctx.Attr<int>("num_samples");
    const bool replacement = ctx.Attr<bool>("replacement");

    auto* in_data = x->data<T>();
    auto* out_data = out->mutable_data<T>(ctx.GetPlace());

    auto in_dims = x->dims();
    int64_t in_rank = in_dims.size();
    const int64_t num_categories = in_dims[in_rank - 1];
    const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    if (!replacement) {
      int in_data_numel = x->numel();
      int out_data_numel = out->numel();

      T* cpu_in_data = new T[in_data_numel];
      T* cpu_out_data = new T[out_data_numel];

      cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
                 cudaMemcpyDeviceToHost);

      MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
                            num_categories, num_distributions);
      cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(T),
                 cudaMemcpyHostToDevice);

      delete[] cpu_in_data;
      delete[] cpu_out_data;
      return;
    }

172
    framework::Tensor sum_rows_t;
173 174
    auto* sum_rows_data =
        sum_rows_t.mutable_data<T>({num_distributions}, ctx.GetPlace());
175 176 177 178

    auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
                       .eigen_device();

179 180 181 182 183 184 185 186 187 188 189 190
    if (num_distributions == 1) {
      auto eigen_input = framework::EigenVector<T>::Flatten(*x);
      auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
      eigen_sum_rows.device(place) =
          eigen_input.sum(Eigen::DSizes<int, 1>(1))
              .eval()
              .reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0]));
    } else {
      auto eigen_input = framework::EigenMatrix<T>::From(*x);
      auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t);
      eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
    }
191 192

    framework::Tensor norm_probs_t;
193 194 195 196 197
    auto* norm_probs_data = norm_probs_t.mutable_data<T>(
        {num_distributions, num_categories}, ctx.GetPlace());

    dim3 block(num_categories < 512 ? num_categories : 512);
    dim3 grid((num_categories - 1) / block.x + 1, num_distributions);
198 199 200 201 202
    NormalizeProbability<
        T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
        norm_probs_data, in_data, sum_rows_data);

    framework::Tensor cumulative_probs_t;
203 204 205 206 207 208 209 210
    auto* cumulative_probs = cumulative_probs_t.mutable_data<T>(
        {num_distributions, num_categories}, ctx.GetPlace());
    dim3 block1(1);
    dim3 grid1(num_distributions);
    Cumsum<T><<<grid1, block1, 0, ctx.cuda_device_context().stream()>>>(
        norm_probs_data, num_distributions, num_categories, cumulative_probs);

    VLOG(3) << "Print cumsum " << cumulative_probs << "\n";
211 212 213 214

    if (replacement) {
      dim3 block(128);
      // int grid_y = 1;
215
      dim3 grid((num_samples - 1) / block.x + 1, num_distributions);
216 217 218 219 220

      std::random_device rd;
      auto seed = rd();

      framework::Tensor rng_data_t;
221 222
      auto* rng_data = rng_data_t.mutable_data<T>(
          {num_distributions, num_samples}, ctx.GetPlace());
223 224 225 226 227

      thrust::counting_iterator<unsigned int> index_sequence_begin(0);
      platform::Transform<platform::CUDADeviceContext> trans;
      auto* context = static_cast<const platform::CUDADeviceContext*>(
          &ctx.device_context());
228 229 230
      trans(*context, index_sequence_begin,
            index_sequence_begin + num_distributions * num_samples, rng_data,
            RandomGeneratorCudaFunctor<T>(seed));
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248

      sampleMultinomialWithReplacement<
          T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
          rng_data, num_samples, out_data, num_distributions, num_categories,
          cumulative_probs, norm_probs_data);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
    multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>,
    ops::MultinomialOpKernel<plat::CUDADeviceContext, double>);