concat_and_split_functor.cu 35.0 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
16
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
17
#include "paddle/phi/common/memory_utils.h"
18
#include "paddle/phi/kernels/funcs/segmented_array.h"
19

L
Leo Chen 已提交
20 21 22
namespace phi {
namespace funcs {

MarDino's avatar
MarDino 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
static inline void GetBlockDims(const phi::GPUContext& context,
                                int64_t num_rows,
                                int64_t num_cols,
                                dim3* block_dims,
                                dim3* grid_dims) {
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  constexpr int waves = 1;
  int max_threads = context.GetMaxPhysicalThreadCount() * waves;
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

48 49 50 51 52 53
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x)
#endif

54 55 56
template <typename T, int Size>
struct PointerWrapper {
 public:
MarDino's avatar
MarDino 已提交
57 58
  const void* ins_addr[Size];
  __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
59 60 61 62 63

  PointerWrapper() {}
  PointerWrapper(const phi::GPUContext& ctx,
                 const std::vector<phi::DenseTensor>& ins,
                 const T** pre_alloced_host_ptr) {
64 65 66 67 68
    SetInputAddr(ins);
  }

 protected:
  void SetInputAddr(const std::vector<phi::DenseTensor>& ins) {
69
    for (auto i = 0; i < ins.size(); ++i) {
MarDino's avatar
MarDino 已提交
70
      ins_addr[i] = ins[i].data();
71 72 73 74
    }
  }
};

75 76 77 78 79 80 81 82 83 84 85 86
template <typename T, int Size>
struct PADDLE_ALIGN(256) AlignedPointerWrapper
    : public PointerWrapper<T, Size> {
 public:
  AlignedPointerWrapper() {}
  AlignedPointerWrapper(const phi::GPUContext& ctx,
                        const std::vector<phi::DenseTensor>& ins,
                        const T** pre_alloced_host_ptr) {
    this->SetInputAddr(ins);
  }
};

L
Leo Chen 已提交
87
template <typename T>
88 89
struct PointerToPointer {
 public:
MarDino's avatar
MarDino 已提交
90 91
  void** ins_addr{nullptr};
  __device__ inline const void* operator[](int i) const { return ins_addr[i]; }
92 93 94 95 96

  PointerToPointer() {}
  PointerToPointer(const phi::GPUContext& ctx,
                   const std::vector<phi::DenseTensor>& ins,
                   const T** pre_alloced_host_ptr,
97
                   phi::Allocator::AllocationPtr* dev_ins_ptr) {
98 99 100 101
    auto in_num = ins.size();
    for (auto i = 0; i < in_num; ++i) {
      pre_alloced_host_ptr[i] = ins[i].data<T>();
    }
102
    *dev_ins_ptr = phi::memory_utils::Alloc(
103 104 105 106 107 108 109 110 111 112 113
        ctx.GetPlace(),
        in_num * sizeof(T*),
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
    auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
        pre_alloced_host_ptr, in_num);
    paddle::memory::Copy(ctx.GetPlace(),
                         (*dev_ins_ptr)->ptr(),
                         phi::CPUPlace(),
                         restored,
                         in_num * sizeof(T*),
                         ctx.stream());
MarDino's avatar
MarDino 已提交
114
    ins_addr = reinterpret_cast<void**>((*dev_ins_ptr)->ptr());
115 116 117 118
  }
};

template <typename T, typename IndexT, int Size>
119
struct PADDLE_ALIGN(256) PointerAndColWrapper {
120 121 122 123 124 125 126 127 128 129 130 131 132
 public:
  IndexT col_length[Size];
  PointerAndColWrapper(const phi::GPUContext& ctx,
                       const std::vector<phi::DenseTensor>& ins,
                       const IndexT& inputs_col_num,
                       const T** pre_alloced_host_ptr,
                       IndexT* inputs_col) {
    for (auto i = 0; i < inputs_col_num; ++i) {
      col_length[i] = inputs_col[i];
    }
    ins_ptr_wrapper = PointerWrapper<T, Size>(ctx, ins, pre_alloced_host_ptr);
  }

MarDino's avatar
MarDino 已提交
133
  __device__ inline const void* operator[](int i) const {
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    return ins_ptr_wrapper[i];
  }

