conv.cu.h 34.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <thrust/remove.h>
#include <thrust/unique.h>
#include "paddle/phi/kernels/sparse/conv_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
#include "paddle/phi/kernels/primitive/compute_primitives.h"

namespace phi {
namespace sparse {

using Dims4D = phi::funcs::sparse::Dims4D;

// Vectorize load and store global memory
// In the scene of 3D point cloud, the slice_size 4,8,16,32,64 are commonly
// used.
template <typename T, typename IndexT = int, int VecSize>
__global__ void GatherKernel(const T* params,
                             const IndexT* indices,
                             T* output,
                             size_t index_size,
                             size_t slice_size) {
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size / VecSize, int64_t) {
    const int vec_slice_size = slice_size / VecSize;
    int indices_i = i / vec_slice_size;
    int slice_i = i - indices_i * vec_slice_size;  // offset inside the slice
    IndexT gather_i = indices[indices_i];
    int64_t params_i = gather_i * slice_size + slice_i * VecSize;
    using LoadT = phi::AlignedVector<T, VecSize>;
    using StoreT = phi::AlignedVector<T, VecSize>;
    LoadT params_vec;
    phi::Load<T, VecSize>(params + params_i, &params_vec);
    phi::Store<T, VecSize>(params_vec, output + i * VecSize);
  }
}

// double sparse, seed GroupIndexs
template <typename T, typename IntT, int VecSize>
__global__ void GatherKernelV2(const T* inputs,
                               const int* index_counts,
                               const int* index_groups,
                               const int non_zero_num,
                               const int kernel_size,
68
                               const int max_voxel,
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
                               const int channels,
                               const int buffer_count,
                               T* output) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  const int vec_channels = channels / VecSize;
  using LoadT = phi::AlignedVector<T, VecSize>;
  using StoreT = phi::AlignedVector<T, VecSize>;
  for (int i = tid; i < non_zero_num * vec_channels;
       i += gridDim.x * blockDim.x) {
    int indices_i = i / vec_channels;
    int channels_i = i - indices_i * vec_channels;
    LoadT in_vec;
    phi::Load<T, VecSize>(inputs + indices_i * channels + channels_i * VecSize,
                          &in_vec);
#pragma unroll
    for (int it = 0; it < buffer_count; it++) {
      int len = index_counts[indices_i + it * non_zero_num];
86
      const int group_offset = it * kernel_size * max_voxel * non_zero_num;
87 88
#pragma unroll
      for (int j = 0; j < len; j++) {
89 90
        int out_i = index_groups[indices_i * kernel_size * max_voxel + j +
                                 group_offset];
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        phi::Store<T, VecSize>(
            in_vec, output + out_i * channels + channels_i * VecSize);
      }
    }
  }
}

template <typename T, typename IntT>
inline void Gather(const GPUContext& dev_ctx,
                   const T* inputs,
                   const IntT* indices,
                   const int indices_size,
                   const int channels,
                   T* output) {
  const int VecSize = VecBytes / sizeof(T);
  if (channels % VecSize == 0) {
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, indices_size * channels / VecSize, 1);
    GatherKernel<T, IntT, VecSize>
        <<<config.block_per_grid.x,
           config.thread_per_block.x,
           0,
           dev_ctx.stream()>>>(inputs, indices, output, indices_size, channels);
  } else {
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, indices_size * channels, 1);
    GatherKernel<T, IntT, 1>
        <<<config.block_per_grid.x,
           config.thread_per_block.x,
           0,
           dev_ctx.stream()>>>(inputs, indices, output, indices_size, channels);
  }
}

template <typename T, typename IntT>
inline void GatherV2(const GPUContext& dev_ctx,
                     const T* inputs,
                     const int* index_counts,
                     const int* index_groups,
                     const int non_zero_num,
                     const int kernel_size,
132
                     const int max_voxel,
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
                     const int channels,
                     const int buffer_count,
                     T* output) {
  const int VecSize = VecBytes / sizeof(T);
  if (channels % VecSize == 0) {
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, non_zero_num * channels / VecSize, 1);
    GatherKernelV2<T, IntT, VecSize><<<config.block_per_grid.x,
                                       config.thread_per_block.x,
                                       0,
                                       dev_ctx.stream()>>>(inputs,
                                                           index_counts,
                                                           index_groups,
                                                           non_zero_num,
                                                           kernel_size,
148
                                                           max_voxel,
149 150 151 152 153 154 155 156 157 158 159 160 161 162
                                                           channels,
                                                           buffer_count,
                                                           output);
  } else {
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        dev_ctx, non_zero_num * channels, 1);
    GatherKernelV2<T, IntT, 1><<<config.block_per_grid.x,
                                 config.thread_per_block.x,
                                 0,
                                 dev_ctx.stream()>>>(inputs,
                                                     index_counts,
                                                     index_groups,
                                                     non_zero_num,
                                                     kernel_size,
163
                                                     max_voxel,
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
                                                     channels,
                                                     buffer_count,
                                                     output);
  }
}

