randperm_kernel.cu 6.0 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/randperm_kernel.h"

17 18
#ifdef __NVCC__
#include <curand_kernel.h>
19

20 21 22 23
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
24

25 26 27 28
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
29
#include "paddle/phi/common/memory_utils.h"
30
#include "paddle/phi/core/kernel_registry.h"
31 32 33
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/randint_kernel.h"
34

35 36
DECLARE_bool(use_curand);

L
Leo Chen 已提交
37 38
namespace phi {

39 40 41 42 43 44
template <typename keyT, typename dataT>
__global__ void SwapRepeatKernel(keyT* key_out_data,
                                 dataT* out_data,
                                 int n,
                                 uint64_t seed,
                                 uint64_t offset) {
45
  size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
46
  if (idx >= n - 1) return;  // out of range
47

48 49
  bool is_first_repeat = false;
  if (key_out_data[idx] == key_out_data[idx + 1]) {
50
    if (idx == 0) {
51 52 53
      is_first_repeat = true;
    } else if (key_out_data[idx] != key_out_data[idx - 1]) {
      is_first_repeat = true;
54 55 56
    }
  }

57
  if (!is_first_repeat) return;
58 59 60

  int repeat_size = 1;
  for (int i = idx; i < n; ++i) {
61
    if (key_out_data[i] == key_out_data[i + 1]) {
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
      ++repeat_size;
    } else {
      break;
    }
  }

#ifdef __NVCC__
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, offset, &state);
  for (int i = repeat_size - 1; i > 0; i--) {
    uint32_t r = curand(&state) % (i + 1);
#elif __HIPCC__
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, offset, &state);
  for (int i = repeat_size - 1; i > 0; i--) {
    uint32_t r = hiprand(&state) % (i + 1);
#endif
    if (r != i) {
80 81 82
      dataT tmp = out_data[idx + i];
      out_data[idx + i] = out_data[idx + r];
      out_data[idx + r] = tmp;
83 84 85 86
    }
  }
}

L
Leo Chen 已提交
87
template <typename T, typename Context>
88 89
void RandpermRawKernel(
    const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  DenseTensor key;
  RandintKernel<int, Context>(dev_ctx,
                              std::numeric_limits<int>::min(),
                              std::numeric_limits<int>::max(),
                              IntArray({n}),
                              phi::DataType::INT32,
                              &key);
  DenseTensor key_out = Empty<int, Context>(dev_ctx, IntArray({n}));

  DenseTensor range = Empty<T, Context>(dev_ctx, IntArray({n}));
  T* range_data = range.data<T>();
  funcs::ForRange<Context> for_range(dev_ctx, n);
  for_range([range_data] __device__(size_t idx) {
    range_data[idx] = static_cast<T>(idx);
  });

  out->Resize(phi::make_ddim({n}));
  T* out_data = dev_ctx.template Alloc<T>(out);

  // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to
  // improve performance of radix sort.
  double n_d = static_cast<double>(n);
  int begin_bit = 0;
  int end_bit =
      std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9))));

  size_t temp_storage_bytes = 0;
  cub::DeviceRadixSort::SortPairs<int, T>(nullptr,
                                          temp_storage_bytes,
                                          key.data<int>(),
                                          key_out.data<int>(),
                                          range.data<T>(),
                                          out_data,
                                          n,
                                          begin_bit,
                                          end_bit < 32 ? end_bit : 32,
                                          dev_ctx.stream());

128
  auto d_temp_storage = phi::memory_utils::Alloc(
129 130 131
      dev_ctx.GetPlace(),
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
  cub::DeviceRadixSort::SortPairs<int, T>(d_temp_storage->ptr(),
                                          temp_storage_bytes,
                                          key.data<int>(),
                                          key_out.data<int>(),
                                          range.data<T>(),
                                          out_data,
                                          n,
                                          begin_bit,
                                          end_bit < 32 ? end_bit : 32,
                                          dev_ctx.stream());

  auto gen_cuda = dev_ctx.GetGenerator();
  auto seed_offset = gen_cuda->IncrementOffset(n);

  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
147 148 149 150
  SwapRepeatKernel<<<config.block_per_grid.x,
                     config.thread_per_block.x,
                     0,
                     dev_ctx.stream()>>>(
151
      key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
L
Leo Chen 已提交
152 153
}

154 155 156 157 158 159 160 161
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
                    int n,
                    DataType dtype,
                    DenseTensor* out) {
  RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}

L
Leo Chen 已提交
162 163
}  // namespace phi

164 165 166 167 168 169 170 171 172
PD_REGISTER_KERNEL(randperm_raw,
                   GPU,
                   ALL_LAYOUT,
                   phi::RandpermRawKernel,
                   float,
                   double,
                   int,
                   int64_t) {}

L
Leo Chen 已提交
173 174 175 176 177 178 179 180
PD_REGISTER_KERNEL(randperm,
                   GPU,
                   ALL_LAYOUT,
                   phi::RandpermKernel,
                   float,
                   double,
                   int,
                   int64_t) {}