 private:
  PointerWrapper<T, Size> ins_ptr_wrapper;
};

template <typename T, typename IndexT>
struct PointerToPointerAndCol {
 public:
  IndexT* col_length{nullptr};
  PointerToPointerAndCol(const phi::GPUContext& ctx,
                         const std::vector<phi::DenseTensor>& ins,
                         const IndexT inputs_col_num,
                         const T** pre_alloced_host_ptr,
                         IndexT* inputs_col,
150 151 152
                         phi::Allocator::AllocationPtr* dev_ins_ptr,
                         phi::Allocator::AllocationPtr* dev_col_ptr) {
    *dev_col_ptr = phi::memory_utils::Alloc(
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
        ctx.GetPlace(),
        inputs_col_num * sizeof(IndexT),
        phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
    auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
        inputs_col, inputs_col_num);
    paddle::memory::Copy(ctx.GetPlace(),
                         (*dev_col_ptr)->ptr(),
                         phi::CPUPlace(),
                         restored,
                         inputs_col_num * sizeof(IndexT),
                         ctx.stream());
    col_length = static_cast<IndexT*>((*dev_col_ptr)->ptr());
    ins_ptr_wrapper =
        PointerToPointer<T>(ctx, ins, pre_alloced_host_ptr, dev_ins_ptr);
  }

MarDino's avatar
MarDino 已提交
169
  __device__ inline const void* operator[](int i) const {
170 171 172 173 174 175 176
    return ins_ptr_wrapper[i];
  }

 private:
  PointerToPointer<T> ins_ptr_wrapper;
};

177 178
#undef PADDLE_ALIGN

MarDino's avatar
MarDino 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
template <int MovSize>
struct alignas(MovSize) Packed {
  __device__ Packed() {
    // do nothing
  }
  union {
    char buf[MovSize];
  };
};

template <typename IndexT, int MovSize, typename PointerAndColWrapperT>
__global__ void ConcatTensorWithDifferentShape(
    const PointerAndColWrapperT ins_datas,
    int col_size,
    const IndexT output_rows,
    const IndexT output_cols,
    void* output) {
  Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output);

198 199
  IndexT curr_segment = 0;
  IndexT curr_offset = ins_datas.col_length[0];
MarDino's avatar
MarDino 已提交
200

201 202
  CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, IndexT) {
    IndexT curr_col_offset = ins_datas.col_length[curr_segment + 1];
MarDino's avatar
MarDino 已提交
203

L
Leo Chen 已提交
204 205 206
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
207
      curr_col_offset = ins_datas.col_length[curr_segment + 1];
L
Leo Chen 已提交
208 209
    }

210 211
    IndexT local_col = tid_x - curr_offset;
    IndexT segment_width = curr_col_offset - curr_offset;
L
Leo Chen 已提交
212

MarDino's avatar
MarDino 已提交
213 214 215
    const Packed<MovSize>* input_ptr =
        reinterpret_cast<const Packed<MovSize>*>(ins_datas[curr_segment]);

216
    IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
MarDino's avatar
MarDino 已提交
217 218 219

    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
      dst[tid_y * output_cols + tid_x] =
L
Leo Chen 已提交
220
          input_ptr[tid_y * segment_width + local_col];
MarDino's avatar
MarDino 已提交
221
    }
L
Leo Chen 已提交
222 223 224
  }
}

MarDino's avatar
MarDino 已提交
225 226
template <typename IndexT, int MovSize, typename PointerWrapperT>
__global__ void ConcatTensorWithSameShape(const PointerWrapperT ins_data,
227 228 229
                                          const IndexT fixed_in_col,
                                          const IndexT out_rows,
                                          const IndexT out_cols,
MarDino's avatar
MarDino 已提交
230 231
                                          void* output_data) {
  Packed<MovSize>* dst = reinterpret_cast<Packed<MovSize>*>(output_data);
232 233 234
  CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, IndexT) {
    IndexT split = tid_x / fixed_in_col;
    IndexT in_offset = tid_x - split * fixed_in_col;
MarDino's avatar
MarDino 已提交
235 236
    const Packed<MovSize>* input_ptr =
        reinterpret_cast<const Packed<MovSize>*>(ins_data[split]);
237
    IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
238
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
MarDino's avatar
MarDino 已提交
239
      dst[tid_y * out_cols + tid_x] =
L
Leo Chen 已提交
240 241 242 243 244
          input_ptr[tid_y * fixed_in_col + in_offset];
    }
  }
}