// unique the out indexs in rulebook
template <typename IntT>
__global__ void UniqueKernel(const IntT* in_indexs,
                             const int rulebook_len,
                             int* out_index_table,
                             int* out_indexs,
                             int* nnz) {
  extern __shared__ int cache[];
  __shared__ int count, start;
  if (threadIdx.x == 0) {
    count = 0;
    start = 0;
  }
  __syncthreads();

  int i = threadIdx.x + blockDim.x * blockIdx.x;
  if (i < rulebook_len) {
    // atomicOr only support int
    int index = static_cast<int>(in_indexs[i]);
    int change_index = index == 0 ? -1 : index;
    int flag = atomicOr(out_index_table + index, change_index);
    if (flag == 0) {
      int j = atomicAdd(&count, 1);
      cache[j] = index;
    }
  }
  __syncthreads();

  if (threadIdx.x == 0) {
    start = atomicAdd(nnz, count);
  }
  __syncthreads();
  for (int i = threadIdx.x; i < count; i += blockDim.x) {
    out_indexs[start + i] = cache[i];
  }
}

template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table,
                            const int n,
210
                            const int offset,
211 212 213 214 215 216 217 218 219 220 221
                            IntT* out_indexs,
                            int* out_index_counts,
                            int* out_index_groups) {
  CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
    IntT index = out_indexs[i];
    int real_index = out_index_table[index];
    out_indexs[i] = real_index;

    // kernel_size at most
    int j = atomicAdd(out_index_counts + real_index, 1);
    // nnz * kernel_size
222
    out_index_groups[real_index * offset + j] = i;
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
  }
}

/**
 * @brief product rulebook
 * for input_i in x_indices:
 *   if input_i participate in the convolution calculation:
 *       infer the output_i by input_i and kernel_i
 *       save output_i
 *
 * x_indices: the indices of input features
 * x_dims: the input dims
 * kernel_dims: the kernel dims
 * out_dims: the output dims
 * non_zero_num: the number of input features
 * rulebook: the rulebook to save the kernel index, input index and output index
 * counter: save the number of times each location in the kernel participates in
 *the caculation
 **/
template <typename T>
__global__ void ProductRuleBookKernel(const T* x_indices,
                                      const Dims4D x_dims,
                                      const Dims4D kernel_dims,
                                      const Dims4D out_dims,
                                      const int64_t non_zero_num,
                                      const Dims4D paddings,
                                      const Dims4D dilations,
                                      const Dims4D strides,
                                      T* rulebook,
                                      int* counter) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  extern __shared__ int counter_buf[];  // kernel_size
  const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1];
  const int offset = kernel_size * non_zero_num;
  for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
    counter_buf[i] = 0;
  }
  __syncthreads();

  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int kernel_index = 0;
    T batch = x_indices[i];
    T in_z = x_indices[i + non_zero_num];
    T in_y = x_indices[i + 2 * non_zero_num];
    T in_x = x_indices[i + 3 * non_zero_num];
    for (int kz = 0; kz < kernel_dims[1]; kz++) {
      for (int ky = 0; ky < kernel_dims[2]; ky++) {
        for (int kx = 0; kx < kernel_dims[3]; kx++) {
          int in_i = -1, out_index = -1, kernel_i = -1;
          if (phi::funcs::sparse::Check(x_dims,
                                        kernel_dims,
                                        paddings,
                                        dilations,
                                        strides,
                                        in_x,
                                        in_y,
                                        in_z,
                                        kx,
                                        ky,
                                        kz)) {
            T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1];
            T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2];
            T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3];
            in_i = i;
            out_index = phi::funcs::sparse::PointToIndex<Dims4D>(
                batch, out_x, out_y, out_z, out_dims);
            atomicAdd(&counter_buf[kernel_index], 1);
            kernel_i = kernel_index;
          }
          // rulebook[kernel_index * non_zero_num + i] = kernel_i;
          rulebook[kernel_index * non_zero_num + i] = in_i;
          rulebook[kernel_index * non_zero_num + offset + i] = out_index;
          ++kernel_index;
        }
      }
    }
  }
  __syncthreads();
  for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
    atomicAdd(&counter[i], counter_buf[i]);
  }
}

