stack_and_unstack.h 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"

namespace phi {
namespace funcs {

template <typename T, typename IndexT, typename ArrayT>
__global__ void StackCudaKernel(ArrayT array,
                                GeneralDivMod<IndexT> divmoder,
                                IndexT split_size,
                                IndexT rows,
                                IndexT cols,
                                T* __restrict__ output) {
  IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
  IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
  IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;

  for (; grid_x < cols; grid_x += grid_x_stride) {
    IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;

    auto divmod_rslt = divmoder.div_mod(grid_x);
    IndexT split = divmod_rslt[0];       // grid_x / split_size
    IndexT col_offset = divmod_rslt[1];  // grid_x % split_size
    const T* input_ptr = array.data[split];
#pragma unroll
    for (; grid_y < rows; grid_y += grid_y_stride) {
      output[grid_y * cols + grid_x] =
          input_ptr[grid_y * split_size + col_offset];
    }
  }
}

template <typename Context,
          typename T,
          typename IndexT,
          SegmentedArraySize Size>
void LaunchStackKernel(const Context& ctx,
                       const IndexT x_col,
                       const IndexT x_row,
                       const IndexT out_col,
                       const std::vector<const DenseTensor*>& x,
                       DenseTensor* out) {
  T* out_ptr = ctx.template Alloc<T>(out);
  auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row);

  ConstPointerArraySetter<Context, T, Size> setter(ctx, x);
  GeneralDivMod<IndexT> divmoder(x_col);
  StackCudaKernel<T, IndexT, decltype(setter.array)>
      <<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
          setter.array, divmoder, x_col, x_row, out_col, out_ptr);
}

template <typename T, typename Context>
void StackRawKernel(const Context& ctx,
                    const std::vector<const DenseTensor*>& x,
                    int axis,
                    DenseTensor* out) {
  if (axis < 0) axis += (x[0]->dims().size() + 1);
  int num = static_cast<int>(x.size());

  // Split x dim from axis to matrix of shape [x_row, x_col], and the output
  // tensor's shape is [x_row, out_col].
  int64_t x_row = 1;
  for (int i = 0; i < axis; ++i) {
    x_row *= x[0]->dims()[i];
  }
  int64_t x_col = x[0]->numel() / x_row;
  int64_t out_col = x_col * num;

  if (out->numel() < std::numeric_limits<int32_t>::max()) {
    switch (CalcArraySize(num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchStackKernel<Context, T, int32_t, kArraySize>(
              ctx, x_col, x_row, out_col, x, out));
    }
  } else {
    switch (CalcArraySize(num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchStackKernel<Context, T, int64_t, kArraySize>(
              ctx, x_col, x_row, out_col, x, out));
    }
  }
}

template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernel(const T* __restrict__ input,
                                  IndexT out_row,
                                  IndexT split_dim,
                                  IndexT out_col,
                                  IndexT num_splits,
                                  GeneralDivMod<IndexT> col_divmoder,
                                  ArrayT array) {
  assert(blockDim.y == 1);
  assert(blockDim.z == 1);
  // In this case they are equal
  assert(split_dim % num_splits == 0);

  IndexT numel = out_row * split_dim * out_col;
  IndexT each_dim_size = split_dim / num_splits;
  IndexT split_dim_with_out_col = split_dim * out_col;

  IndexT offset = blockIdx.x * blockDim.x + threadIdx.x;
  if (each_dim_size == 1) {
    for (; offset < numel; offset += blockDim.x * gridDim.x) {
      auto col_divmod_rslt = col_divmoder.div_mod(offset);

      IndexT i = offset / split_dim_with_out_col;
      IndexT j = col_divmod_rslt[0] - i * split_dim;
      IndexT k = col_divmod_rslt[1];  // offset % out_col

      T* output = array.data[j];
      if (output) {
        IndexT output_idx = i * out_col + k;
        *(output + output_idx) = input[offset];
      }
    }
  } else {
    for (; offset < numel; offset += blockDim.x * gridDim.x) {
      auto col_divmod_rslt = col_divmoder.div_mod(offset);

      IndexT i = offset / split_dim_with_out_col;
      IndexT j = col_divmod_rslt[0] - i * split_dim;
      IndexT k = col_divmod_rslt[1];  // offset % out_col

      T* output = array.data[j / each_dim_size];
      if (output) {
        IndexT output_idx = (i + j % each_dim_size) * out_col + k;
        *(output + output_idx) = input[offset];
      }
    }
  }
}