MarDino's avatar
MarDino 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
#define IMPL_CONCATE_CUDA_KERNEL_HELPER(func_impl, ...) \
  func_impl(4, ##__VA_ARGS__);                          \
  func_impl(8, ##__VA_ARGS__);                          \
  func_impl(16, ##__VA_ARGS__);                         \
  func_impl(32, ##__VA_ARGS__);                         \
  func_impl(64, ##__VA_ARGS__);                         \
  func_impl(128, ##__VA_ARGS__);

template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithDifferentShapeKernelLimitNum(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const IndexT inputs_col_num,
    const T** inputs_data,
    IndexT* inputs_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT in_num,
    const IndexT limit_num) {
  dim3 block_dims;
  dim3 grid_dims;
  GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);

#define IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE(size_, ...)    \
  case size_: {                                             \
    PointerAndColWrapper<T, IndexT, size_> ptr_col_array(   \
        ctx, ins, inputs_col_num, inputs_data, inputs_col); \
    __VA_ARGS__;                                            \
  } break;
  switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
    IMPL_CONCATE_CUDA_KERNEL_HELPER(
        IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE,
        ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
        <<<grid_dims, block_dims, 0, ctx.stream()>>>(
            ptr_col_array, inputs_col_num, out_row, out_col, output->data()));
    default: {
282 283
      phi::Allocator::AllocationPtr dev_ins_ptr{nullptr};
      phi::Allocator::AllocationPtr dev_col_ptr{nullptr};
MarDino's avatar
MarDino 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
      PointerToPointerAndCol<T, IndexT> ptr_col_array(ctx,
                                                      ins,
                                                      inputs_col_num,
                                                      inputs_data,
                                                      inputs_col,
                                                      &dev_ins_ptr,
                                                      &dev_col_ptr);
      ConcatTensorWithDifferentShape<IndexT, MovSize, decltype(ptr_col_array)>
          <<<grid_dims, block_dims, 0, ctx.stream()>>>(
              ptr_col_array, inputs_col_num, out_row, out_col, output->data());
    }
  }
#undef IMPL_COMPLEX_CONCAT_CUDA_KERNEL_CASE
}

template <typename T, typename IndexT>
void DispatchConcatWithDifferentShapeMovsize(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const IndexT inputs_col_num,
    const T** inputs_data,
    IndexT* inputs_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT mov_size,
    const IndexT in_num,
    const IndexT limit_num) {
  if (mov_size == 16) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 16>(
        ctx,
        ins,
        inputs_col_num,
        inputs_data,
        inputs_col,
        out_row,
        out_col,
        output,
        in_num,
        limit_num);
  } else if (mov_size == 8) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 8>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else if (mov_size == 4) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 4>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else if (mov_size == 2) {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 2>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  } else {
    DispatchConcatWithDifferentShapeKernelLimitNum<T, IndexT, 1>(ctx,
                                                                 ins,
                                                                 inputs_col_num,
                                                                 inputs_data,
                                                                 inputs_col,
                                                                 out_row,
                                                                 out_col,
                                                                 output,
                                                                 in_num,
                                                                 limit_num);
  }
}

template <typename T, typename IndexT, int MovSize>
void DispatchConcatWithSameShapeKernelLimitNum(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const T** inputs_data,
    IndexT in_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT in_num,
    const IndexT limit_num) {
  dim3 block_dims;
  dim3 grid_dims;
  GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);

386 387 388 389
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...)                      \
  case size_: {                                                       \
    AlignedPointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
    __VA_ARGS__;                                                      \
MarDino's avatar
MarDino 已提交
390 391 392 393 394 395 396 397 398
  } break;

  switch (phi::backends::gpu::RoundToNextHighPowOfTwo(limit_num, 4)) {
    IMPL_CONCATE_CUDA_KERNEL_HELPER(
        IMPL_CONCAT_CUDA_KERNEL_CASE,
        ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
        <<<grid_dims, block_dims, 0, ctx.stream()>>>(
            ptr_array, in_col, out_row, out_col, output->data()));
    default: {
399
      phi::Allocator::AllocationPtr dev_ins_ptr{nullptr};
MarDino's avatar
MarDino 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
      PointerToPointer<T> ptr_array(ctx, ins, inputs_data, &dev_ins_ptr);
      ConcatTensorWithSameShape<IndexT, MovSize, decltype(ptr_array)>
          <<<grid_dims, block_dims, 0, ctx.stream()>>>(
              ptr_array, in_col, out_row, out_col, output->data());
    }
  }
#undef IMPL_CONCAT_CUDA_KERNEL_CASE
}

#undef IMPL_CONCATE_CUDA_KERNEL_HELPER

template <typename T, typename IndexT>
void DispatchConcatWithSameShapeMovsize(
    const phi::GPUContext& ctx,
    const std::vector<phi::DenseTensor>& ins,
    const T** inputs_data,
    IndexT in_col,
    const IndexT out_row,
    const IndexT out_col,
    phi::DenseTensor* output,
    const IndexT mov_size,
    const IndexT in_num,
    const IndexT limit_num) {
  if (mov_size == 16) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 16>(ctx,
                                                             ins,
                                                             inputs_data,
                                                             in_col,
                                                             out_row,
                                                             out_col,
                                                             output,
                                                             in_num,
                                                             limit_num);
  } else if (mov_size == 8) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 8>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else if (mov_size == 4) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 4>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else if (mov_size == 2) {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 2>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  } else {
    DispatchConcatWithSameShapeKernelLimitNum<T, IndexT, 1>(ctx,
                                                            ins,
                                                            inputs_data,
                                                            in_col,
                                                            out_row,
                                                            out_col,
                                                            output,
                                                            in_num,
                                                            limit_num);
  }
}

