select_impl.cu.h 16.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include <algorithm>
28

29 30 31 32 33 34 35 36 37 38 39 40 41 42
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

namespace kps = phi::kps;

namespace phi {
namespace funcs {
using Mode = kps::details::ReduceMode;

/*
43 44 45 46 47 48
 * Count how many of the data being processed by the current block are true
 * 1. Load data from global memory and cast from bool to int64_t
 * 2. Get result of this thread according to thread reduce
 * 3. Get result of this block according to block reduce
 * 4. first block store 0 and current result
 */
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
template <typename T>
struct NonZeroFunctor {
  HOSTDEVICE NonZeroFunctor() {}
  HOSTDEVICE inline T operator()(const T in) {
    if (in) {
      return static_cast<T>(1);
    } else {
      return static_cast<T>(0);
    }
  }
};

template <typename InT, typename OutT, int VecSize, int IsBoundary>
__device__ void GetBlockCountImpl(const InT *in,
                                  OutT *out,
                                  int num,
                                  int repeat) {
  InT in_data[VecSize];
  OutT temp[VecSize];
  OutT result = static_cast<OutT>(0.0f);
  using Add = kps::AddFunctor<OutT>;
  using Cast = NonZeroFunctor<InT>;
  int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;

  kps::Init<InT, VecSize>(&in_data[0], static_cast<InT>(0.0f));
74 75
  kps::ReadData<InT, VecSize, 1, IsBoundary>(&in_data[0], in, num);
  kps::ElementwiseUnary<InT, OutT, VecSize, 1, Cast>(
76
      &temp[0], &in_data[0], Cast());
77
  kps::Reduce<OutT, VecSize, 1, Add, Mode::kLocalMode>(
78
      &result, &temp[0], Add(), true);
79
  kps::Reduce<OutT, 1, 1, Add, Mode::kGlobalMode>(
80 81 82 83
      &result, &result, Add(), true);
  if (store_fix == 0) {
    // first block's fix_size = 0;
    OutT tmp = static_cast<OutT>(0.0f);
84
    kps::WriteData<OutT, 1, 1, true>(out + store_fix, &tmp, 1);
85 86 87
  }

  // store num of this block
88
  kps::WriteData<OutT, 1, 1, true>(out + store_fix + 1, &result, 1);
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
}

// Count how many data is not zero in current block
template <typename InT, typename OutT, int VecSize>
__global__ void GetBlockCountKernel(const InT *in,
                                    OutT *out,
                                    int64_t numel,
                                    int64_t main_offset) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  int repeat = 0;
  for (; data_offset < main_offset; data_offset += stride) {
    GetBlockCountImpl<InT, OutT, VecSize, false>(
        in + data_offset, out, BLOCK_NUM_X * VecSize, repeat);
    repeat++;  // to get the real blockIdx
  }

