randperm_kernel.cu 6.1 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/randperm_kernel.h"

17 18
#ifdef __NVCC__
#include <curand_kernel.h>
19

20 21 22 23
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
24

25 26 27 28
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
29
#include "paddle/phi/core/kernel_registry.h"
30 31 32
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/randint_kernel.h"
33

L
Leo Chen 已提交
34 35 36
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"

37 38
DECLARE_bool(use_curand);

L
Leo Chen 已提交
39 40
namespace phi {

41 42 43 44 45 46
template <typename keyT, typename dataT>
__global__ void SwapRepeatKernel(keyT* key_out_data,
                                 dataT* out_data,
                                 int n,
                                 uint64_t seed,
                                 uint64_t offset) {
47
  size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
48
  if (idx >= n - 1) return;  // out of range
49

50 51
  bool is_first_repeat = false;
  if (key_out_data[idx] == key_out_data[idx + 1]) {
52
    if (idx == 0) {
53 54 55
      is_first_repeat = true;
    } else if (key_out_data[idx] != key_out_data[idx - 1]) {
      is_first_repeat = true;
56 57 58
    }
  }

59
  if (!is_first_repeat) return;
60 61 62

  int repeat_size = 1;
  for (int i = idx; i < n; ++i) {
63
    if (key_out_data[i] == key_out_data[i + 1]) {
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
      ++repeat_size;
    } else {
      break;
    }
  }

#ifdef __NVCC__
  curandStatePhilox4_32_10_t state;
  curand_init(seed, idx, offset, &state);
  for (int i = repeat_size - 1; i > 0; i--) {
    uint32_t r = curand(&state) % (i + 1);
#elif __HIPCC__
  hiprandStatePhilox4_32_10_t state;
  hiprand_init(seed, idx, offset, &state);
  for (int i = repeat_size - 1; i > 0; i--) {
    uint32_t r = hiprand(&state) % (i + 1);
#endif
    if (r != i) {
82 83 84
      dataT tmp = out_data[idx + i];
      out_data[idx + i] = out_data[idx + r];
      out_data[idx + r] = tmp;
85 86 87 88
    }
  }
}

L
Leo Chen 已提交
89
template <typename T, typename Context>
90 91
void RandpermRawKernel(
    const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
  DenseTensor key;
  RandintKernel<int, Context>(dev_ctx,
                              std::numeric_limits<int>::min(),
                              std::numeric_limits<int>::max(),
                              IntArray({n}),
                              phi::DataType::INT32,
                              &key);
  DenseTensor key_out = Empty<int, Context>(dev_ctx, IntArray({n}));

  DenseTensor range = Empty<T, Context>(dev_ctx, IntArray({n}));
  T* range_data = range.data<T>();
  funcs::ForRange<Context> for_range(dev_ctx, n);
  for_range([range_data] __device__(size_t idx) {
    range_data[idx] = static_cast<T>(idx);
  });

  out->Resize(phi::make_ddim({n}));
  T* out_data = dev_ctx.template Alloc<T>(out);

  // Refer to [Algorithm of randperm] https://osf.io/af2hy/ to
  // improve performance of radix sort.
  double n_d = static_cast<double>(n);
  int begin_bit = 0;
  int end_bit =
      std::ceil(std::log2(n_d - (6 * n_d * n_d + 1) / (12 * std::log(0.9))));

  size_t temp_storage_bytes = 0;
  cub::DeviceRadixSort::SortPairs<int, T>(nullptr,
                                          temp_storage_bytes,
                                          key.data<int>(),
                                          key_out.data<int>(),
                                          range.data<T>(),
                                          out_data,
                                          n,
                                          begin_bit,
                                          end_bit < 32 ? end_bit : 32,
                                          dev_ctx.stream());

130 131 132 133
  auto d_temp_storage = paddle::memory::Alloc(
      dev_ctx.GetPlace(),
      temp_storage_bytes,
      phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
  cub::DeviceRadixSort::SortPairs<int, T>(d_temp_storage->ptr(),
                                          temp_storage_bytes,
                                          key.data<int>(),
                                          key_out.data<int>(),
                                          range.data<T>(),
                                          out_data,
                                          n,
                                          begin_bit,
                                          end_bit < 32 ? end_bit : 32,
                                          dev_ctx.stream());

  auto gen_cuda = dev_ctx.GetGenerator();
  auto seed_offset = gen_cuda->IncrementOffset(n);

  auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n);
149 150 151 152
  SwapRepeatKernel<<<config.block_per_grid.x,
                     config.thread_per_block.x,
                     0,
                     dev_ctx.stream()>>>(
153
      key_out.data<int>(), out_data, n, seed_offset.first, seed_offset.second);
L
Leo Chen 已提交
154 155
}

156 157 158 159 160 161 162 163
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
                    int n,
                    DataType dtype,
                    DenseTensor* out) {
  RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}

L
Leo Chen 已提交
164 165
}  // namespace phi

166 167 168 169 170 171 172 173 174
PD_REGISTER_KERNEL(randperm_raw,
                   GPU,
                   ALL_LAYOUT,
                   phi::RandpermRawKernel,
                   float,
                   double,
                   int,
                   int64_t) {}

L
Leo Chen 已提交
175 176 177 178 179 180 181 182
PD_REGISTER_KERNEL(randperm,
                   GPU,
                   ALL_LAYOUT,
                   phi::RandpermKernel,
                   float,
                   double,
                   int,
                   int64_t) {}