template <typename T, typename IndexT>
void DispatchConcatKernel(const phi::GPUContext& ctx,
                          const std::vector<phi::DenseTensor>& ins,
                          const IndexT inputs_col_num,
                          const T** inputs_data,
                          IndexT* inputs_col,
                          const IndexT out_row,
                          const IndexT out_col,
                          phi::DenseTensor* output,
                          const IndexT in_num,
                          const IndexT limit_num,
                          bool has_same_shape) {
  constexpr IndexT MaxVecSize = 16 / sizeof(T);
  bool find_vecsize_flag = false;
  IndexT dispatch_vec_size = 1;
491 492

  auto output_data = reinterpret_cast<std::uintptr_t>(output->data());
MarDino's avatar
MarDino 已提交
493
  for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
494 495 496
    const IndexT mov_size = vec_size * sizeof(T);
    for (IndexT idx = 1; idx < in_num + 1; idx++) {
      auto input_data = reinterpret_cast<std::uintptr_t>(inputs_data[idx - 1]);
MarDino's avatar
MarDino 已提交
497
      // Since input_cols[0] is 0, we need to jump.
498 499 500 501
      const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1];
      if (input_col % vec_size == 0 && output_data % mov_size == 0 &&
          input_data % mov_size == 0) {
        if (idx == in_num) {
MarDino's avatar
MarDino 已提交
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
          find_vecsize_flag = true;
        }
      } else {
        break;
      }
    }
    if (find_vecsize_flag) {
      dispatch_vec_size = vec_size;
      break;
    }
  }

  const int64_t vectorized_out_col = out_col / dispatch_vec_size;
  for (IndexT idx = 0; idx < in_num + 1; idx++) {
    inputs_col[idx] /= dispatch_vec_size;
  }
  const IndexT mov_size = sizeof(T) * dispatch_vec_size;
  if (has_same_shape) {
    // In same shape situation, each input's col are equal, so here we select to
    // use inputs_col[1].
    DispatchConcatWithSameShapeMovsize<T, IndexT>(ctx,
                                                  ins,
                                                  inputs_data,
                                                  inputs_col[1],
                                                  out_row,
                                                  vectorized_out_col,
                                                  output,
                                                  mov_size,
                                                  in_num,
                                                  limit_num);
  } else {
    DispatchConcatWithDifferentShapeMovsize<T, IndexT>(ctx,
                                                       ins,
                                                       inputs_col_num,
                                                       inputs_data,
                                                       inputs_col,
                                                       out_row,
                                                       vectorized_out_col,
                                                       output,
                                                       mov_size,
                                                       in_num,
                                                       limit_num);
  }
}