  int num = numel - data_offset;
  if (num > 0) {
    GetBlockCountImpl<InT, OutT, VecSize, true>(
        in + data_offset, out, num, repeat);
  }
}

/*
114 115 116 117
 * Get block num prefix us one block, VecSize must be 2
 * 1. Each thread load 2 data : threadIdx.x and threadIdx.x + blockDimx.x
 * 2. Cumsum limitation is blockDim.x must be less than 512
 */
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          bool IsBoundary>
__device__ void CumsumImpl(
    const InT *in, OutT *out, OutT *pre_cumsum, int num, Functor func) {
  __shared__ OutT max_thread_data;
  OutT temp[VecSize];
  InT arg[VecSize];
  OutT result[VecSize];
  // init data_pr
  kps::Init<InT, VecSize>(&arg[0], static_cast<InT>(0.0f));
  // set pre_cumsum
  kps::Init<OutT, VecSize>(&temp[0], *pre_cumsum);
  // load data to arg
135
  kps::ReadData<InT, InT, VecSize, 1, IsBoundary>(
136 137
      &arg[0], in, num, 1, BLOCK_NUM_X, 1);
  // block cumsum
138
  kps::Cumsum<InT, OutT, Functor>(&result[0], &arg[0], func);
139
  // result = cumsum_result + pre_cumsum
140
  kps::ElementwiseBinary<OutT, OutT, VecSize, 1, Functor>(
141 142 143 144 145 146 147 148
      &result[0], &result[0], &temp[0], func);
  // get the last prefix sum
  if ((THREAD_ID_X == BLOCK_NUM_X - 1) && !IsBoundary) {
    max_thread_data = result[VecSize - 1];
  }
  __syncthreads();
  // update pre_cumsum
  *pre_cumsum = max_thread_data;
149
  kps::WriteData<OutT, OutT, VecSize, 1, IsBoundary>(
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
      out, &result[0], num, 1, BLOCK_NUM_X, 1);
}

// Compute this store_offset of this block
template <typename InT, typename OutT, typename Functor, int VecSize>
__global__ void CumsumOneBlock(
    const InT *in, OutT *out, int numel, int main_offset, Functor func) {
  int stride = BLOCK_NUM_X * VecSize;
  int offset = 0;
  OutT pre_cumsum = static_cast<OutT>(0);
  for (; offset < main_offset; offset += stride) {
    CumsumImpl<InT, OutT, Functor, VecSize, false>(
        in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize, func);
  }

  int num = numel - offset;
  if (num > 0) {
    CumsumImpl<InT, OutT, Functor, VecSize, true>(
        in + offset, out + offset, &pre_cumsum, num, func);
  }
}

172
// where_index
173 174 175 176 177 178
template <typename OutT,
          typename MT,
          typename InT,
          typename Functor,
          int VecSize,
          int IsBoundary,
179
          int MaskData>
180
struct SelectCaller {
181
  __device__ void inline operator()(OutT *out,
182 183 184
                                    const MT *mask_data,
                                    const InT *in,
                                    Functor func,
185 186 187 188 189 190 191
                                    int data_offset,
                                    int store_num,
                                    int thread_fix,
                                    int num) {
    int64_t in_data[VecSize];
    OutT store_data[VecSize * phi::DDim::kMaxRank];
    // set index
192
    kps::InitWithDataIndex<int64_t, VecSize, 1>(&in_data[0], data_offset);
193
    // Get store data according to mask_idt
194 195 196
    kps::OperatorTernary<MT, int64_t, OutT, Functor>(
        store_data, mask_data, &in_data[0], func, VecSize);
    kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
197 198 199
  }
};

200
// masked_select
201 202 203 204 205 206
template <typename OutT,
          typename MT,
          typename InT,
          typename Functor,
          int VecSize,
          int IsBoundary>
207 208
struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 1> {
  __device__ void inline operator()(OutT *out,
209 210 211
                                    const MT *mask_data,
                                    const InT *in,
                                    Functor func,
212 213 214 215
                                    int data_offset,
                                    int store_num,
                                    int thread_fix,
                                    int num) {
216
    InT in_data[VecSize];
217
    OutT store_data[VecSize * phi::DDim::kMaxRank];
218
    kps::ReadData<InT, VecSize, 1, IsBoundary>(&in_data[0], in, num);
219 220 221
    // Get store data according to mask_idt
    kps::OperatorTernary<MT, InT, OutT, Functor>(
        store_data, mask_data, &in_data[0], func, VecSize);
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
    kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
  }
};

// masked_select_grad
template <typename OutT,
          typename MT,
          typename InT,
          typename Functor,
          int VecSize,
          int IsBoundary>
struct SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, 2> {
  __device__ void inline operator()(OutT *out,
                                    const MT *mask_data,
                                    const InT *in,
                                    Functor func,
                                    int data_offset,
                                    int store_num,
                                    int thread_fix,
                                    int num) {
    InT in_data[VecSize];
    OutT store_data[VecSize * phi::DDim::kMaxRank];
    kps::details::ReadData<InT>(&in_data[0], in + thread_fix, store_num);
    kps::OperatorTernary<MT, InT, OutT, Functor>(
        store_data, mask_data, &in_data[0], func, VecSize);
247
    kps::WriteData<OutT, VecSize, 1, IsBoundary>(out, &store_data[0], num);
248 249 250 251
  }
};

/**
252 253
 * Get mask's index if mask == true
 */
254 255 256 257 258 259 260
template <typename InT,
          typename MT,
          typename OutT,
          typename Functor,
          int VecSize,
          int MaskData,
          int IsBoundary>  // SelectType = 1 Mask_select else where_index
261 262 263 264 265 266 267
__device__ void SelectKernelImpl(OutT *out,
                                 const MT *mask,
                                 const InT *in,
                                 Functor func,
                                 int num,
                                 int data_offset,
                                 int store_rank) {
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
  const int kCVecSize = 2;
  // each thread cumsum 2 data
  using IdT = int64_t;
  // Set index data type
  using Add = kps::AddFunctor<IdT>;  // for cumsum
  using Cast = NonZeroFunctor<InT>;  // for mask

  IdT init_idx = static_cast<IdT>(0.0f);
  MT init_mask = static_cast<MT>(0.0f);

  IdT num_thread[kCVecSize];
  IdT cumsum_thread[kCVecSize];

  MT mask_data[VecSize];
  IdT mask_idt[VecSize];
  // init data_pr
  kps::Init<IdT, kCVecSize>(&cumsum_thread[0], init_idx);
  kps::Init<IdT, kCVecSize>(&num_thread[0], init_idx);
  kps::Init<MT, VecSize>(&mask_data[0], init_mask);
  // Load mask
288
  kps::ReadData<MT, VecSize, 1, IsBoundary>(&mask_data[0], mask, num);
289
  // Cast from MT to int
290
  kps::ElementwiseUnary<MT, IdT, VecSize, 1, Cast>(
291 292
      &mask_idt[0], &mask_data[0], Cast());
  // Get the num of thread only num_thread[1] has data
293
  kps::Reduce<IdT, VecSize, 1, Add, Mode::kLocalMode>(
294 295 296
      &num_thread[0], &mask_idt[0], Add(), true);
  // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the
  // thread_fix
297
  kps::Cumsum<IdT, IdT, Add>(&cumsum_thread[0], &num_thread[0], Add());
298 299 300 301 302 303
  // get thread_fix
  int thread_fix =
      (static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
  // get how many data need to store
  int store_num = static_cast<int>(num_thread[0]) * store_rank;
  // thread store num data, each thread may has different num
304 305 306
  // Get store data(index) according to mask_idt
  SelectCaller<OutT, MT, InT, Functor, VecSize, IsBoundary, MaskData> select;
  select(out, mask_data, in, func, data_offset, store_num, thread_fix, num);
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
}

template <typename MT,
          typename InT,
          typename CT,
          typename OutT,
          typename Functor,
          int VecSize,
          int MaskData>
__global__ void SelectKernel(OutT *out,
                             const MT *mask,
                             const InT *in,
                             CT *cumsum,
                             Functor func,
                             const int64_t numel,
                             int64_t main_offset,
                             int store_rank) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  int repeat = 0;
  int size = VecSize * BLOCK_ID_X;
328
  CT block_store_offset = 0;
329 330 331
  for (; data_offset < main_offset; data_offset += stride) {
    // Cumsum index
    int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
332 333 334
    kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
    int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
    int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
335
    SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
336
        out + out_fix,
337
        mask + data_offset,
338
        in + in_fix,
339 340 341 342 343 344 345 346 347 348 349
        func,
        size,
        data_offset,
        store_rank);
    repeat++;
  }

