concat_and_split_functor.cu 21.6 KB
Newer Older
L
Leo Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
16 17 18
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"

L
Leo Chen 已提交
19 20 21 22 23 24 25 26 27 28
namespace phi {
namespace funcs {

template <typename T>
__global__ void ConcatKernel_(const T** inputs,
                              const int64_t* input_cols,
                              int col_size,
                              const int64_t output_rows,
                              const int64_t output_cols,
                              T* output) {
29 30 31 32
  int64_t curr_segment = 0;
  int64_t curr_offset = input_cols[0];
  CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, int64_t) {
    int64_t curr_col_offset = input_cols[curr_segment + 1];
L
Leo Chen 已提交
33 34 35 36 37 38
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
      curr_col_offset = input_cols[curr_segment + 1];
    }

39 40
    int64_t local_col = tid_x - curr_offset;
    int64_t segment_width = curr_col_offset - curr_offset;
L
Leo Chen 已提交
41 42

    const T* input_ptr = inputs[curr_segment];
43
    int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
44 45 46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

template <typename T>
__device__ void ConcatKernelDetail(const T** inputs_data,
52 53 54
                                   const int64_t fixed_in_col,
                                   const int64_t out_rows,
                                   const int64_t out_cols,
L
Leo Chen 已提交
55
                                   T* output_data) {
56 57 58
  CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, int64_t) {
    int64_t split = tid_x * 1.0 / fixed_in_col;
    int64_t in_offset = tid_x - split * fixed_in_col;
L
Leo Chen 已提交
59
    const T* input_ptr = inputs_data[split];
60
    int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
    }
  }
}