L
Leo Chen 已提交
547 548 549 550
/*
 * All tensors' dimension should be the same and the values of
 * each dimension must be the same, except the axis dimension.
 */
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
template <typename T, typename IndexT>
void ConcatFunctorWithIndexType(const phi::GPUContext& ctx,
                                const std::vector<phi::DenseTensor>& ins,
                                int axis,
                                phi::DenseTensor* output) {
  // TODO(zcd): Add input data validity checking
  IndexT in_num = ins.size();
  IndexT in_row = 1;
  auto dim_0 = ins[0].dims();
  for (int i = 0; i < axis; ++i) {
    in_row *= dim_0[i];
  }
  IndexT in_col = ins[0].numel() / in_row;
  IndexT out_row = in_row, out_col = 0;

  IndexT inputs_col_num = in_num + 1;
  std::vector<const T*> inputs_data_vec(in_num, nullptr);
  std::vector<IndexT> inputs_col_vec(inputs_col_num, 0);
  const T** inputs_data = inputs_data_vec.data();
  IndexT* inputs_col = inputs_col_vec.data();
L
Leo Chen 已提交
571
#ifdef PADDLE_WITH_HIP
572
  // TODO(chentianyu03): try to find a method to remove the Alloc function
573
  phi::Allocator::AllocationPtr data_alloc = phi::memory_utils::Alloc(
574 575
      paddle::platform::CUDAPinnedPlace(), in_num * sizeof(T*));
  inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
576
  phi::Allocator::AllocationPtr col_alloc = phi::memory_utils::Alloc(
577 578
      paddle::platform::CUDAPinnedPlace(), inputs_col_num * sizeof(IndexT));
  inputs_col = reinterpret_cast<IndexT*>(col_alloc->ptr());
L
Leo Chen 已提交
579 580
#endif

581 582 583 584 585
  bool has_same_shape = true;
  for (int i = 0; i < in_num; ++i) {
    IndexT t_cols = ins[i].numel() / in_row;
    if (has_same_shape) {
      has_same_shape &= (t_cols == in_col);
L
Leo Chen 已提交
586
    }
587 588 589 590 591
    out_col += t_cols;
    inputs_col[i + 1] = out_col;
  }
  IndexT limit_num = has_same_shape ? in_num : inputs_col_num;

MarDino's avatar
MarDino 已提交
592 593 594 595 596 597 598 599 600 601 602
  DispatchConcatKernel<T, IndexT>(ctx,
                                  ins,
                                  inputs_col_num,
                                  inputs_data,
                                  inputs_col,
                                  out_row,
                                  out_col,
                                  output,
                                  in_num,
                                  limit_num,
                                  has_same_shape);
L
Leo Chen 已提交
603

W
Wilber 已提交
604
#ifdef PADDLE_WITH_HIP
605 606 607 608 609 610 611
  // Prevent pinned memory from being covered and release the memory after
  // kernel launch of the stream is executed (reapply pinned memory next time)
  auto* data_alloc_released = data_alloc.release();
  auto* col_alloc_released = col_alloc.release();
  ctx.AddStreamCallback([data_alloc_released, col_alloc_released] {
    VLOG(4) << "Delete cuda pinned at " << data_alloc_released;
    VLOG(4) << "Delete cuda pinned at " << col_alloc_released;
612 613
    phi::memory_utils::AllocationDeleter(data_alloc_released);
    phi::memory_utils::AllocationDeleter(col_alloc_released);
614
  });
W
Wilber 已提交
615
#endif
616 617 618 619 620 621 622 623 624 625 626 627 628
}

template <typename T>
struct ConcatFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const std::vector<phi::DenseTensor>& input,
                  int axis,
                  phi::DenseTensor* output) {
    if (output->numel() < std::numeric_limits<int32_t>::max()) {
      ConcatFunctorWithIndexType<T, int32_t>(context, input, axis, output);
    } else {
      ConcatFunctorWithIndexType<T, int64_t>(context, input, axis, output);
    }
L
Leo Chen 已提交
629 630 631
  }
};

632 633 634
template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
struct PointerAndColArray
    : public funcs::PointerArraySetter<phi::GPUContext, T, Size> {
L
Leo Chen 已提交
635
 public:
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
  funcs::ValueArray<IndexT, Size> val_array;

  PointerAndColArray() {}
  PointerAndColArray(const phi::GPUContext& ctx,
                     const int out_col_num,
                     IndexT* out_cols,
                     std::vector<DenseTensor*>* t,
                     T** pre_alloc_host_buf = nullptr)
      : funcs::PointerArraySetter<phi::GPUContext, T, Size>(
            ctx,
            t,
            /*need_alloc=*/false,
            /*use_cuda_graph=*/true,
            pre_alloc_host_buf) {
    IndexT* dev_ptr = nullptr;
    if (Size == SegmentedArraySize::kVariableLength) {
      size_t num_bytes = out_col_num * sizeof(IndexT);
      dev_ptr = reinterpret_cast<IndexT*>(this->AllocAndCopy(
          ctx, reinterpret_cast<void*>(out_cols), num_bytes, true));
      val_array.Set(dev_ptr, out_col_num);
    } else {
      val_array.Set(out_cols, out_col_num);
    }
  }
};

