multinomial_kernel.cu 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202

#include "paddle/phi/kernels/multinomial_kernel.h"

21 22 23 24 25 26 27 28
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

29
#include "paddle/phi/backends/gpu/gpu_context.h"
30
#include "paddle/phi/common/data_type.h"
31 32
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
33
#include "paddle/phi/core/kernel_registry.h"
34 35 36
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
37
#include "paddle/phi/kernels/funcs/eigen/common.h"
38 39
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
40
#include "paddle/phi/kernels/funcs/multinomial_functor.h"
41 42
#include "paddle/phi/kernels/top_k_kernel.h"

43 44
namespace phi {

45 46
template <typename T, typename MT>
__global__ void NormalizeProbability(MT* norm_probs,
47
                                     const T* in_data,
48
                                     MT* sum_rows,
49 50 51 52 53 54
                                     int64_t num_distributions,
                                     int64_t num_categories) {
  int id = threadIdx.x + blockIdx.x * blockDim.x +
           blockIdx.y * gridDim.x * blockDim.x;
  if (id < num_distributions * num_categories) {
    PADDLE_ENFORCE(
55
        static_cast<MT>(in_data[id]) >= 0.0,
56
        "The input of multinomial distribution should be >= 0, but got %f.",
57
        static_cast<MT>(in_data[id]));
58 59 60 61 62
    int64_t row_id = id / num_categories;
    PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
                   "The sum of one multinomial distribution probability should "
                   "be > 0, but got %f.",
                   sum_rows[row_id]);
63
    norm_probs[id] = static_cast<MT>(in_data[id]) / sum_rows[row_id];
64 65 66 67
  }
}

template <typename T>
68
__device__ int binarySearchFunctor(T* cumulative_probs_data,
69 70 71 72 73 74 75 76 77
                                   T* norm_probs_data,
                                   int num_categories,
                                   T rng_number) {
  int left = 0;
  int right = num_categories;

  while (right - left > 0) {
    int mid = left + (right - left) / 2;

78
    T temp_prob = cumulative_probs_data[mid];
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    if (temp_prob < rng_number) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }

  if (left == num_categories) {
    left = num_categories - 1;
  }

  while (left >= 1 && norm_probs_data[left] == 0) left--;

  return left;
}

template <typename T>
__global__ void sampleMultinomialWithReplacement(
    const int64_t num_samples,
    int64_t* out_data,
    const int64_t num_distributions,
    const int64_t num_categories,
101 102 103
    T* cumulative_probs_data,
    T* norm_probs_data,
    uint64_t seed,
104
    uint64_t offset) {
105
  // use binary search to get the selected category sample id.
106
  // let cumulative_probs_data[id-1] < rng_number < cumulative_probs_data[id].
107 108
  size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
               threadIdx.x;
109

110 111
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, offset, &state);
112

113 114 115
  int sample = blockIdx.x * blockDim.x + threadIdx.x;
  for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
    if (sample < num_samples) {
116
      T rng_number = static_cast<T>(curand_uniform4(&state).x);
117 118 119 120 121 122
      // Find the bucket that a uniform random number lies in
      int selected_category =
          binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
                                 norm_probs_data + dist * num_categories,
                                 num_categories,
                                 rng_number);
123

124 125
      out_data[sample + dist * num_samples] = selected_category;
    }
126 127 128 129 130 131
  }
}

template <typename T, typename Context>
void MultinomialKernel(const Context& dev_ctx,
                       const DenseTensor& x,
132
                       const Scalar& num_samples,
133 134
                       bool replacement,
                       DenseTensor* out) {
135 136
  using MT = typename kps::details::MPTypeTrait<T>::Type;

137
  auto int_num_samples = num_samples.to<int>();
138 139 140
  auto* in_data = x.data<T>();
  int64_t* out_data = dev_ctx.template Alloc<int64_t>(out);
  auto in_dims = x.dims();
141 142 143
  int64_t dim_size = in_dims.size();
  const int64_t num_categories = in_dims[dim_size - 1];
  const int64_t num_distributions = dim_size > 1 ? in_dims[dim_size - 2] : 1;
144
  // If replacement is False, it's not a replaceable sample. Every category
145
  // can be used only once.
146 147 148 149
  if (!replacement) {
    int64_t in_data_numel = x.numel();
    int64_t out_data_numel = out->numel();

150 151 152
    phi::DenseTensor cpu_tensor;
    phi::Copy<Context>(dev_ctx, x, phi::CPUPlace(), false, &cpu_tensor);
    T* cpu_in_data = cpu_tensor.data<T>();
153 154 155
    for (size_t i = 0; i < num_distributions; ++i) {
      int zero_num = 0;
      for (size_t j = 0; j < num_categories; ++j) {
156
        T weight = cpu_in_data[i * num_categories + j];
157
        PADDLE_ENFORCE_GE(
158
            static_cast<MT>(weight),
159 160 161
            0,
            errors::InvalidArgument(
                "Each element of multinomial'input must >= 0, but got %f.",
162
                static_cast<MT>(weight)));
163 164
        if (weight == static_cast<T>(0)) {
          zero_num++;
165 166
        }
      }
167 168
      int valid_samples = num_categories - zero_num;
      PADDLE_ENFORCE_LE(
169
          int_num_samples,
170 171 172 173
          valid_samples,
          errors::InvalidArgument("When replacement=False, 'num_samples' "
                                  "must less than or eaqual to the number of "
                                  "positive item of input"));
174
    }
175

176 177 178
    // Refer to [gumbel softmax algorithm]
    DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x);
    T* rand_data = rand.data<T>();
179 180
    funcs::uniform_distribution<MT> dist;
    funcs::exponential_transform<MT> trans(1.0);
181 182 183 184 185 186 187
    funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans);

    funcs::ForRange<Context> for_range(dev_ctx, x.numel());
    for_range([rand_data, in_data] __device__(size_t idx) {
      rand_data[idx] = in_data[idx] / rand_data[idx];
    });