template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data,
                                            const IndexT cols,
                                            const IndexT rows,
                                            const IndexT tile_x_num,
                                            ArrayT array) {
  constexpr int buffer_size = 512;
  __shared__ T s_buf[buffer_size];

  for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
    IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
    IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y;
    int s_idx = threadIdx.y * blockDim.x + threadIdx.x;
    bool is_valid = (col_idx < cols && row_idx < rows);

    if (is_valid) {
      T data = in_data[row_idx * cols + col_idx];
      s_buf[s_idx] = data;
    }
    __syncthreads();
    if (is_valid) {
      if (array.data[col_idx]) {
        array.data[col_idx][row_idx] = s_buf[s_idx];
      }
    }
  }
}

template <typename Context,
          typename T,
          typename IndexT,
          SegmentedArraySize Size>
void LaunchUnStackKernel(const Context& ctx,
                         const IndexT out_row,
                         const IndexT split_dim,
                         const IndexT out_col,
                         const IndexT num_splits,
                         const DenseTensor& x,
                         std::vector<DenseTensor*>* outs) {
  // each tensor in outs should have same shape.
  VLOG(6) << "out_row=" << out_row << ", split_dim=" << split_dim
          << ", out_col=" << out_col << ", num_splits=" << num_splits;

  auto x_ptr = x.data<T>();
  PointerArraySetter<Context, T, Size> setter(ctx, outs);

  if (out_col == 1) {
    // For the case axis == (x.dims().size() - 1)
    constexpr int kThreads = 512;
    constexpr int kWarpSize = 32;
    constexpr int kMaxOut = 16;

    int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
    if (split_dim < kMaxOut) {
      tid_y = split_dim;
      tid_x =
          std::min(backends::gpu::RoundToNextHighPowOfTwo(out_row, kWarpSize),
                   kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
    } else {
      tid_y = kMaxOut;
      tid_x = kWarpSize;
      bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
    }
    int tile_x_num = backends::gpu::DivUp<int>(out_row, tid_x);
    bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
    dim3 blocks(tid_x, tid_y, 1);
    dim3 grids(bid_x, bid_y, 1);

    UnStackCudaKernelForLastDim<T, IndexT, decltype(setter.array)>
        <<<grids, blocks, 0, ctx.stream()>>>(
            x_ptr, split_dim, out_row, tile_x_num, setter.array);
  } else {
    GeneralDivMod<IndexT> col_divmoder(out_col);
    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
        ctx, out_row * split_dim * out_col);

    UnStackCudaKernel<T, IndexT, decltype(setter.array)>
        <<<config.block_per_grid.x,
           config.thread_per_block.x,
           0,
           ctx.stream()>>>(x_ptr,
                           out_row,
                           split_dim,
                           out_col,
                           num_splits,
                           col_divmoder,
                           setter.array);
  }
}

template <typename T, typename Context>
void UnStackRawKernel(const Context& ctx,
                      const DenseTensor& x,
                      int axis,
                      std::vector<DenseTensor*>* outs) {
  auto x_dims = x.dims();

  // Input tensor is splited to split_dim tensors along split_dim dimension.
  int64_t split_dim = x_dims[axis];

  // Treat outs[i] as [out_row, out_col], and x as [out_row, split_dim,
  // out_col].
  int64_t out_row = 1;
  for (int i = 0; i < axis; ++i) {
    out_row *= x_dims[i];
  }

  int64_t out_col = x.numel() / (split_dim * out_row);

  if (x.numel() < std::numeric_limits<int32_t>::max()) {
    switch (CalcArraySize(split_dim)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchUnStackKernel<Context, T, int32_t, kArraySize>(
              ctx, out_row, split_dim, out_col, split_dim, x, outs));
    }
  } else {
    switch (CalcArraySize(split_dim)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          LaunchUnStackKernel<Context, T, int64_t, kArraySize>(
              ctx, out_row, split_dim, out_col, split_dim, x, outs));
    }
  }
}

}  // namespace funcs
}  // namespace phi