  int num = numel - data_offset;
  if (num > 0) {
    // Cumsum index
    int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
350 351 352
    kps::details::ReadData<CT>(&block_store_offset, cumsum + idx_cumsum, 1);
    int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset;
    int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank;
353
    SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
354
        out + out_fix,
355
        mask + data_offset,
356
        in + in_fix,
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
        func,
        num,
        data_offset,
        store_rank);
  }
}

inline int64_t Floor(int64_t in, int64_t div) { return in / div * div; }

// SelectData = 1 then masked_select; SelectData = 0 then where_index
template <typename MT,
          typename InT,
          typename OutT,
          int SelectData,
          typename Functor>
void SelectKernel(const KPDevice &dev_ctx,
                  const DenseTensor &condition,
                  const DenseTensor &in_data,
                  DenseTensor *out,
                  Functor func) {
  const MT *cond_data = condition.data<MT>();
  const int64_t numel = condition.numel();
  auto dims = condition.dims();
  int rank = SelectData ? 1 : dims.size();
  const InT *in_data_ptr = SelectData ? in_data.data<InT>() : nullptr;
  // calculate the inclusive prefix sum of "true_num_array"
  // to get the index of "out" tensor,
  // and the total number of cond_data[i]==true.
  // Example:
  // condition: F T T F F F T T
  // before:    0 1 1 0 0 0 1 1
  // after:     0 1 2 2 2 2 3 4
  // out:       1 2 6 7
  // alloc for cpu
  using CT = int64_t;  // set Count_data Type
  const int t_size = sizeof(CT);

  const paddle::platform::CUDAPlace &cuda_place = dev_ctx.GetPlace();
  paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();

  // 1.1 get stored data num of per block
  const int kVecSize = 4;
#ifdef PADDLE_WITH_XPU_KP
  int block = 64;
  auto stream = dev_ctx.x_context()->xpu_stream;
  const int num_per_block = kVecSize * block;
  const int need_grids = (numel + num_per_block - 1) / num_per_block;
  const int grid = std::min(need_grids, 8);
#else
  const int block = 256;
  const int num_per_block = kVecSize * block;
  const int need_grids = (numel + num_per_block - 1) / num_per_block;
  const int grid = std::min(need_grids, 256);
  auto stream = dev_ctx.stream();
#endif
  const int64_t main_offset = Floor(numel, num_per_block);
  // 1.2 alloc tmp data for CoutBlock
  const int size_count_block = need_grids + 1;
  std::vector<int> dims_vec = {size_count_block * 2};
416
  IntArray dims_array(dims_vec);
417 418 419
  DenseTensor count_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
  CT *count_data = count_mem.data<CT>();
  // 1.3 launch CountKernl
420 421
  GetBlockCountKernel<MT, CT, kVecSize>
      <<<grid, block, 0, stream>>>(cond_data, count_data, numel, main_offset);
422 423 424 425
  // 2.1 alloc cumsum data for CoutBlock prefix
  DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
  CT *cumsum_data = cumsum_mem.data<CT>();
  // 2.2 get prefix of count_data for real out_index
426
  CT total_true_num = static_cast<CT>(0);  // init
427 428 429
  const int kCumVesize = 2;
  const int block_c = 256;
  const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
430

431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
  using Add = kps::AddFunctor<CT>;
  CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(
      count_data, cumsum_data, size_count_block, main_offset_c, Add());
  // 3.1 set temp ptr for in;
  // 3.1 alloc for out
  // 3.1.1 get true_num for gpu place the last cumsum is the true_num
  paddle::memory::Copy(cpu_place,
                       &total_true_num,
                       cuda_place,
                       cumsum_data + need_grids,
                       t_size,
                       dev_ctx.stream());

  dev_ctx.Wait();
  // 3.1.2 allock for out with total_true_num
  std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)};
447 448 449 450

  if (SelectData == 1) {
    out->Resize(phi::make_ddim(out_dim));
  } else if (SelectData == 0) {  // == 0 where_index
451
    out_dim.push_back(static_cast<int64_t>(rank));
452
    out->Resize(phi::make_ddim(out_dim));
453 454 455 456
  }
  auto out_data = out->mutable_data<OutT>(cuda_place);
  // 3.2 get true data's index according to cond_data and cumsum_data
  if (total_true_num <= 0) return;
457 458 459 460 461 462 463 464 465
  SelectKernel<MT, InT, CT, OutT, Functor, kVecSize, SelectData>
      <<<grid, block, 0, stream>>>(out_data,
                                   cond_data,
                                   in_data_ptr,
                                   cumsum_data,
                                   func,
                                   numel,
                                   main_offset,
                                   rank);
466 467 468 469 470 471
}

}  // namespace funcs
}  // namespace phi

#endif