concat_and_split_functor.cu 35.5 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
16
#include "paddle/fluid/memory/malloc.h"
17
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
18
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
19

L
Leo Chen 已提交
20 21 22
namespace phi {
namespace funcs {

MarDino's avatar
MarDino 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
static inline void GetBlockDims(const phi::GPUContext& context,
                                int64_t num_rows,
                                int64_t num_cols,
                                dim3* block_dims,
                                dim3* grid_dims) {
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  constexpr int waves = 1;
  int max_threads = context.GetMaxPhysicalThreadCount() * waves;
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

48 49 50
template <typename T, int Size>
struct PointerWrapper {
 public:
MarDino's avatar
MarDino 已提交
51 52
  const void* ins_addr[Size];
  __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
53 54 55 56 57 58

  PointerWrapper() {}
  PointerWrapper(const phi::GPUContext& ctx,
                 const std::vector<phi::DenseTensor>& ins,
                 const T** pre_alloced_host_ptr) {
    for (auto i = 0; i < ins.size(); ++i) {
MarDino's avatar
MarDino 已提交
59
      ins_addr[i] = ins[i].data();
60 61 62 63
    }
  }
};

L
Leo Chen 已提交
64
template <typename T>
65 66
struct PointerToPointer {
 public:
MarDino's avatar
MarDino 已提交
67 68
  void** ins_addr{nullptr};
  __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

  PointerToPointer() {}
  PointerToPointer(const phi::GPUContext& ctx,
                   const std::vector<phi::DenseTensor>& ins,
                   const T** pre_alloced_host_ptr,
                   paddle::memory::AllocationPtr* dev_ins_ptr) {
    auto in_num = ins.size();
    for (auto i = 0; i < in_num; ++i) {
      pre_alloced_host_ptr[i] = ins[i].data<T>();
    }
    *dev_ins_ptr = paddle::memory::Alloc(
        ctx.GetPlace(),
        in_num * sizeof(T*),
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
    auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
        pre_alloced_host_ptr, in_num);
    paddle::memory::Copy(ctx.GetPlace(),
                         (*dev_ins_ptr)->ptr(),
                         phi::CPUPlace(),
                         restored,
                         in_num * sizeof(T*),
                         ctx.stream());
MarDino's avatar
MarDino 已提交
91
    ins_addr = reinterpret_cast<void**>((*dev_ins_ptr)->ptr());
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  }
};

template <typename T, typename IndexT, int Size>
struct PointerAndColWrapper {
 public:
  IndexT col_length[Size];
  PointerAndColWrapper(const phi::GPUContext& ctx,
                       const std::vector<phi::DenseTensor>& ins,
                       const IndexT& inputs_col_num,
                       const T** pre_alloced_host_ptr,
                       IndexT* inputs_col) {
    for (auto i = 0; i < inputs_col_num; ++i) {
      col_length[i] = inputs_col[i];
    }
    ins_ptr_wrapper = PointerWrapper<T, Size>(ctx, ins, pre_alloced_host_ptr);
  }

MarDino's avatar
MarDino 已提交
110
  __device__ inline const void* operator[](int i) const {
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    return ins_ptr_wrapper[i];
  }

 private:
  PointerWrapper<T, Size> ins_ptr_wrapper;
};

template <typename T, typename IndexT>
struct PointerToPointerAndCol {
 public:
  IndexT* col_length{nullptr};
  PointerToPointerAndCol(const phi::GPUContext& ctx,
                         const std::vector<phi::DenseTensor>& ins,
                         const IndexT inputs_col_num,
                         const T** pre_alloced_host_ptr,
                         IndexT* inputs_col,
                         paddle::memory::AllocationPtr* dev_ins_ptr,
                         paddle::memory::AllocationPtr* dev_col_ptr) {
    *dev_col_ptr = paddle::memory::Alloc(
        ctx.GetPlace(),
        inputs_col_num * sizeof(IndexT),
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
    auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
        inputs_col, inputs_col_num);
    paddle::memory::Copy(ctx.GetPlace(),
                         (*dev_col_ptr)->ptr(),
                         phi::CPUPlace(),
                         restored,
                         inputs_col_num * sizeof(IndexT),
                         ctx.stream());
    col_length = static_cast<IndexT*>((*dev_col_ptr)->ptr());
    ins_ptr_wrapper =
        PointerToPointer<T>(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr);
  }

MarDino's avatar
MarDino 已提交
146
  __device__ inline const void* operator[](int i) const {
147 148 149 150 151 152 153
    return ins_ptr_wrapper[i];
  }

 private:
  PointerToPointer<T> ins_ptr_wrapper;
};

MarDino's avatar
MarDino 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
template <int MovSize>
struct alignas(MovSize) Packed {
  __device__ Packed() {
    // do nothing
  }
  union {
    char buf[MovSize];
  };
};

template <typename IndexT, int MovSize, typename PointerAndColWrapperT>
__global__ void ConcatTensorWithDifferentShape(
    const PointerAndColWrapperT ins_datas,
    int col_size,
    const IndexT output_rows,
    const IndexT output_cols,
    void* output) {
  Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output);

173 174
  IndexT curr_segment = 0;
  IndexT curr_offset = ins_datas.col_length[0];
MarDino's avatar
MarDino 已提交
175

176 177
  CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) {
    IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1];
MarDino's avatar
MarDino 已提交
178

L
Leo Chen 已提交
179 180 181
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
182
      curr_col_offset = ins_datas.col_length[curr_segment + 1];
L
Leo Chen 已提交
183 184
    }

185 186
    IndexT local_col = tid_x - curr_offset;
    IndexT segment_width = curr_col_offset - curr_offset;
L
Leo Chen 已提交
187

MarDino's avatar
MarDino 已提交
188 189 190
    const Packed<MovSize>* input_ptr =
        reinterpret_cast<const Packed<MovSize>*>(ins_datas[curr_segment]);

191
    IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
MarDino's avatar
MarDino 已提交
192 193 194

    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
      dst[tid_y * output_cols + tid_x] =
L
Leo Chen 已提交
195
          input_ptr[tid_y * segment_width + local_col];
MarDino's avatar
MarDino 已提交
196
    }
