select_impl.cu.h 15.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include <algorithm>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

namespace kps = phi::kps;

namespace phi {
namespace funcs {
using Mode = kps::details::ReduceMode;

/*
* Count how many of the data being processed by the current block are true
* 1. Load data from global memory and cast from bool to int64_t
* 2. Get result of this thread according to thread reduce
* 3. Get result of this block according to block reduce
* 4. first block store 0 and current result
*/
template <typename T>
struct NonZeroFunctor {
  HOSTDEVICE NonZeroFunctor() {}
  HOSTDEVICE inline T operator()(const T in) {
    if (in) {
      return static_cast<T>(1);
    } else {
      return static_cast<T>(0);
    }
  }
};

template <typename InT, typename OutT, int VecSize, int IsBoundary>
__device__ void GetBlockCountImpl(const InT *in,
                                  OutT *out,
                                  int num,
                                  int repeat) {
  InT in_data[VecSize];
  OutT temp[VecSize];
  OutT result = static_cast<OutT>(0.0f);
  using Add = kps::AddFunctor<OutT>;
  using Cast = NonZeroFunctor<InT>;
  int store_fix = BLOCK_ID_X + repeat * GRID_NUM_X;

  kps::Init<InT, VecSize>(&in_data[0], static_cast<InT>(0.0f));
  kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
  kps::ElementwiseUnary<InT, OutT, VecSize, 1, 1, Cast>(
      &temp[0], &in_data[0], Cast());
  kps::Reduce<OutT, VecSize, 1, 1, Add, Mode::kLocalMode>(
      &result, &temp[0], Add(), true);
  kps::Reduce<OutT, 1, 1, 1, Add, Mode::kGlobalMode>(
      &result, &result, Add(), true);
  if (store_fix == 0) {
    // first block's fix_size = 0;
    OutT tmp = static_cast<OutT>(0.0f);
    kps::WriteData<OutT, 1, 1, 1, true>(out + store_fix, &tmp, 1);
  }

  // store num of this block
  kps::WriteData<OutT, 1, 1, 1, true>(out + store_fix + 1, &result, 1);
}

// Count how many data is not zero in current block
template <typename InT, typename OutT, int VecSize>
__global__ void GetBlockCountKernel(const InT *in,
                                    OutT *out,
                                    int64_t numel,
                                    int64_t main_offset) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  int repeat = 0;
  for (; data_offset < main_offset; data_offset += stride) {
    GetBlockCountImpl<InT, OutT, VecSize, false>(
        in + data_offset, out, BLOCK_NUM_X * VecSize, repeat);
    repeat++;  // to get the real blockIdx
  }