306
template <typename IntT, bool save_out_index = true>
307 308 309
__global__ void GetOutIndexTable(const IntT* indices,
                                 const IntT non_zero_num,
                                 const Dims4D dims,
310 311 312 313 314 315 316 317 318
                                 int* out_index_table,
                                 int* out_index_table2,
                                 int* max_voxel) {
  __shared__ int cache_max;
  if (threadIdx.x == 0) {
    cache_max = 0;
  }
  __syncthreads();

319 320 321 322 323 324
  CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
    IntT batch = indices[i];
    IntT in_z = indices[i + non_zero_num];
    IntT in_y = indices[i + 2 * non_zero_num];
    IntT in_x = indices[i + 3 * non_zero_num];
    IntT index = PointToIndex(batch, in_x, in_y, in_z, dims);
325 326 327 328 329 330 331 332 333 334 335
    if (save_out_index) {
      out_index_table[index] = i == 0 ? -1 : i;
    }

    int count = atomicAdd(out_index_table2 + index, 1);
    atomicMax(&cache_max, count);
  }

  __syncthreads();
  if (threadIdx.x == 0) {
    atomicMax(max_voxel, cache_max + 1);
336 337 338 339 340 341 342 343
  }
}

template <typename IntT>
__global__ void GetOutIndexTable(int* indexs,
                                 const int non_zero_num,
                                 const Dims4D out_dims,
                                 int* out_index_table,
344 345
                                 int* out_index_table2,
                                 int* max_voxel,
346
                                 IntT* out_indices) {
347 348 349 350 351 352
  __shared__ int cache_max;
  if (threadIdx.x == 0) {
    cache_max = 0;
  }
  __syncthreads();

353 354 355
  CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) {
    IntT index = static_cast<IntT>(indexs[i]);
    out_index_table[index] = i;
356 357 358 359

    int count = atomicAdd(out_index_table2 + index, 1);
    atomicMax(&cache_max, count);

360 361 362 363 364 365 366 367 368 369
    IntT batch, x, y, z;
    phi::funcs::sparse::IndexToPoint<Dims4D>(
        index, out_dims, &batch, &x, &y, &z);
    // get out indices
    out_indices[i] = batch;
    out_indices[i + non_zero_num] = z;
    out_indices[i + non_zero_num * 2] = y;
    out_indices[i + non_zero_num * 3] = x;
    indexs[i] = 0;
  }
370 371 372 373 374

  __syncthreads();
  if (threadIdx.x == 0) {
    atomicMax(max_voxel, cache_max + 1);
  }
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
}

template <typename IntT>
__global__ void CopyRuleBook(const int* counters,
                             const int* offsets,
                             const IntT* in_rulebook,
                             const int len,
                             const int kernel_size,
                             const int non_zero_num,
                             IntT* out_rulebook) {
  int tid = threadIdx.x + blockDim.x * blockIdx.x;
  extern __shared__ int cache_counters[];
  int* cache_offsets = cache_counters + kernel_size;
  for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
    cache_counters[i] = counters[i];
    cache_offsets[i] = offsets[i];
  }
  __syncthreads();
  for (int i = tid; i < len; i += gridDim.x * blockDim.x) {
    // get the kernel index
    int kernel_index = 0;
    for (; kernel_index < kernel_size - 1; kernel_index++) {
      if (i >= offsets[kernel_index] && i < offsets[kernel_index + 1]) {
        break;
      }
    }
    int inner_index = i - offsets[kernel_index];
    out_rulebook[i] = in_rulebook[kernel_index * non_zero_num + inner_index];
    out_rulebook[len + i] =
        in_rulebook[kernel_size * non_zero_num + kernel_index * non_zero_num +
                    inner_index];
  }
}