L
Leo Chen 已提交
197 198 199
  }
}

MarDino's avatar
MarDino 已提交
200 201
template <typename IndexT, int MovSize, typename PointerWrapperT>
__global__ void ConcatTensorWithSameShape(const PointerWrapperT ins_data,
202 203 204
                                          const IndexT fixed_in_col,
                                          const IndexT out_rows,
                                          const IndexT out_cols,
MarDino's avatar
MarDino 已提交
205 206
                                          void* output_data) {
  Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output_data);
207 208 209
  CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) {
    IndexT split = tid_x / fixed_in_col;
    IndexT in_offset = tid_x - split * fixed_in_col;
MarDino's avatar
MarDino 已提交
210 211
    const Packed<MovSize>* input_ptr =
        reinterpret_cast<const Packed<MovSize>*>(ins_data[split]);
212
    IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
213
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
MarDino's avatar
MarDino 已提交
214
      dst[tid_y * out_cols + tid_x] =
L
Leo Chen 已提交
215 216 217 218 219
          input_ptr[tid_y * fixed_in_col + in_offset];
    }
  }
}

MarDino's avatar
MarDino 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \
  func_impl(4, ##__VA_ARGS__);                          \
  func_impl(8, ##__VA_ARGS__);                          \
  func_impl(16, ##__VA_ARGS__);                         \
  func_impl(32, ##__VA_ARGS__);                         \
  func_impl(64, ##__VA_ARGS__);                         \
  func_impl(128, ##__VA_ARGS__);

template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithDifferentShapeKernelLimitNum(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const IndexT inputs_col_num,
    const T** inputs_data,
    IndexT* inputs_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT in_num,
    const IndexT limit_num) {
  dim3 block_dims;
  dim3 grid_dims;
  GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);

#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...)    \
  case size_: {                                             \
    PointerAndColWrapper<T, IndexT, size_> ptr_col_array(   \
        ctx, ins, inputs_col_num, inputs_data, inputs_col); \
    __VA_ARGS__;                                            \
  } break;
  switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
    IMPL_CONCATE_CUDA_KERNEL_HELPER(
        IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE,
        ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
        <<<grid_dims, block_dims, 0, ctx.stream()>>>(
            ptr_col_array, inputs_col_num, out_row, out_col, output->data()));
    default: {
      paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
      paddle::memory::AllocationPtr dev_col_ptr{nullptr};
      PointerToPointerAndCol<T, IndexT> ptr_col_array(ctx,
                                                      ins,
                                                      inputs_col_num,
                                                      inputs_data,
                                                      inputs_col,
                                                      &dev_ins_ptr,
                                                      &dev_col_ptr);
      ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
          <<<grid_dims, block_dims, 0, ctx.stream()>>>(
              ptr_col_array, inputs_col_num, out_row, out_col, output->data());
    }
  }
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
}

template <typename T, typename IndexT>
void DispatchConcatWithDifferentShapeMovsize(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const IndexT inputs_col_num,
    const T** inputs_data,
    IndexT* inputs_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT mov_size,
    const IndexT in_num,
    const IndexT limit_num) {
  if (mov_size == 16) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 16>(
        ctx,
        ins,
        inputs_col_num,
        inputs_data,
        inputs_col,
        out_row,
        out_col,
        output,
        in_num,
        limit_num);
  } else if (mov_size == 8) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 8>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else if (mov_size == 4) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 4>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else if (mov_size == 2) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 2>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 1>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  }
}

template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithSameShapeKernelLimitNum(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const T** inputs_data,
    IndexT in_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT in_num,
    const IndexT limit_num) {
  dim3 block_dims;
  dim3 grid_dims;
  GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);

#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...)               \
  case size_: {                                                \
    PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
    __VA_ARGS__;                                               \
  } break;

  switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
    IMPL_CONCATE_CUDA_KERNEL_HELPER(
        IMPL_CONCAT_CUDA_KERNEL_CASE,
        ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
        <<<grid_dims, block_dims, 0, ctx.stream()>>>(
            ptr_array, in_col, out_row, out_col, output->data()));
    default: {
      paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
      PointerToPointer<T> ptr_array(ctx, ins, inputs_data, &dev_ins_ptr);
      ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
          <<<grid_dims, block_dims, 0, ctx.stream()>>>(
              ptr_array, in_col, out_row, out_col, output->data());
    }
  }
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
}