template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
                              const T* input_addr1,
                              const int64_t fixed_in_col,
                              const int64_t out_rows,
                              const int64_t out_cols,
                              T* output_data) {
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(
      inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}

template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
                              const T* input_addr1,
                              const T* input_addr2,
                              const int64_t fixed_in_col,
                              const int64_t out_rows,
                              const int64_t out_cols,
                              T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(
      inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}

template <typename T>
__global__ void ConcatKernel_(const T* input_addr0,
                              const T* input_addr1,
                              const T* input_addr2,
                              const T* input_addr3,
                              const int64_t fixed_in_col,
                              const int64_t out_rows,
                              const int64_t out_cols,
                              T* output_data) {
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(
      inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}

template <typename T>
__global__ void ConcatKernel_(const T** inputs_data,
                              const int in_num,
                              const int64_t fixed_in_col,
                              const int64_t out_rows,
                              const int64_t out_cols,
                              T* output_data) {
  ConcatKernelDetail<T>(
      inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t* out_cols,
                             int out_cols_size,
                             T** outputs_data) {
134 135 136 137
  int64_t curr_segment = 0;
  int64_t curr_offset = out_cols[0];
  CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
    int64_t curr_col_offset = out_cols[curr_segment + 1];
L
Leo Chen 已提交
138 139 140 141 142 143
    while (curr_col_offset <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
      curr_col_offset = out_cols[curr_segment + 1];
    }

144 145
    int64_t local_col = tid_x - curr_offset;
    int64_t segment_width = curr_col_offset - curr_offset;
L
Leo Chen 已提交
146 147
    T* output_ptr = outputs_data[curr_segment];
    if (output_ptr != nullptr) {
148
      int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
149 150 151 152 153 154 155 156 157
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
  }
}

template <typename T>
__device__ void SplitKernelDetail(const T* input_data,
158 159 160
                                  const int64_t in_row,
                                  const int64_t in_col,
                                  const int64_t fixed_out_col,
L
Leo Chen 已提交
161
                                  T** outputs_data) {
162 163 164
  CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
    int64_t split = tid_x / fixed_out_col;
    int64_t in_offset = tid_x - split * fixed_out_col;
L
Leo Chen 已提交
165 166
    T* output_ptr = outputs_data[split];
    if (output_ptr != nullptr) {
167
      int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
L
Leo Chen 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
  }
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1,
                             T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel_(const T* input_data,
                             const int64_t in_row,
                             const int64_t in_col,
                             const int64_t fixed_out_col,
                             T* outputs_addr0,
                             T* outputs_addr1,
                             T* outputs_addr2,
                             T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

static inline void GetBlockDims(const phi::GPUContext& context,
                                int64_t num_rows,
                                int64_t num_cols,
                                dim3* block_dims,
                                dim3* grid_dims) {
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

/*
 * All tensors' dimension should be the same and the values of
 * each dimension must be the same, except the axis dimension.
 */

template <typename T>
struct ConcatFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const std::vector<phi::DenseTensor>& input,
                  int axis,
                  phi::DenseTensor* output) {
    // TODO(zcd): Add input data validity checking
265
    int64_t in_num = input.size();
L
Leo Chen 已提交
266 267 268 269 270 271 272 273
    int64_t in_row = 1;
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
      in_row *= dim_0[i];
    }
    int64_t in_col = input[0].numel() / in_row;
    int64_t out_row = in_row, out_col = 0;

274
    int64_t inputs_col_num = in_num + 1;
W
Wilber 已提交
275 276 277 278
    std::vector<const T*> inputs_data_vec(in_num);
    std::vector<int64_t> inputs_col_vec(inputs_col_num);
    const T** inputs_data = inputs_data_vec.data();
    int64_t* inputs_col = inputs_col_vec.data();
L
Leo Chen 已提交
279 280 281 282 283 284 285 286 287

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
W
Wilber 已提交
288
    paddle::memory::AllocationPtr data_alloc, col_alloc;
L
Leo Chen 已提交
289 290 291
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                       in_num * sizeof(T*));
W
Wilber 已提交
292
    inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
L
Leo Chen 已提交
293 294 295
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    col_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                      inputs_col_num * sizeof(int));
W
Wilber 已提交
296
    inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
L
Leo Chen 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
#endif

    inputs_col[0] = 0;
    bool has_same_shape = true;
    for (int i = 0; i < in_num; ++i) {
      int64_t t_cols = input[i].numel() / in_row;
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
      }
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
      inputs_data[i] = input[i].data<T>();
    }

    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    paddle::memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
    if (!has_same_shape || in_num < 2 || in_num > 4) {
      tmp_dev_ins_data = paddle::memory::Alloc(context, in_num * sizeof(T*));
      auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph(
          inputs_data, in_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_ins_data->ptr(),
                           paddle::platform::CPUPlace(),
                           restored,
                           in_num * sizeof(T*),
                           context.stream());
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
    }

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0],
            inputs_data[1],
            in_col,
            out_row,
            out_col,
            output->data<T>());
      } else if (in_num == 3) {
        ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0],
            inputs_data[1],
            inputs_data[2],
            in_col,
            out_row,
            out_col,
            output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0],
            inputs_data[1],
            inputs_data[2],
            inputs_data[3],
            in_col,
            out_row,
            out_col,
            output->data<T>());
      } else {
        ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
    } else {
      auto tmp_dev_ins_col_data =
          paddle::memory::Alloc(context, inputs_col_num * sizeof(int64_t));

      auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph(
          inputs_col, inputs_col_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_ins_col_data->ptr(),
                           paddle::platform::CPUPlace(),
                           restored,
                           inputs_col_num * sizeof(int64_t),
                           context.stream());
      int64_t* dev_ins_col_data =
          static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());

      ConcatKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
          dev_ins_data,
          dev_ins_col_data,
          static_cast<int>(inputs_col_num),
          out_row,
          out_col,
          output->data<T>());
    }

W
Wilber 已提交
386
#ifdef PADDLE_WITH_HIP
L
Leo Chen 已提交
387 388 389 390 391 392
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* col_alloc_released = col_alloc.release();
    context.AddStreamCallback([data_alloc_released, col_alloc_released] {
L
Leo Chen 已提交
393 394
      VLOG(4) << "Delete cuda pinned at " << data_alloc_released;
      VLOG(4) << "Delete cuda pinned at " << col_alloc_released;
L
Leo Chen 已提交
395 396 397 398 399
      paddle::memory::allocation::Allocator::AllocationDeleter(
          data_alloc_released);
      paddle::memory::allocation::Allocator::AllocationDeleter(
          col_alloc_released);
    });