template <typename T>
__global__ void ProductSubmRuleBookKernel(const T* x_indices,
                                          const Dims4D x_dims,
                                          const Dims4D kernel_dims,
                                          const Dims4D out_dims,
                                          const int64_t non_zero_num,
                                          const Dims4D paddings,
                                          const Dims4D dilations,
                                          const Dims4D strides,
                                          const int* out_index_table,
                                          T* rulebook,
                                          int* counter) {
  int tid = threadIdx.x + blockIdx.x * blockDim.x;
  const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1];
  extern __shared__ int counter_buf[];  // kernel_size
  int* counter_buf2 = counter_buf + kernel_size;
  // length = kernel_size * blockDim.x * 2;
  int* rulebook_buf = counter_buf + kernel_size * 2;

  const int offset = kernel_size * non_zero_num;
  for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
    counter_buf[i] = 0;
  }
  __syncthreads();

  for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) {
    int kernel_index = 0;
    T batch = x_indices[i];
    T in_z = x_indices[i + non_zero_num];
    T in_y = x_indices[i + 2 * non_zero_num];
    T in_x = x_indices[i + 3 * non_zero_num];
    for (int kz = 0; kz < kernel_dims[1]; kz++) {
      for (int ky = 0; ky < kernel_dims[2]; ky++) {
        for (int kx = 0; kx < kernel_dims[3]; kx++) {
          int in_i = -1, out_index = -1, kernel_i = -1;
          if (phi::funcs::sparse::Check(x_dims,
                                        kernel_dims,
                                        paddings,
                                        dilations,
                                        strides,
                                        in_x,
                                        in_y,
                                        in_z,
                                        kx,
                                        ky,
                                        kz)) {
            T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1];
            T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2];
            T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3];
            out_index = phi::funcs::sparse::PointToIndex<Dims4D>(
                batch, out_x, out_y, out_z, out_dims);
            int real_out_index = out_index_table[out_index];
            if (real_out_index != 0) {
              real_out_index = real_out_index == -1 ? 0 : real_out_index;
              in_i = i;
              int buf_i = atomicAdd(&counter_buf[kernel_index], 1);
              kernel_i = kernel_index;
              rulebook_buf[kernel_index * blockDim.x + buf_i] = in_i;
              rulebook_buf[kernel_index * blockDim.x +
                           kernel_size * blockDim.x + buf_i] = real_out_index;
            }
          }
          ++kernel_index;
        }
      }
    }
  }
  __syncthreads();
  for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) {
    counter_buf2[i] = atomicAdd(&counter[i], counter_buf[i]);
  }
  __syncthreads();
  for (int i = 0; i < kernel_size; i++) {
    if (threadIdx.x < counter_buf[i]) {
      // rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i;
      rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] =
          rulebook_buf[i * blockDim.x + threadIdx.x];
      rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] =
          rulebook_buf[i * blockDim.x + kernel_size * blockDim.x + threadIdx.x];
    }
  }
}

template <typename IntT>
__global__ void GroupIndexs(const int n,
494
                            const int offset,
495 496 497 498 499 500 501 502
                            const IntT* indexs,
                            int* index_counts,
                            int* index_groups) {
  CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) {
    IntT index = indexs[i];
    // kernel_size at most
    int j = atomicAdd(index_counts + index, 1);
    // nnz * kernel_size
503
    index_groups[index * offset + j] = i;
504 505 506 507 508 509 510
  }
}

// double space to reduce atomicAdd conflict
template <typename IntT>
__global__ void GroupIndexsV2(const int rulebook_len,
                              const int non_zero_num,
511
                              const int offset,
512 513 514 515 516 517 518 519 520 521
                              const int half_kernel_offset,
                              const IntT* indexs,
                              int* index_counts,
                              int* index_groups) {
  CUDA_KERNEL_LOOP_TYPE(i, rulebook_len, int64_t) {
    IntT index = indexs[i];
    int* counts_ptr =
        i < half_kernel_offset ? index_counts : index_counts + non_zero_num;
    int* groups_ptr = i < half_kernel_offset
                          ? index_groups
522
                          : index_groups + non_zero_num * offset;
523 524 525
    // conflict kernel_size times at most
    int j = atomicAdd(counts_ptr + index, 1);
    // nnz * kernel_size
526
    groups_ptr[index * offset + j] = i;
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624
  }
}