#undef IMPL_CONCATE_CUDA_KERNEL_HELPER

template <typename T, typename IndexT>
void DispatchConcatWithSameShapeMovsize(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const T** inputs_data,
    IndexT in_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT mov_size,
    const IndexT in_num,
    const IndexT limit_num) {
  if (mov_size == 16) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 16>(ctx,
                                                             ins,
                                                             inputs_data,
                                                             in_col,
                                                             out_row,
                                                             out_col,
                                                             output,
                                                             in_num,
                                                             limit_num);
  } else if (mov_size == 8) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 8>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else if (mov_size == 4) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 4>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else if (mov_size == 2) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 2>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 1>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  }
}

template <typename T, typename IndexT>
void DispatchConcatKernel(const phi::GPUContext& ctx,
                          const std::vector<phi::DenseTensor>& ins,
                          const IndexT inputs_col_num,
                          const T** inputs_data,
                          IndexT* inputs_col,
                          const IndexT out_row,
                          const IndexT out_col,
                          phi::DenseTensor* output,
                          const IndexT in_num,
                          const IndexT limit_num,
                          bool has_same_shape) {
  constexpr IndexT MaxVecSize = 16 / sizeof(T);
  bool find_vecsize_flag = false;
  IndexT dispatch_vec_size = 1;
466 467

  auto output_data = reinterpret_cast<std::uintptr_t>(output->data());
MarDino's avatar
MarDino 已提交
468
  for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
469 470 471
    const IndexT mov_size = vec_size * sizeof(T);
    for (IndexT idx = 1; idx < in_num + 1; idx++) {
      auto input_data = reinterpret_cast<std::uintptr_t>(inputs_data[idx - 1]);
MarDino's avatar
MarDino 已提交
472
      // Since input_cols[0] is 0, we need to jump.
473 474 475 476
      const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1];
      if (input_col % vec_size == 0 && output_data % mov_size == 0 &&
          input_data % mov_size == 0) {
        if (idx == in_num) {
MarDino's avatar
MarDino 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
          find_vecsize_flag = true;
        }
      } else {
        break;
      }
    }
    if (find_vecsize_flag) {
      dispatch_vec_size = vec_size;
      break;
    }
  }

  const int64_t vectorized_out_col = out_col / dispatch_vec_size;
  for (IndexT idx = 0; idx < in_num + 1; idx++) {
    inputs_col[idx] /= dispatch_vec_size;
  }
  const IndexT mov_size = sizeof(T) * dispatch_vec_size;
  if (has_same_shape) {
    // In same shape situation, each input's col are equal, so here we select to
    // use inputs_col[1].
    DispatchConcatWithSameShapeMovsize<T, IndexT>(ctx,
                                                  ins,
                                                  inputs_data,
                                                  inputs_col[1],
                                                  out_row,
                                                  vectorized_out_col,
                                                  output,
                                                  mov_size,
                                                  in_num,
                                                  limit_num);
  } else {
    DispatchConcatWithDifferentShapeMovsize<T, IndexT>(ctx,
                                                       ins,
                                                       inputs_col_num,
                                                       inputs_data,
                                                       inputs_col,
                                                       out_row,
                                                       vectorized_out_col,
                                                       output,
                                                       mov_size,
                                                       in_num,
                                                       limit_num);
  }
}

