stack_kernel.cu 4.2 KB
Newer Older
C
csy0225 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/stack_kernel.h"
C
csy0225 已提交
16 17
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
18
#include "paddle/phi/core/dense_tensor.h"
C
csy0225 已提交
19
#include "paddle/phi/core/kernel_registry.h"
20
#include "paddle/phi/kernels/funcs/segmented_array.h"
C
csy0225 已提交
21 22 23

namespace phi {

24 25 26
template <typename T, typename IndexT, typename ArrayT>
__global__ void StackCUDAKernel(ArrayT array,
                                funcs::GeneralDivMod<IndexT> divmoder,
27 28 29
                                IndexT split_size,
                                IndexT rows,
                                IndexT cols,
C
csy0225 已提交
30
                                T* __restrict__ output) {
31 32 33
  IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
  IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
  IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;
C
csy0225 已提交
34

35
  for (; grid_x < cols; grid_x += grid_x_stride) {
36
    IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
C
csy0225 已提交
37

38 39 40 41
    auto divmod_rslt = divmoder.div_mod(grid_x);
    IndexT split = divmod_rslt[0];       // grid_x / split_size
    IndexT col_offset = divmod_rslt[1];  // grid_x % split_size
    const T* input_ptr = array.data[split];
C
csy0225 已提交
42
#pragma unroll
43
    for (; grid_y < rows; grid_y += grid_y_stride) {
C
csy0225 已提交
44
      output[grid_y * cols + grid_x] =
45
          input_ptr[grid_y * split_size + col_offset];
C
csy0225 已提交
46 47 48 49
    }
  }
}

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
template <typename Context,
          typename T,
          typename IndexT,
          funcs::SegmentedArraySize Size>
void LaunchStackKernel(const Context& ctx,
                       const IndexT x_col,
                       const IndexT x_row,
                       const IndexT out_col,
                       const std::vector<const DenseTensor*>& x,
                       DenseTensor* out) {
  T* out_ptr = ctx.template Alloc<T>(out);
  auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row);

  funcs::ConstPointerArraySetter<Context, T, Size> setter(ctx, x);
  funcs::GeneralDivMod<IndexT> divmoder(x_col);
  StackCUDAKernel<T, IndexT, decltype(setter.array)>
      <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
          setter.array, divmoder, x_col, x_row, out_col, out_ptr);
68 69
}

C
csy0225 已提交
70
template <typename T, typename Context>
71
void StackKernel(const Context& ctx,
C
csy0225 已提交
72 73 74 75
                 const std::vector<const DenseTensor*>& x,
                 int axis,
                 DenseTensor* out) {
  if (axis < 0) axis += (x[0]->dims().size() + 1);
76
  int num = static_cast<int>(x.size());
C
csy0225 已提交
77 78

  // Split x dim from axis to matrix
79
  int64_t x_row = 1;
C
csy0225 已提交
80 81 82
  for (int i = 0; i < axis; ++i) {
    x_row *= x[0]->dims()[i];
  }
83
  int64_t x_col = x[0]->numel() / x_row;
84
  int64_t out_col = x_col * num;
C
csy0225 已提交
85

86
  if (out->numel() < std::numeric_limits<int32_t>::max()) {
87 88 89 90 91
    switch (funcs::CalcArraySize(num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchStackKernel<Context, T, int32_t, kArraySize>(
              ctx, x_col, x_row, out_col, x, out));
    }
C
csy0225 已提交
92
  } else {
93 94 95 96 97
    switch (funcs::CalcArraySize(num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchStackKernel<Context, T, int64_t, kArraySize>(
              ctx, x_col, x_row, out_col, x, out));
    }
C
csy0225 已提交
98 99
  }
}
100

C
csy0225 已提交
101 102 103 104 105 106 107 108
}  // namespace phi

PD_REGISTER_KERNEL(stack,
                   GPU,
                   ALL_LAYOUT,
                   phi::StackKernel,
                   float,
                   double,
109
                   bool,
C
csy0225 已提交
110 111
                   int64_t,
                   int,
112 113
                   uint8_t,
                   int8_t,
C
csy0225 已提交
114 115
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}