inline void CallThrustScan(const GPUContext& dev_ctx,
                           const int* counter_ptr,
                           const int kernel_size,
                           int* offsets_ptr,
                           int* h_counter_ptr,
                           int* h_offsets_ptr) {
#ifdef PADDLE_WITH_HIP
  thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else
  thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                         counter_ptr,
                         counter_ptr + kernel_size,
                         offsets_ptr);

  phi::backends::gpu::GpuMemcpyAsync(h_counter_ptr,
                                     counter_ptr,
                                     kernel_size * sizeof(int),
                                     gpuMemcpyDeviceToHost,
                                     dev_ctx.stream());

  phi::backends::gpu::GpuMemcpyAsync(h_offsets_ptr,
                                     offsets_ptr,
                                     kernel_size * sizeof(int),
                                     gpuMemcpyDeviceToHost,
                                     dev_ctx.stream());
}

// the basic algorithm can refer to convolution_kernel.cc or
// the second paper
// example:
// 1. the rulebook:
//  the kernel_index:                       0, 0, 0, 1, 1, 1, 2, 2, ....
//  the out_index(key):                     20, 30, 33, 30, 33, 20, 25
// 2. mark the index of out_index(value):   0, 1, 2, 3, 4, 5, 6, ....
// 3. sorted the (key, value)
// 4. unique the (key, value):
//  unique_key:     20, 25, 30, 33
//  unique_values:  0, 2, 3, 5
//  the index of unique_values is: 0, 1, 2, 3
// 5. update the out_index by unique_key, uniqe_value and the index of
// unique_value:
//  the new out_index: 0, 2, 3, 2, 3, 0, 1
template <typename T, typename Context, typename IntT = int>
int ProductRuleBook(const Context& dev_ctx,
                    const SparseCooTensor& x,
                    const std::vector<int>& kernel_sizes,
                    const std::vector<int>& paddings,
                    const std::vector<int>& dilations,
                    const std::vector<int>& strides,
                    const DDim& out_dims,
                    const bool subm,
                    DenseTensor* rulebook,
                    DenseTensor* counter_per_kernel,
                    DenseTensor* offsets_per_kernel,
                    DenseTensor* out_index,
                    DenseTensor* unique_value,
                    SparseCooTensor* out,
                    int* h_counter,
                    int* h_offsets) {
  auto indices_dtype = paddle::experimental::CppTypeToDataType<IntT>::Type();
  const int64_t non_zero_num = x.nnz();
  const auto& non_zero_indices = x.non_zero_indices();
  const IntT* indices_ptr = non_zero_indices.data<IntT>();
  int* counter_ptr = counter_per_kernel->data<int>();
  int* offsets_ptr = offsets_per_kernel->data<int>();
  int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2];

  const auto x_dims = x.dims();
  Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]);
  Dims4D d_kernel_dims(1, kernel_sizes[2], kernel_sizes[1], kernel_sizes[0]);
  Dims4D d_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]);
  Dims4D d_paddings(1, paddings[2], paddings[1], paddings[0]);
  Dims4D d_strides(1, strides[2], strides[1], strides[0]);
  Dims4D d_dilations(1, dilations[2], dilations[1], dilations[0]);
  // 1. product rule book
  phi::backends::gpu::GpuMemsetAsync(counter_ptr,
                                     0,
                                     sizeof(int) * counter_per_kernel->numel(),
                                     dev_ctx.stream());
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);

  const int rulebook_rows = 2;
  const int rulebook_cols = kernel_size * non_zero_num;
  DenseTensorMeta rulebook_meta(
      indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW);

  int64_t table_size = 1;
  for (int i = 0; i < out_dims.size() - 1; i++) {
    table_size *= out_dims[i];
  }
  DenseTensor out_index_table = phi::Empty<int>(dev_ctx, {table_size});
  int* out_index_table_ptr = out_index_table.data<int>();

625 626 627 628
  DenseTensor out_index_table2 = phi::Empty<int>(dev_ctx, {table_size + 1});
  int* out_index_table2_ptr = out_index_table2.data<int>();
  int* h_max_voxel = h_counter + kernel_size;