L
Leo Chen 已提交
522 523 524 525 526 527 528
template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t* out_cols,
                             int out_cols_size,
                             T** outputs_data) {
529 530 531 532
  int64_t curr_segment = 0;
  int64_t curr_offset = out_cols[0];
  CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
    int64_t curr_col_offset = out_cols[curr_segment + 1];
L
Leo Chen 已提交
533 534 535 536 537 538
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
      curr_col_offset = out_cols[curr_segment + 1];
    }

539 540
    int64_t local_col = tid_x - curr_offset;
    int64_t segment_width = curr_col_offset - curr_offset;
L
Leo Chen 已提交
541 542
    T* output_ptr = outputs_data[curr_segment];
    if (output_ptr != nullptr) {
543
      int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
544 545 546 547 548 549 550 551 552
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
  }
}

template <typename T>
__device__ void SplitKernelDetail(const T* input_data,
553 554 555
                                  const int64_t in_row,
                                  const int64_t in_col,
                                  const int64_t fixed_out_col,
L
Leo Chen 已提交
556
                                  T** outputs_data) {
557 558 559
  CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
    int64_t split = tid_x / fixed_out_col;
    int64_t in_offset = tid_x - split * fixed_out_col;
L
Leo Chen 已提交
560 561
    T* output_ptr = outputs_data[split];
    if (output_ptr != nullptr) {
562
      int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
  }
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1,
                             T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1,
                             T* outputs_addr2,
                             T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

/*
 * All tensors' dimension should be the same and the values of
 * each dimension must be the same, except the axis dimension.
 */
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
template <typename T, typename IndexT>
void ConcatFunctorWithIndexType(const phi::GPUContext& ctx,
                                const std::vector<phi::DenseTensor>& ins,
                                int axis,
                                phi::DenseTensor* output) {
  // TODO(zcd): Add input data validity checking
  IndexT in_num = ins.size();
  IndexT in_row = 1;
  auto dim_0 = ins[0].dims();
  for (int i = 0; i < axis; ++i) {
    in_row *= dim_0[i];
  }
  IndexT in_col = ins[0].numel() / in_row;
  IndexT out_row = in_row, out_col = 0;

  IndexT inputs_col_num = in_num + 1;
  std::vector<const T*> inputs_data_vec(in_num, nullptr);
  std::vector<IndexT> inputs_col_vec(inputs_col_num, 0);
  const T** inputs_data = inputs_data_vec.data();
  IndexT* inputs_col = inputs_col_vec.data();
L
Leo Chen 已提交
648
#ifdef PADDLE_WITH_HIP
649 650 651 652 653 654 655
  // TODO(chentianyu03): try to find a method to remove the Alloc function
  paddle::memory::AllocationPtr data_alloc = paddle::memory::Alloc(
      paddle::platform::CUDAPinnedPlace(), in_num * sizeof(T*));
  inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
  paddle::memory::AllocationPtr col_alloc = paddle::memory::Alloc(
      paddle::platform::CUDAPinnedPlace(), inputs_col_num * sizeof(IndexT));
  inputs_col = reinterpret_cast<IndexT*>(col_alloc->ptr());
L
Leo Chen 已提交
656 657
#endif

658 659 660 661 662
  bool has_same_shape = true;
  for (int i = 0; i < in_num; ++i) {
    IndexT t_cols = ins[i].numel() / in_row;
    if (has_same_shape) {
      has_same_shape &= (t_cols == in_col);
L
Leo Chen 已提交
663
    }
664 665 666 667 668
    out_col += t_cols;
    inputs_col[i + 1] = out_col;
  }
  IndexT limit_num = has_same_shape ? in_num : inputs_col_num;

MarDino's avatar
MarDino 已提交
669 670 671 672 673 674 675 676 677 678 679
  DispatchConcatKernel<T, IndexT>(ctx,
                                  ins,
                                  inputs_col_num,
                                  inputs_data,
                                  inputs_col,
                                  out_row,
                                  out_col,
                                  output,
                                  in_num,
                                  limit_num,
                                  has_same_shape);
L
Leo Chen 已提交
680

W
Wilber 已提交
681
#ifdef PADDLE_WITH_HIP
682 683 684 685 686 687 688 689 690 691 692 693
  // Prevent pinned memory from being covered and release the memory after
  // kernel launch of the stream is executed (reapply pinned memory next time)
  auto* data_alloc_released = data_alloc.release();
  auto* col_alloc_released = col_alloc.release();
  ctx.AddStreamCallback([data_alloc_released, col_alloc_released] {
    VLOG(4) << "Delete cuda pinned at " << data_alloc_released;
    VLOG(4) << "Delete cuda pinned at " << col_alloc_released;
    paddle::memory::allocation::Allocator::AllocationDeleter(
        data_alloc_released);
    paddle::memory::allocation::Allocator::AllocationDeleter(
        col_alloc_released);
  });
W
Wilber 已提交
694
#endif
695 696 697 698 699 700 701 702 703 704 705 706 707
}

template <typename T>
struct ConcatFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const std::vector<phi::DenseTensor>& input,
                  int axis,
                  phi::DenseTensor* output) {
    if (output->numel() < std::numeric_limits<int32_t>::max()) {
      ConcatFunctorWithIndexType<T, int32_t>(context, input, axis, output);
    } else {
      ConcatFunctorWithIndexType<T, int64_t>(context, input, axis, output);
    }
L
Leo Chen 已提交
708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
  }
};

