concat_and_split.cu 19.6 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
17
#include "gflags/gflags.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/framework/mixed_vector.h"
19
#include "paddle/fluid/memory/malloc.h"
C
chengduo 已提交
20
#include "paddle/fluid/operators/math/concat_and_split.h"
21
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduoZH 已提交
23
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
24 25 26 27 28 29

namespace paddle {
namespace operators {
namespace math {

template <typename T>
W
wuhuachaocoding 已提交
30 31 32
__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
                             int col_size, const int64_t output_rows,
                             const int64_t output_cols, T* output) {
C
chengduoZH 已提交
33
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
34 35
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
36
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
37 38
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
39 40
      curr_offset = curr_col_offset;
      ++curr_segment;
41
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
42 43 44 45
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
46

47
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
48
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
49 50 51 52 53 54
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
55
template <typename T>
56 57 58
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
59
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
60 61 62
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
63
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
64
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
65 66 67
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
68
    }
C
chengduoZH 已提交
69 70 71
  }
}

72 73
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
74 75
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
76 77 78 79 80 81 82
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

83 84
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
85 86
                             const T* input_addr2, const int64_t fixed_in_col,
                             const int64_t out_rows, const int64_t out_cols,
87 88 89 90 91 92 93 94 95 96 97 98
                             T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const T* input_addr3,
W
wuhuachaocoding 已提交
99 100
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
101 102 103 104 105 106 107 108 109
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

110 111
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
W
wuhuachaocoding 已提交
112 113
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
114 115 116 117
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
118
template <typename T>
T
Thunderbrook 已提交
119 120
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t* out_cols,
C
chengduo 已提交
121
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
122
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
123 124
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
125
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
126 127
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
128 129
      curr_offset = curr_col_offset;
      ++curr_segment;
130
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
131 132 133 134
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
135
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
136 137 138 139 140 141
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
142 143 144 145
  }
}

template <typename T>
146 147 148
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
149
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
150 151 152 153
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
154 155 156 157 158 159
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
160 161 162
  }
}

163
template <typename T>
T
Thunderbrook 已提交
164 165
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
166 167 168 169 170
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
171 172
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
173 174 175 176 177 178 179
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

180
template <typename T>
T
Thunderbrook 已提交
181 182
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
183 184 185 186 187 188 189 190 191 192
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
193 194
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
195 196 197 198 199 200 201 202 203 204
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2, T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

205
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
T
Thunderbrook 已提交
206 207
                                int64_t num_rows, int64_t num_cols,
                                dim3* block_dims, dim3* grid_dims) {
208 209 210 211 212 213 214 215 216 217
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
T
Thunderbrook 已提交
218
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
219 220 221

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
T
Thunderbrook 已提交
222 223
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
224 225 226
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
227
/*
C
chengduoZH 已提交
228
 * All tensors' dimension should be the same and the values of
229
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
230 231 232 233 234
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
235
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
236
                  framework::Tensor* output) {
C
chengduoZH 已提交
237
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
238
    int in_num = input.size();
W
wuhuachaocoding 已提交
239
    int64_t in_row = 1;
C
chengduoZH 已提交
240 241
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
242
      in_row *= dim_0[i];
C
chengduoZH 已提交
243
    }
W
wuhuachaocoding 已提交
244 245
    int64_t in_col = input[0].numel() / in_row;
    int64_t out_row = in_row, out_col = 0;
C
chengduoZH 已提交
246

247 248
    int inputs_col_num = in_num + 1;
    std::vector<const T*> inputs_data_vec(in_num);
W
wuhuachaocoding 已提交
249
    std::vector<int64_t> inputs_col_vec(inputs_col_num);
250
    const T** inputs_data = inputs_data_vec.data();
W
wuhuachaocoding 已提交
251
    int64_t* inputs_col = inputs_col_vec.data();
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, col_alloc;
    data_alloc =
        memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*));
    inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
    col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                              inputs_col_num * sizeof(int));
W
wuhuachaocoding 已提交
267
    inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
268
#endif
C
chengduoZH 已提交
269

C
chengduoZH 已提交
270
    inputs_col[0] = 0;
271
    bool has_same_shape = true;
C
chengduoZH 已提交
272
    for (int i = 0; i < in_num; ++i) {
W
wuhuachaocoding 已提交
273
      int64_t t_cols = input[i].numel() / in_row;
274 275
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
276
      }
C
chengduoZH 已提交
277 278
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
279
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
280 281
    }

282 283 284 285 286 287
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
288
    if (!has_same_shape || in_num < 2 || in_num > 4) {
289
      tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
290 291 292 293 294 295 296
      {
        platform::SkipCUDAGraphCaptureGuard guard;
        memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
                     tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                     static_cast<void*>(inputs_data), in_num * sizeof(T*),
                     context.stream());
      }
297
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
298
    }
299 300 301 302 303 304

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
305 306 307 308 309 310 311 312
      } else if (in_num == 3) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row,
            out_col, output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3],
            in_col, out_row, out_col, output->data<T>());
313 314 315 316
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
317
    } else {
318
      auto tmp_dev_ins_col_data =
W
wuhuachaocoding 已提交
319
          memory::Alloc(context, inputs_col_num * sizeof(int64_t));
320 321 322 323 324 325 326
      {
        platform::SkipCUDAGraphCaptureGuard guard;
        memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
                     tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                     static_cast<void*>(inputs_col),
                     inputs_col_num * sizeof(int64_t), context.stream());
      }
W
wuhuachaocoding 已提交
327 328
      int64_t* dev_ins_col_data =
          static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
329

330
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
331
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
C
chengduoZH 已提交
332
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
333
    }
W
wuhuachaocoding 已提交
334

335 336 337 338 339 340 341 342 343 344 345 346
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* col_alloc_released = col_alloc.release();
    context.AddStreamCallback([data_alloc_released, col_alloc_released] {
      memory::allocation::AllocationDeleter deleter;
      deleter(data_alloc_released);
      deleter(col_alloc_released);
    });
#endif
C
chengduoZH 已提交
347 348 349
  }
};

C
chengduoZH 已提交
350 351
/*
 * All tensors' dimension should be the same and the values of
352
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
353
 */
C
chengduoZH 已提交
354
template <typename T>
C
chengduo 已提交
355
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
356
 public:
357
  SplitFunctor();
C
chengduoZH 已提交
358
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
359
                  const framework::Tensor& input,
C
chengduoZH 已提交
360 361
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
L
Leo Chen 已提交
362 363 364 365 366 367
    // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
    // tensors of shape [0,1,4]
    if (input.numel() == 0) {
      return;
    }

C
chengduoZH 已提交
368
    // TODO(zcd): Add input data validity checking
369
    int o_num = outputs->size();
T
Thunderbrook 已提交
370
    int64_t out_row = 1;
Q
qiaolongfei 已提交
371
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
372
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
373
      out_row *= dim_0[i];
C
chengduoZH 已提交
374 375
    }