629 630 631 632 633 634 635 636 637 638 639 640
  if (subm) {
    DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
    IntT* rulebook_ptr = tmp_rulebook.data<IntT>();
    DenseTensor out_indices =
        phi::EmptyLike<IntT>(dev_ctx, x.non_zero_indices());
    DenseTensor out_values = phi::Empty<T>(dev_ctx, {x.nnz(), kernel_sizes[4]});

    phi::Copy(
        dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices);

    phi::backends::gpu::GpuMemsetAsync(
        out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
641 642 643 644
    phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr,
                                       0,
                                       sizeof(int) * (table_size + 1),
                                       dev_ctx.stream());
645 646 647

    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1);
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
    GetOutIndexTable<IntT>
        <<<config.block_per_grid,
           config.thread_per_block,
           0,
           dev_ctx.stream()>>>(out_indices.data<IntT>(),
                               non_zero_num,
                               d_x_dims,
                               out_index_table_ptr,
                               out_index_table2_ptr,
                               out_index_table2_ptr + table_size);
    phi::backends::gpu::GpuMemcpyAsync(h_max_voxel,
                                       out_index_table2_ptr + table_size,
                                       sizeof(int),
                                       gpuMemcpyDeviceToHost,
                                       dev_ctx.stream());
    dev_ctx.Wait();
664

Z
zhangkaihuo 已提交
665 666 667
    size_t cache_size =
        kernel_size * 2 * sizeof(int) +
        kernel_size * config.thread_per_block.x * 2 * sizeof(int);
668 669 670 671 672 673 674
    const int MAX_CACHE_SIZE = 48 * 1024;
    while (cache_size >= MAX_CACHE_SIZE) {
      config.thread_per_block.x /= 2;
      config.block_per_grid.x *= 2;
      PADDLE_ENFORCE_GE(config.thread_per_block.x,
                        32,
                        phi::errors::Fatal("the shared memory is not enough"));
Z
zhangkaihuo 已提交
675
      cache_size = kernel_size * 2 * sizeof(int) +
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
                   kernel_size * config.thread_per_block.x * 2 * sizeof(int);
    }
    ProductSubmRuleBookKernel<IntT><<<config.block_per_grid.x,
                                      config.thread_per_block.x,
                                      cache_size,
                                      dev_ctx.stream()>>>(indices_ptr,
                                                          d_x_dims,
                                                          d_kernel_dims,
                                                          d_out_dims,
                                                          non_zero_num,
                                                          d_paddings,
                                                          d_dilations,
                                                          d_strides,
                                                          out_index_table_ptr,
                                                          rulebook_ptr,
                                                          counter_ptr);

    out->SetMember(out_indices, out_values, out_dims, false);

    CallThrustScan(
        dev_ctx, counter_ptr, kernel_size, offsets_ptr, h_counter, h_offsets);

    dev_ctx.Wait();
    int rulebook_len = h_offsets[kernel_size - 1] + h_counter[kernel_size - 1];
    DenseTensor out_rulebook =
        phi::Empty<IntT>(dev_ctx, {rulebook_rows, rulebook_len});
    IntT* out_rulebook_ptr = out_rulebook.data<IntT>();
    config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
    cache_size = kernel_size * 2 * sizeof(int);
    CopyRuleBook<IntT><<<config.block_per_grid,
                         config.thread_per_block,
                         cache_size,
                         dev_ctx.stream()>>>(counter_ptr,
                                             offsets_ptr,
                                             rulebook_ptr,
                                             rulebook_len,
                                             kernel_size,
                                             non_zero_num,
                                             out_rulebook_ptr);
    *rulebook = out_rulebook;

717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
    unique_value->ResizeAndAllocate(
        {static_cast<int>(non_zero_num * h_max_voxel[0] * kernel_size)});
    int* unique_value_ptr = unique_value->data<int>();
    out_index->ResizeAndAllocate({static_cast<int>(rulebook_len)});
    int* out_index_ptr = out_index->data<int>();
    phi::backends::gpu::GpuMemsetAsync(
        out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream());
    GroupIndexs<<<config.block_per_grid,
                  config.thread_per_block,
                  0,
                  dev_ctx.stream()>>>(rulebook_len,
                                      kernel_size * h_max_voxel[0],
                                      out_rulebook_ptr + rulebook_len,
                                      out_index_ptr,
                                      unique_value_ptr);