W
Wilber 已提交
400
#endif
L
Leo Chen 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
  }
};

template <typename T>
class SplitFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const phi::DenseTensor& input,
                  const std::vector<const phi::DenseTensor*>& ref_inputs,
                  int axis,
                  std::vector<phi::DenseTensor*>* outputs) {
    // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
    // tensors of shape [0,1,4]
    if (input.numel() == 0) {
      return;
    }

    // TODO(zcd): Add input data validity checking
    int o_num = outputs->size();
    int64_t out_row = 1;
    auto dim_0 = ref_inputs[0]->dims();
    for (int i = 0; i < axis; ++i) {
      out_row *= dim_0[i];
    }

    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
    bool has_same_shape = true;

    int outputs_cols_num = o_num + 1;
W
Wilber 已提交
431 432 433 434
    std::vector<T*> outputs_data_vec(o_num);
    std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
    T** outputs_data = outputs_data_vec.data();
    int64_t* outputs_cols = outputs_cols_vec.data();
L
Leo Chen 已提交
435 436 437 438 439 440 441 442 443

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
W
Wilber 已提交
444
    paddle::memory::AllocationPtr data_alloc, cols_alloc;
L
Leo Chen 已提交
445 446 447
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                       o_num * sizeof(T*));
W
Wilber 已提交
448
    outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
L
Leo Chen 已提交
449 450 451
    // TODO(chentianyu03): try to find a method to remove the Alloc function
    cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
                                       (outputs_cols_num) * sizeof(int64_t));
W
Wilber 已提交
452
    outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
L
Leo Chen 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
#endif

    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
      }
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
      if (outputs->at(i) != nullptr) {
        outputs_data[i] = outputs->at(i)->data<T>();
      } else {
        outputs_data[i] = nullptr;
      }
    }

    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    paddle::memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
    if (!has_same_shape || o_num < 2 || o_num > 4) {
      // TODO(chentianyu03): try to find a method to remove the Alloc function
      tmp_dev_outs_data = paddle::memory::Alloc(context, o_num * sizeof(T*));
      auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph(
          outputs_data, o_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_outs_data->ptr(),
                           paddle::platform::CPUPlace(),
                           restored,
                           o_num * sizeof(T*),
                           context.stream());
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
    }

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1]);
      } else if (o_num == 3) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1],
            outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(),
            in_row,
            in_col,
            out0_col,
            outputs_data[0],
            outputs_data[1],
            outputs_data[2],
            outputs_data[3]);
      } else {
        SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
    } else {
      auto tmp_dev_ins_col_data =
          // TODO(chentianyu03): try to find a method to remove the Alloc
          // function
          paddle::memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
      auto* restored = paddle::platform::RestoreHostMemIfCapturingCUDAGraph(
          outputs_cols, outputs_cols_num);
      paddle::memory::Copy(context.GetPlace(),
                           tmp_dev_ins_col_data->ptr(),
                           paddle::platform::CPUPlace(),
                           restored,
                           outputs_cols_num * sizeof(int64_t),
                           context.stream());
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());

      SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
          input.data<T>(),
          in_row,
          in_col,
          dev_outs_col_data,
          static_cast<int>(outputs_cols_num),
          dev_out_gpu_data);
    }
546

W
Wilber 已提交
547
#ifdef PADDLE_WITH_HIP
L
Leo Chen 已提交
548 549 550 551 552 553 554 555 556 557 558
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* cols_alloc_released = cols_alloc.release();
    context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
      paddle::memory::allocation::Allocator::AllocationDeleter(
          data_alloc_released);
      paddle::memory::allocation::Allocator::AllocationDeleter(
          cols_alloc_released);
    });
W
Wilber 已提交
559
#endif
L
Leo Chen 已提交
560 561 562 563 564 565 566 567 568 569 570
  }
};

#define DEFINE_FUNCTOR(type)                           \
  template class ConcatFunctor<phi::GPUContext, type>; \
  template class SplitFunctor<phi::GPUContext, type>

FOR_ALL_TYPES(DEFINE_FUNCTOR);

}  // namespace funcs
}  // namespace phi