template <typename T, typename IndexT, typename DataArrayT>
__global__ void SplitTensorWithSameShape(const T* input_data,
                                         const IndexT out_row,
                                         const IndexT cumulative_col,
                                         const IndexT fixed_out_col,
                                         DataArrayT data_array) {
  CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) {
    IndexT split = tid_x / fixed_out_col;
    IndexT in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = data_array.data[split];
    if (output_ptr != nullptr) {
      IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * cumulative_col + tid_x];
    }
  }
}

template <typename T, typename IndexT, typename DataArrayT, typename ValArrayT>
__global__ void SplitTensorWithDifferentShape(const T* input_data,
                                              const IndexT out_row,
                                              const IndexT cumulative_col,
                                              DataArrayT data_array,
                                              ValArrayT col_array) {
  IndexT curr_segment = 0;
  IndexT curr_offset = col_array.data[0];
  CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) {
    IndexT curr_col_offset = col_array.data[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
      curr_col_offset = col_array.data[curr_segment + 1];
L
Leo Chen 已提交
695 696
    }

697 698 699 700 701 702 703 704
    IndexT local_col = tid_x - curr_offset;
    IndexT segment_width = curr_col_offset - curr_offset;
    T* output_ptr = data_array.data[curr_segment];
    if (output_ptr != nullptr) {
      IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * cumulative_col + tid_x];
L
Leo Chen 已提交
705
    }
706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754
  }
}

template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
void SplitFunctionDispatchWithSameShape(const phi::GPUContext& ctx,
                                        const IndexT out_col,
                                        const IndexT out_row,
                                        const IndexT cumulative_col,
                                        const T* input_data,
                                        std::vector<phi::DenseTensor*>* outs,
                                        T** pre_alloc_host_buf) {
  dim3 grid_dims;
  dim3 block_dims;
  GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims);

  funcs::PointerArraySetter<phi::GPUContext, T, Size> setter(
      ctx,
      outs,
      /*need_alloc=*/false,
      /*use_cuda_graph=*/true,
      pre_alloc_host_buf);
  SplitTensorWithSameShape<T, IndexT, decltype(setter.array)>
      <<<grid_dims, block_dims, 0, ctx.stream()>>>(
          input_data, out_row, cumulative_col, out_col, setter.array);
}

template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
void SplitFunctionDispatchWithDifferentShape(
    const phi::GPUContext& ctx,
    const int out_col_num,
    const IndexT out_row,
    const IndexT cumulative_col,
    const T* input_data,
    std::vector<phi::DenseTensor*>* outs,
    IndexT* output_cols,
    T** pre_alloc_host_buf) {
  dim3 grid_dims;
  dim3 block_dims;
  GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims);
  PointerAndColArray<T, IndexT, Size> setter(
      ctx, out_col_num, output_cols, outs, pre_alloc_host_buf);

  SplitTensorWithDifferentShape<T,
                                IndexT,
                                decltype(setter.array),
                                decltype(setter.val_array)>
      <<<grid_dims, block_dims, 0, ctx.stream()>>>(
          input_data, out_row, cumulative_col, setter.array, setter.val_array);
}
L
Leo Chen 已提交
755