T
Thunderbrook 已提交
376 377
    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
378
    bool has_same_shape = true;
C
chengduoZH 已提交
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    int outputs_cols_num = o_num + 1;
    std::vector<T*> outputs_data_vec(o_num);
    std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
    T** outputs_data = outputs_data_vec.data();
    int64_t* outputs_cols = outputs_cols_vec.data();

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, cols_alloc;
    data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*));
    outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
    cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                               (outputs_cols_num) * sizeof(int64_t));
    outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
#endif
C
chengduoZH 已提交
401

C
chengduoZH 已提交
402 403
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
T
Thunderbrook 已提交
404
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
405 406
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
407
      }
C
chengduoZH 已提交
408 409
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
410
      if (outputs->at(i) != nullptr) {
411
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
412
      } else {
413
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
414
      }
C
chengduoZH 已提交
415 416
    }

417 418 419 420 421 422
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
423
    if (!has_same_shape || o_num < 2 || o_num > 4) {
424
      tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
425 426 427 428 429 430 431
      {
        platform::SkipCUDAGraphCaptureGuard guard;
        memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
                     tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                     reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
                     context.stream());
      }
432
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
433
    }
434 435 436 437 438 439

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
440 441 442 443 444 445 446 447
      } else if (o_num == 3) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2], outputs_data[3]);
448 449 450 451
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
452
    } else {
453
      auto tmp_dev_ins_col_data =
454
          memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
455 456 457 458 459 460 461
      {
        platform::SkipCUDAGraphCaptureGuard guard;
        memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
                     tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                     reinterpret_cast<void*>(outputs_cols),
                     outputs_cols_num * sizeof(int64_t), context.stream());
      }
T
Thunderbrook 已提交
462 463
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
464

465
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
466
          input.data<T>(), in_row, in_col, dev_outs_col_data,
467
          static_cast<int>(outputs_cols_num), dev_out_gpu_data);
C
chengduoZH 已提交
468
    }
469 470 471 472 473 474 475 476 477 478 479 480
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* cols_alloc_released = cols_alloc.release();
    context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
      memory::allocation::AllocationDeleter deleter;
      deleter(data_alloc_released);
      deleter(cols_alloc_released);
    });
#endif
C
chengduoZH 已提交
481 482 483
  }
};

C
chengduoZH 已提交
484 485
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
486
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
487

C
chengduoZH 已提交
488
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
489

C
chengduoZH 已提交
490 491 492
}  // namespace math
}  // namespace operators
}  // namespace paddle