733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
    return rulebook_len;

  } else {
    *rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta));
    IntT* rulebook_ptr = rulebook->data<IntT>();
    ProductRuleBookKernel<IntT><<<config.block_per_grid.x,
                                  config.thread_per_block.x,
                                  kernel_size * sizeof(int),
                                  dev_ctx.stream()>>>(indices_ptr,
                                                      d_x_dims,
                                                      d_kernel_dims,
                                                      d_out_dims,
                                                      non_zero_num,
                                                      d_paddings,
                                                      d_dilations,
                                                      d_strides,
                                                      rulebook_ptr,
                                                      counter_ptr);

    // 2. remove -1
#ifdef PADDLE_WITH_HIP
    IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()),
#else
    IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()),
#endif
                                rulebook_ptr,
                                rulebook_ptr + rulebook_rows * rulebook_cols,
                                -1);

    IntT rulebook_len = (last - rulebook_ptr) / 2;

    CallThrustScan(
        dev_ctx, counter_ptr, kernel_size, offsets_ptr, h_counter, h_offsets);

    rulebook->Resize({rulebook_rows, static_cast<int>(rulebook_len)});
    // 3. sorted or merge the out index
    out_index->ResizeAndAllocate({static_cast<int>(rulebook_len)});
    DenseTensor unique_key =
        phi::Empty<int>(dev_ctx, {static_cast<int>(rulebook_len)});
    int* out_index_ptr = out_index->data<int>();
    int* unique_key_ptr = unique_key.data<int>();

    phi::backends::gpu::GpuMemsetAsync(
        out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
    phi::backends::gpu::GpuMemsetAsync(
        unique_key_ptr, 0, sizeof(int), dev_ctx.stream());

    config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
    size_t cache_size = sizeof(int) * config.thread_per_block.x;
    UniqueKernel<IntT><<<config.block_per_grid,
                         config.thread_per_block,
                         cache_size,
                         dev_ctx.stream()>>>(rulebook_ptr + rulebook_len,
                                             rulebook_len,
                                             out_index_table_ptr,
                                             out_index_ptr,
                                             unique_key_ptr);
    int out_nnz = 0;
    phi::backends::gpu::GpuMemcpyAsync(&out_nnz,
                                       unique_key_ptr,
                                       sizeof(int),
                                       gpuMemcpyDeviceToHost,
                                       dev_ctx.stream());
    dev_ctx.Wait();

    const int64_t sparse_dim = 4;
    phi::DenseTensor out_indices =
        phi::Empty<IntT>(dev_ctx, {sparse_dim, out_nnz});
    phi::DenseTensor out_values =
        phi::Empty<T>(dev_ctx, {out_nnz, kernel_sizes[4]});
    out->SetMember(out_indices, out_values, out_dims, false);

    IntT* out_indices_ptr = out_indices.data<IntT>();

807 808 809 810 811 812 813
    phi::backends::gpu::GpuMemsetAsync(
        out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream());
    phi::backends::gpu::GpuMemsetAsync(out_index_table2_ptr,
                                       0,
                                       sizeof(int) * (table_size + 1),
                                       dev_ctx.stream());

814
    config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1);
815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
    GetOutIndexTable<IntT>
        <<<config.block_per_grid,
           config.thread_per_block,
           0,
           dev_ctx.stream()>>>(out_index_ptr,
                               out_nnz,
                               d_out_dims,
                               out_index_table_ptr,
                               out_index_table2_ptr,
                               out_index_table2_ptr + table_size,
                               out_indices_ptr);
    phi::backends::gpu::GpuMemcpyAsync(h_max_voxel,
                                       out_index_table2_ptr + table_size,
                                       sizeof(int),
                                       gpuMemcpyDeviceToHost,
                                       dev_ctx.stream());
    dev_ctx.Wait();

833
    config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1);
834 835
    unique_value->ResizeAndAllocate(
        {static_cast<int>(out_nnz * h_max_voxel[0] * kernel_size)});
836 837 838 839 840 841 842
    int* unique_value_ptr = unique_value->data<int>();

    GroupIndexs<<<config.block_per_grid,
                  config.thread_per_block,
                  0,
                  dev_ctx.stream()>>>(out_index_table_ptr,
                                      rulebook_len,
843
                                      kernel_size * h_max_voxel[0],
844 845 846 847 848 849 850 851 852 853
                                      rulebook_ptr + rulebook_len,
                                      out_index_ptr,
                                      unique_value_ptr);

    return rulebook_len;
  }
}

}  // namespace sparse
}  // namespace phi