template <typename T>
class SplitFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const phi::DenseTensor& input,
                  const std::vector<const phi::DenseTensor*>& ref_inputs,
                  int axis,
                  std::vector<phi::DenseTensor*>* outputs) {
    // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
    // tensors of shape [0,1,4]
    if (input.numel() == 0) {
      return;
    }

    // TODO(zcd): Add input data validity checking
    int o_num = outputs->size();
    int64_t out_row = 1;
    auto dim_0 = ref_inputs[0]->dims();
    for (int i = 0; i < axis; ++i) {
      out_row *= dim_0[i];
    }

    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
    bool has_same_shape = true;

    int outputs_cols_num = o_num + 1;
W
Wilber 已提交
738 739 740 741
    std::vector<T*> outputs_data_vec(o_num);
    std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
    T** outputs_data = outputs_data_vec.data();
    int64_t* outputs_cols = outputs_cols_vec.data();
L
Leo Chen 已提交
742 743 744 745 746 747 748 749 750

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
W
Wilber 已提交
751
    paddle::memory::AllocationPtr data_alloc, cols_alloc;
L
Leo Chen 已提交
752 753 754
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                       o_num * sizeof(T*));
W
Wilber 已提交
755
    outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
L
Leo Chen 已提交
756 757 758
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                       (outputs_cols_num) * sizeof(int64_t));
W
Wilber 已提交
759
    outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