188
    if (int_num_samples == 1) {
189 190 191 192 193 194
      ArgMaxKernel<T, Context>(
          dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
    } else {
      std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
      DenseTensor value = Empty<T, Context>(dev_ctx, IntArray(out_dim_vec));
      TopkKernel<T, Context>(
195
          dev_ctx, rand, num_samples, -1, true, true, &value, out);
196
    }
197 198 199 200 201 202 203 204
    return;
  }

  // Sum of input may not be 1. To get probability in range [0, 1], calculate
  // sum of each row of input, and then use the sum to normalize the input.
  // sum_row_data: sum of each row
  DenseTensor sum_rows_tensor;
  sum_rows_tensor.Resize({num_distributions});
205
  auto* sum_rows_data = dev_ctx.template Alloc<MT>(&sum_rows_tensor);
206 207 208 209
  auto& place = *dev_ctx.eigen_device();

  if (num_distributions == 1) {
    auto eigen_input = EigenVector<T>::Flatten(x);
210
    auto eigen_sum_rows = EigenVector<MT>::Flatten(sum_rows_tensor);
211 212
    eigen_sum_rows.device(place) =
        eigen_input.sum(Eigen::DSizes<int, 1>(1))
213
            .template cast<MT>()
214
            .eval()
215 216 217
            .template cast<MT>()
            .reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]))
            .template cast<MT>();
218 219
  } else {
    auto eigen_input = EigenMatrix<T>::From(x);
220 221 222
    auto eigen_sum_rows = EigenVector<MT>::Flatten(sum_rows_tensor);
    eigen_sum_rows.device(place) =
        eigen_input.sum(Eigen::DSizes<int, 1>(1)).template cast<MT>();
223 224 225 226 227 228
  }
  // Normalize row of each distribution to get the probability in range [0,
  // 1].
  // norm_probs_data: probability of the distribution
  DenseTensor norm_probs_tensor;
  norm_probs_tensor.Resize({num_distributions, num_categories});
229
  auto* norm_probs_data = dev_ctx.template Alloc<MT>(&norm_probs_tensor);
230
  // number of threads in a block is min(num_categories, 512)
231 232
  int block_size = num_categories < 512 ? num_categories : 512;
  dim3 block_norm(block_size);
233
  dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
234 235

  NormalizeProbability<T, MT>
236 237 238 239 240
      <<<grid_norm, block_norm, 0, dev_ctx.stream()>>>(norm_probs_data,
                                                       in_data,
                                                       sum_rows_data,
                                                       num_distributions,
                                                       num_categories);
241
  // Get cumulative probability of each distribution. It's the same function
242
  // of ``cumsum`` op.
243 244
  DenseTensor cumulative_probs_tensor;
  cumulative_probs_tensor.Resize({num_distributions, num_categories});
245
  auto* cumulative_probs_data =
246
      dev_ctx.template Alloc<MT>(&cumulative_probs_tensor);
247 248
  // 'phi::funcs::InclusiveScan' has higher accuracy than
  // 'thrust::inclusive_scan'
249
  funcs::InclusiveScan<MT, std::plus<MT>>(
250 251 252 253 254 255
      /*in*/ norm_probs_data,
      /*out*/ cumulative_probs_data,
      /*outer_dim*/ static_cast<size_t>(num_distributions),
      /*mid_dim*/ static_cast<size_t>(num_categories),
      /*inner_dim*/ static_cast<size_t>(1),
      /*init*/ static_cast<T>(0),
256
      std::plus<MT>(),
257 258
      /*reverse=*/false,
      dev_ctx);
259
  // Sample the multinomial distributions.
260 261 262 263
  dim3 block(128);
  int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
  const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
  int grid_y = std::min<int64_t>(num_distributions, prop.maxGridSize[1]);
264
  dim3 grid((int_num_samples - 1) / block.x + 1, grid_y);
265 266 267 268 269 270 271 272

  auto gen_cuda = dev_ctx.GetGenerator();
  size_t curand4_loop_times =
      (num_distributions + 4 * grid_y - 1) / (4 * grid_y);
  // 'increment' shoulde be multiple of 4
  uint64_t increment = curand4_loop_times * 4;
  auto seed_offset = gen_cuda->IncrementOffset(increment);

273
  sampleMultinomialWithReplacement<MT>
274
      <<<grid, block, 0, dev_ctx.stream()>>>(int_num_samples,
275 276 277 278 279 280 281
                                             out_data,
                                             num_distributions,
                                             num_categories,
                                             cumulative_probs_data,
                                             norm_probs_data,
                                             seed_offset.first,
                                             seed_offset.second);
282 283 284 285 286 287 288 289
}

}  // namespace phi

PD_REGISTER_KERNEL(multinomial,  // cuda_only
                   GPU,
                   ALL_LAYOUT,
                   phi::MultinomialKernel,
290
                   phi::dtype::float16,
291
                   phi::dtype::bfloat16,
292
                   float,
293 294 295
                   double) {
  kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
296 297

#endif