  int num = numel - data_offset;
  if (num > 0) {
    GetBlockCountImpl<InT, OutT, VecSize, true>(
        in + data_offset, out, num, repeat);
  }
}

/*
* Get block num prefix us one block, VecSize must be 2
* 1. Each thread load 2 data : threadIdx.x and threadIdx.x + blockDimx.x
* 2. Cumsum limitation is blockDim.x must be less than 512
*/

template <typename InT,
          typename OutT,
          typename Functor,
          int VecSize,
          bool IsBoundary>
__device__ void CumsumImpl(
    const InT *in, OutT *out, OutT *pre_cumsum, int num, Functor func) {
  __shared__ OutT max_thread_data;
  OutT temp[VecSize];
  InT arg[VecSize];
  OutT result[VecSize];
  // init data_pr
  kps::Init<InT, VecSize>(&arg[0], static_cast<InT>(0.0f));
  // set pre_cumsum
  kps::Init<OutT, VecSize>(&temp[0], *pre_cumsum);
  // load data to arg
  kps::ReadData<InT, InT, VecSize, 1, 1, IsBoundary>(
      &arg[0], in, num, 1, BLOCK_NUM_X, 1);
  // block cumsum
  kps::Cumsum<InT, OutT, 1, Functor>(&result[0], &arg[0], func);
  // result = cumsum_result + pre_cumsum
  kps::ElementwiseBinary<OutT, OutT, VecSize, 1, 1, Functor>(
      &result[0], &result[0], &temp[0], func);
  // get the last prefix sum
  if ((THREAD_ID_X == BLOCK_NUM_X - 1) && !IsBoundary) {
    max_thread_data = result[VecSize - 1];
  }
  __syncthreads();
  // update pre_cumsum
  *pre_cumsum = max_thread_data;
  kps::WriteData<OutT, OutT, VecSize, 1, 1, IsBoundary>(
      out, &result[0], num, 1, BLOCK_NUM_X, 1);
}

// Compute this store_offset of this block
template <typename InT, typename OutT, typename Functor, int VecSize>
__global__ void CumsumOneBlock(
    const InT *in, OutT *out, int numel, int main_offset, Functor func) {
  int stride = BLOCK_NUM_X * VecSize;
  int offset = 0;
  OutT pre_cumsum = static_cast<OutT>(0);
  for (; offset < main_offset; offset += stride) {
    CumsumImpl<InT, OutT, Functor, VecSize, false>(
        in + offset, out + offset, &pre_cumsum, BLOCK_NUM_X * VecSize, func);
  }

  int num = numel - offset;
  if (num > 0) {
    CumsumImpl<InT, OutT, Functor, VecSize, true>(
        in + offset, out + offset, &pre_cumsum, num, func);
  }
}

template <typename OutT,
          typename MT,
          typename InT,
          typename IdT,
          typename Functor,
          int VecSize,
          int IsBoundary,
          int IsMaskData>
struct SelectCaller {
  __device__ void inline operator()(OutT *store_data,
                                    const MT *mask_data,
                                    const InT *in,
                                    Functor func,
                                    int num,
                                    int data_offset) {
    // where_index op
    IdT index_reg[VecSize];
    // Set data index of global
    kps::InitWithDataIndex<IdT, VecSize, 1, 1>(&index_reg[0], data_offset);
    // Get store data according to mask_idt
    kps::OperatorTernary<MT, IdT, OutT, Functor>(
        store_data, mask_data, &index_reg[0], func, VecSize);
  }
};

template <typename OutT,
          typename MT,
          typename InT,
          typename IdT,
          typename Functor,
          int VecSize,
          int IsBoundary>
struct SelectCaller<OutT,
                    MT,
                    InT,
                    IdT,
                    Functor,
                    VecSize,
                    IsBoundary,
                    1> {  // masked_select
  __device__ void inline operator()(OutT *store_data,
                                    const MT *mask_data,
                                    const InT *in,
                                    Functor func,
                                    int num,
                                    int data_offset) {
    InT in_data[VecSize];
    kps::ReadData<InT, VecSize, 1, 1, IsBoundary>(&in_data[0], in, num);
    // Get store data according to mask_idt
    kps::OperatorTernary<MT, InT, OutT, Functor>(
        store_data, mask_data, &in_data[0], func, VecSize);
  }
};

/**
* Get mask's index if mask == true
*/
template <typename InT,
          typename MT,
          typename OutT,
          typename Functor,
          int VecSize,
          int MaskData,
          int IsBoundary>  // SelectType = 1 Mask_select else where_index
__device__ void
SelectKernelImpl(OutT *out,
                 const MT *mask,
                 const InT *in,
                 Functor func,
                 int num,
                 int data_offset,
                 int store_rank) {
  const int kCVecSize = 2;
  // each thread cumsum 2 data
  using IdT = int64_t;
  // Set index data type
  using Add = kps::AddFunctor<IdT>;  // for cumsum
  using Cast = NonZeroFunctor<InT>;  // for mask

  IdT init_idx = static_cast<IdT>(0.0f);
  MT init_mask = static_cast<MT>(0.0f);

  IdT num_thread[kCVecSize];
  IdT cumsum_thread[kCVecSize];

  OutT store_data[VecSize * phi::DDim::kMaxRank];
  MT mask_data[VecSize];
  IdT mask_idt[VecSize];
  // init data_pr
  kps::Init<IdT, kCVecSize>(&cumsum_thread[0], init_idx);
  kps::Init<IdT, kCVecSize>(&num_thread[0], init_idx);
  kps::Init<MT, VecSize>(&mask_data[0], init_mask);
  // Load mask
  kps::ReadData<MT, VecSize, 1, 1, IsBoundary>(&mask_data[0], mask, num);
  // Cast from MT to int
  kps::ElementwiseUnary<MT, IdT, VecSize, 1, 1, Cast>(
      &mask_idt[0], &mask_data[0], Cast());
  // Get the num of thread only num_thread[1] has data
  kps::Reduce<IdT, VecSize, 1, 1, Add, Mode::kLocalMode>(
      &num_thread[0], &mask_idt[0], Add(), true);
  // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the
  // thread_fix
  kps::Cumsum<IdT, IdT, 1, Add>(&cumsum_thread[0], &num_thread[0], Add());
  // Get store data(index) according to mask_idt
  SelectCaller<OutT, MT, InT, IdT, Functor, VecSize, IsBoundary, MaskData>
      compute;
  compute(&store_data[0], &mask_data[0], in, func, num, data_offset);
  // get thread_fix
  int thread_fix =
      (static_cast<int>(cumsum_thread[0] - num_thread[0]) * store_rank);
  // get how many data need to store
  int store_num = static_cast<int>(num_thread[0]) * store_rank;
  // thread store num data, each thread may has different num
  kps::details::WriteData<OutT>(out + thread_fix, &store_data[0], store_num);
}

template <typename MT,
          typename InT,
          typename CT,
          typename OutT,
          typename Functor,
          int VecSize,
          int MaskData>
__global__ void SelectKernel(OutT *out,
                             const MT *mask,
                             const InT *in,
                             CT *cumsum,
                             Functor func,
                             const int64_t numel,
                             int64_t main_offset,
                             int store_rank) {
  int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
  int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
  int repeat = 0;
  int size = VecSize * BLOCK_ID_X;
  for (; data_offset < main_offset; data_offset += stride) {
    // Cumsum index
    int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
    // niuliling todo: us ReadData API
    int block_store_offset = cumsum[idx_cumsum];
    SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, false>(
        out + block_store_offset * store_rank,
        mask + data_offset,
        in + data_offset,
        func,
        size,
        data_offset,
        store_rank);
    repeat++;
  }

  int num = numel - data_offset;
  if (num > 0) {
    // Cumsum index
    int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X;
    // niuliling todo: us ReadData API
    int block_store_offset = static_cast<int>(cumsum[idx_cumsum]);
    SelectKernelImpl<InT, MT, OutT, Functor, VecSize, MaskData, true>(
        out + block_store_offset * store_rank,
        mask + data_offset,
        in + data_offset,
        func,
        num,
        data_offset,
        store_rank);
  }
}

inline int64_t Floor(int64_t in, int64_t div) { return in / div * div; }

// SelectData = 1 then masked_select; SelectData = 0 then where_index
template <typename MT,
          typename InT,
          typename OutT,
          int SelectData,
          typename Functor>
void SelectKernel(const KPDevice &dev_ctx,
                  const DenseTensor &condition,
                  const DenseTensor &in_data,
                  DenseTensor *out,
                  Functor func) {
  const MT *cond_data = condition.data<MT>();
  const int64_t numel = condition.numel();
  auto dims = condition.dims();
  int rank = SelectData ? 1 : dims.size();
  const InT *in_data_ptr = SelectData ? in_data.data<InT>() : nullptr;
  // calculate the inclusive prefix sum of "true_num_array"
  // to get the index of "out" tensor,
  // and the total number of cond_data[i]==true.
  // Example:
  // condition: F T T F F F T T
  // before:    0 1 1 0 0 0 1 1
  // after:     0 1 2 2 2 2 3 4
  // out:       1 2 6 7
  // alloc for cpu
  using CT = int64_t;  // set Count_data Type
  const int t_size = sizeof(CT);

  const paddle::platform::CUDAPlace &cuda_place = dev_ctx.GetPlace();
  paddle::platform::CPUPlace cpu_place = paddle::platform::CPUPlace();

  // 1.1 get stored data num of per block
  int total_true_num = 0;  // init
  const int kVecSize = 4;
#ifdef PADDLE_WITH_XPU_KP
  int block = 64;
  auto stream = dev_ctx.x_context()->xpu_stream;
  const int num_per_block = kVecSize * block;
  const int need_grids = (numel + num_per_block - 1) / num_per_block;
  const int grid = std::min(need_grids, 8);
#else
  const int block = 256;
  const int num_per_block = kVecSize * block;
  const int need_grids = (numel + num_per_block - 1) / num_per_block;
  const int grid = std::min(need_grids, 256);
  auto stream = dev_ctx.stream();
#endif
  const int64_t main_offset = Floor(numel, num_per_block);
  // 1.2 alloc tmp data for CoutBlock
  const int size_count_block = need_grids + 1;
  std::vector<int> dims_vec = {size_count_block * 2};
  ScalarArray dims_array(dims_vec);
  DenseTensor count_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
  CT *count_data = count_mem.data<CT>();
  // 1.3 launch CountKernl
  GetBlockCountKernel<MT, CT, kVecSize><<<grid, block, 0, stream>>>(
      cond_data, count_data, numel, main_offset);
  // 2.1 alloc cumsum data for CoutBlock prefix
  DenseTensor cumsum_mem = phi::Empty<CT, KPDevice>(dev_ctx, dims_array);
  CT *cumsum_data = cumsum_mem.data<CT>();
  // 2.2 get prefix of count_data for real out_index
  const int kCumVesize = 2;
  const int block_c = 256;
  const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c));
  using Add = kps::AddFunctor<CT>;
  CumsumOneBlock<CT, CT, Add, kCumVesize><<<1, block_c, 0, stream>>>(
      count_data, cumsum_data, size_count_block, main_offset_c, Add());
  // 3.1 set temp ptr for in;
  // 3.1 alloc for out
  // 3.1.1 get true_num for gpu place the last cumsum is the true_num
  paddle::memory::Copy(cpu_place,
                       &total_true_num,
                       cuda_place,
                       cumsum_data + need_grids,
                       t_size,
                       dev_ctx.stream());

  dev_ctx.Wait();
  // 3.1.2 allock for out with total_true_num
  std::vector<int64_t> out_dim = {static_cast<int64_t>(total_true_num)};
  if (SelectData == 0) {  // where_index
    out_dim.push_back(rank);
  }
  out->Resize(phi::make_ddim(out_dim));
  auto out_data = out->mutable_data<OutT>(cuda_place);
  // 3.2 get true data's index according to cond_data and cumsum_data
  if (total_true_num <= 0) return;
  SelectKernel<MT,
               InT,
               CT,
               OutT,
               Functor,
               kVecSize,
               SelectData><<<grid, block, 0, stream>>>(out_data,
                                                       cond_data,
                                                       in_data_ptr,
                                                       cumsum_data,
                                                       func,
                                                       numel,
                                                       main_offset,
                                                       rank);
}

}  // namespace funcs
}  // namespace phi

#endif