L
Leo Chen 已提交
760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
#endif

    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
      }
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
      if (outputs->at(i) != nullptr) {
        outputs_data[i] = outputs->at(i)->data<T>();
      } else {
        outputs_data[i] = nullptr;
      }
    }

    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    paddle::memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
    if (!has_same_shape || o_num < 2 || o_num > 4) {
      // TODO(chentianyu03): try to find a method to remove the Alloc function
785 786 787 788
      tmp_dev_outs_data = paddle::memory::Alloc(
          context.GetPlace(),
          o_num * sizeof(T*),
          phi::Stream(reinterpret_cast<phi::StreamId>(context.stream())));
789
      auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
L
Leo Chen 已提交
790 791 792
          outputs_data, o_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_outs_data->ptr(),
793
                           phi::CPUPlace(),
L
Leo Chen 已提交
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835
                           restored,
                           o_num * sizeof(T*),
                           context.stream());
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
    }

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1]);
      } else if (o_num == 3) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1],
            outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1],
            outputs_data[2],
            outputs_data[3]);
      } else {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
    } else {
      auto tmp_dev_ins_col_data =
          // TODO(chentianyu03): try to find a method to remove the Alloc
          // function
836 837 838 839
          paddle::memory::Alloc(
              context.GetPlace(),
              outputs_cols_num * sizeof(int64_t),
              phi::Stream(reinterpret_cast<phi::StreamId>(context.stream())));
840
      auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
L
Leo Chen 已提交
841 842 843
          outputs_cols, outputs_cols_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_ins_col_data->ptr(),
844
                           phi::CPUPlace(),
L
Leo Chen 已提交
845 846 847 848 849 850 851 852 853 854 855 856 857 858
                           restored,
                           outputs_cols_num * sizeof(int64_t),
                           context.stream());
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());

      SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
          input.data<T>(),
          in_row,
          in_col,
          dev_outs_col_data,
          static_cast<int>(outputs_cols_num),
          dev_out_gpu_data);
    }
859

W
Wilber 已提交
860
#ifdef PADDLE_WITH_HIP
L
Leo Chen 已提交
861 862 863 864 865 866 867 868 869 870 871
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* cols_alloc_released = cols_alloc.release();
    context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
      paddle::memory::allocation::Allocator::AllocationDeleter(
          data_alloc_released);
      paddle::memory::allocation::Allocator::AllocationDeleter(
          cols_alloc_released);
    });
W
Wilber 已提交
872
#endif
L
Leo Chen 已提交
873 874 875 876 877 878 879 880 881 882 883
  }
};

#define DEFINE_FUNCTOR(type)                           \
  template class ConcatFunctor<phi::GPUContext, type>; \
  template class SplitFunctor<phi::GPUContext, type>

FOR_ALL_TYPES(DEFINE_FUNCTOR);

}  // namespace funcs
}  // namespace phi