756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
template <typename T, typename IndexT>
void SplitFunctorDispatchWithIndexType(
    const phi::GPUContext& ctx,
    int axis,
    const phi::DenseTensor& input,
    const std::vector<const phi::DenseTensor*>& ref_ins,
    std::vector<phi::DenseTensor*>* outs) {
  // TODO(zcd): Add input data validity checking
  int out_num = outs->size();
  IndexT out_row = 1;
  auto ref_dim = ref_ins[0]->dims();
  for (int i = 0; i < axis; ++i) {
    out_row *= ref_dim[i];
  }
  IndexT out_col = ref_ins[0]->numel() / out_row;
  IndexT cumulative_col = 0;
  bool has_same_shape = true;
L
Leo Chen 已提交
773

774 775 776 777
  int out_cols_num = out_num + 1;
  std::vector<IndexT> outputs_cols_vec(out_cols_num, 0);
  IndexT* outs_cols = outputs_cols_vec.data();
  T** outs_data = nullptr;
L
Leo Chen 已提交
778 779 780 781 782 783 784 785 786

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
787
  phi::Allocator::AllocationPtr data_alloc, cols_alloc;
788
  // TODO(chentianyu03): try to find a method to remove the Alloc function
789 790
  data_alloc = phi::memory_utils::Alloc(paddle::platform::CUDAPinnedPlace(),
                                        out_num * sizeof(T*));
791 792
  outs_data = reinterpret_cast<T**>(data_alloc->ptr());
  // TODO(chentianyu03): try to find a method to remove the Alloc function
793 794
  cols_alloc = phi::memory_utils::Alloc(paddle::platform::CUDAPinnedPlace(),
                                        (out_cols_num) * sizeof(IndexT));
795
  outs_cols = reinterpret_cast<IndexT*>(cols_alloc->ptr());
L
Leo Chen 已提交
796 797
#endif

798 799 800 801 802
  outs_cols[0] = 0;
  for (int i = 0; i < out_num; ++i) {
    IndexT t_col = ref_ins.at(i)->numel() / out_row;
    if (has_same_shape) {
      has_same_shape &= (t_col == cumulative_col);
L
Leo Chen 已提交
803
    }
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
    cumulative_col += t_col;
    outs_cols[i + 1] = cumulative_col;
  }
  int limit_num = has_same_shape ? out_num : out_cols_num;
  if (has_same_shape) {
    switch (funcs::CalcArraySize(limit_num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          SplitFunctionDispatchWithSameShape<T, IndexT, kArraySize>(
              ctx,
              out_col,
              out_row,
              cumulative_col,
              input.data<T>(),
              outs,
              outs_data));
L
Leo Chen 已提交
819
    }
820 821 822 823 824 825 826 827 828 829 830 831
  } else {
    switch (funcs::CalcArraySize(limit_num)) {
      SEGMENTED_ARRAY_KERNEL_HELPER(
          SplitFunctionDispatchWithDifferentShape<T, IndexT, kArraySize>(
              ctx,
              out_cols_num,
              out_row,
              cumulative_col,
              input.data<T>(),
              outs,
              outs_cols,
              outs_data));
L
Leo Chen 已提交
832
    }
833
  }
834

W
Wilber 已提交
835
#ifdef PADDLE_WITH_HIP
836 837 838 839 840
  // Prevent pinned memory from being covered and release the memory after
  // kernel launch of the stream is executed (reapply pinned memory next time)
  auto* data_alloc_released = data_alloc.release();
  auto* cols_alloc_released = cols_alloc.release();
  ctx.AddStreamCallback([data_alloc_released, cols_alloc_released] {
841 842
    phi::memory_utils::AllocationDeleter(data_alloc_released);
    phi::memory_utils::AllocationDeleter(cols_alloc_released);
843
  });
W
Wilber 已提交
844
#endif
845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
}

template <typename T>
class SplitFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const phi::DenseTensor& input,
                  const std::vector<const phi::DenseTensor*>& ref_inputs,
                  int axis,
                  std::vector<phi::DenseTensor*>* outputs) {
    int64_t numel = input.numel();
    // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in
    // 3 tensors of shape [0,1,4]
    if (input.numel() == 0) {
      return;
    }

    if (numel < std::numeric_limits<int32_t>::max()) {
      SplitFunctorDispatchWithIndexType<T, int32_t>(
          context, axis, input, ref_inputs, outputs);
    } else {
      SplitFunctorDispatchWithIndexType<T, int64_t>(
          context, axis, input, ref_inputs, outputs);
    }
L
Leo Chen 已提交
869 870 871 872 873 874 875 876 877 878 879
  }
};

#define DEFINE_FUNCTOR(type)                           \
  template class ConcatFunctor<phi::GPUContext, type>; \
  template class SplitFunctor<phi::GPUContext, type>

FOR_ALL_TYPES(DEFINE_FUNCTOR);

}  // namespace funcs
}  // namespace phi