concat_and_split.cu 19.4 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
17
#include "gflags/gflags.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/framework/mixed_vector.h"
19
#include "paddle/fluid/memory/malloc.h"
C
chengduo 已提交
20
#include "paddle/fluid/operators/math/concat_and_split.h"
21
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
C
chengduoZH 已提交
23
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
24 25 26 27 28 29

namespace paddle {
namespace operators {
namespace math {

template <typename T>
W
wuhuachaocoding 已提交
30 31 32
__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
                             int col_size, const int64_t output_rows,
                             const int64_t output_cols, T* output) {
C
chengduoZH 已提交
33
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
34 35
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
36
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
37 38
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
39 40
      curr_offset = curr_col_offset;
      ++curr_segment;
41
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
42 43 44 45
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
46

47
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
48
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
49 50 51 52 53 54
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
55
template <typename T>
56 57 58
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
59
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
60 61 62
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
63
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
64
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
65 66 67
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
68
    }
C
chengduoZH 已提交
69 70 71
  }
}

72 73
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
74 75
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
76 77 78 79 80 81 82
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

83 84
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
85 86
                             const T* input_addr2, const int64_t fixed_in_col,
                             const int64_t out_rows, const int64_t out_cols,
87 88 89 90 91 92 93 94 95 96 97 98
                             T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const T* input_addr3,
W
wuhuachaocoding 已提交
99 100
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
101 102 103 104 105 106 107 108 109
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

110 111
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
W
wuhuachaocoding 已提交
112 113
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
114 115 116 117
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
118
template <typename T>
T
Thunderbrook 已提交
119 120
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t* out_cols,
C
chengduo 已提交
121
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
122
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
123 124
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
125
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
126 127
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
128 129
      curr_offset = curr_col_offset;
      ++curr_segment;
130
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
131 132 133 134
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
135
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
136 137 138 139 140 141
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
142 143 144 145
  }
}

template <typename T>
146 147 148
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
149
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
150 151 152 153
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
154 155 156 157 158 159
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
160 161 162
  }
}

163
template <typename T>
T
Thunderbrook 已提交
164 165
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
166 167 168 169 170
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
171 172
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
173 174 175 176 177 178 179
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

180
template <typename T>
T
Thunderbrook 已提交
181 182
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
183 184 185 186 187 188 189 190 191 192
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
193 194
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
195 196 197 198 199 200 201 202 203 204
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2, T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

205
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
T
Thunderbrook 已提交
206 207
                                int64_t num_rows, int64_t num_cols,
                                dim3* block_dims, dim3* grid_dims) {
208 209 210 211 212 213 214 215 216 217
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
T
Thunderbrook 已提交
218
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
219 220 221

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
T
Thunderbrook 已提交
222 223
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
224 225 226
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
227
/*
C
chengduoZH 已提交
228
 * All tensors' dimension should be the same and the values of
229
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
230 231 232 233 234
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
235
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
236
                  framework::Tensor* output) {
C
chengduoZH 已提交
237
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
238
    int in_num = input.size();
W
wuhuachaocoding 已提交
239
    int64_t in_row = 1;
C
chengduoZH 已提交
240 241
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
242
      in_row *= dim_0[i];
C
chengduoZH 已提交
243
    }
W
wuhuachaocoding 已提交
244 245
    int64_t in_col = input[0].numel() / in_row;
    int64_t out_row = in_row, out_col = 0;
C
chengduoZH 已提交
246

247 248
    int inputs_col_num = in_num + 1;
    std::vector<const T*> inputs_data_vec(in_num);
W
wuhuachaocoding 已提交
249
    std::vector<int64_t> inputs_col_vec(inputs_col_num);
250
    const T** inputs_data = inputs_data_vec.data();
W
wuhuachaocoding 已提交
251
    int64_t* inputs_col = inputs_col_vec.data();
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, col_alloc;
    data_alloc =
        memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*));
    inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
    col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                              inputs_col_num * sizeof(int));
W
wuhuachaocoding 已提交
267
    inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
268
#endif
C
chengduoZH 已提交
269

C
chengduoZH 已提交
270
    inputs_col[0] = 0;
271
    bool has_same_shape = true;
C
chengduoZH 已提交
272
    for (int i = 0; i < in_num; ++i) {
W
wuhuachaocoding 已提交
273
      int64_t t_cols = input[i].numel() / in_row;
274 275
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
276
      }
C
chengduoZH 已提交
277 278
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
279
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
280 281
    }

282 283 284 285 286 287
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
288
    if (!has_same_shape || in_num < 2 || in_num > 4) {
289
      tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
290 291
      auto* restored =
          platform::RestoreHostMemIfCapturingCUDAGraph(inputs_data, in_num);
292 293 294
      memory::Copy(context.GetPlace(), tmp_dev_ins_data->ptr(),
                   platform::CPUPlace(), restored, in_num * sizeof(T*),
                   context.stream());
295
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
296
    }
297 298 299 300 301 302

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
303 304 305 306 307 308 309 310
      } else if (in_num == 3) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row,
            out_col, output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3],
            in_col, out_row, out_col, output->data<T>());
311 312 313 314
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
315
    } else {
316
      auto tmp_dev_ins_col_data =
W
wuhuachaocoding 已提交
317
          memory::Alloc(context, inputs_col_num * sizeof(int64_t));
318 319 320

      auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
          inputs_col, inputs_col_num);
321 322
      memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(),
                   platform::CPUPlace(), restored,
323
                   inputs_col_num * sizeof(int64_t), context.stream());
W
wuhuachaocoding 已提交
324 325
      int64_t* dev_ins_col_data =
          static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
326

327
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
328
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
C
chengduoZH 已提交
329
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
330
    }
W
wuhuachaocoding 已提交
331

332 333 334 335 336 337 338
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* col_alloc_released = col_alloc.release();
    context.AddStreamCallback([data_alloc_released, col_alloc_released] {
339 340
      memory::allocation::Allocator::AllocationDeleter(data_alloc_released);
      memory::allocation::Allocator::AllocationDeleter(col_alloc_released);
341 342
    });
#endif
C
chengduoZH 已提交
343 344 345
  }
};

C
chengduoZH 已提交
346 347
/*
 * All tensors' dimension should be the same and the values of
348
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
349
 */
C
chengduoZH 已提交
350
template <typename T>
C
chengduo 已提交
351
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
352
 public:
353
  SplitFunctor();
C
chengduoZH 已提交
354
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
355
                  const framework::Tensor& input,
C
chengduoZH 已提交
356 357
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
L
Leo Chen 已提交
358 359 360 361 362 363
    // NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
    // tensors of shape [0,1,4]
    if (input.numel() == 0) {
      return;
    }

C
chengduoZH 已提交
364
    // TODO(zcd): Add input data validity checking
365
    int o_num = outputs->size();
T
Thunderbrook 已提交
366
    int64_t out_row = 1;
Q
qiaolongfei 已提交
367
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
368
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
369
      out_row *= dim_0[i];
C
chengduoZH 已提交
370 371
    }

T
Thunderbrook 已提交
372 373
    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
374
    bool has_same_shape = true;
C
chengduoZH 已提交
375

376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    int outputs_cols_num = o_num + 1;
    std::vector<T*> outputs_data_vec(o_num);
    std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
    T** outputs_data = outputs_data_vec.data();
    int64_t* outputs_cols = outputs_cols_vec.data();

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, cols_alloc;
    data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*));
    outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
    cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                               (outputs_cols_num) * sizeof(int64_t));
    outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
#endif
C
chengduoZH 已提交
397

C
chengduoZH 已提交
398 399
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
T
Thunderbrook 已提交
400
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
401 402
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
403
      }
C
chengduoZH 已提交
404 405
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
406
      if (outputs->at(i) != nullptr) {
407
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
408
      } else {
409
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
410
      }
C
chengduoZH 已提交
411 412
    }

413 414 415 416 417 418
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
419
    if (!has_same_shape || o_num < 2 || o_num > 4) {
420
      tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
421 422
      auto* restored =
          platform::RestoreHostMemIfCapturingCUDAGraph(outputs_data, o_num);
423 424 425
      memory::Copy(context.GetPlace(), tmp_dev_outs_data->ptr(),
                   platform::CPUPlace(), restored, o_num * sizeof(T*),
                   context.stream());
426
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
427
    }
428 429 430 431 432 433

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
434 435 436 437 438 439 440 441
      } else if (o_num == 3) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2], outputs_data[3]);
442 443 444 445
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
446
    } else {
447
      auto tmp_dev_ins_col_data =
448
          memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
449 450
      auto* restored = platform::RestoreHostMemIfCapturingCUDAGraph(
          outputs_cols, outputs_cols_num);
451 452
      memory::Copy(context.GetPlace(), tmp_dev_ins_col_data->ptr(),
                   platform::CPUPlace(), restored,
453
                   outputs_cols_num * sizeof(int64_t), context.stream());
T
Thunderbrook 已提交
454 455
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
456

457
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
458
          input.data<T>(), in_row, in_col, dev_outs_col_data,
459
          static_cast<int>(outputs_cols_num), dev_out_gpu_data);
C
chengduoZH 已提交
460
    }
461 462 463 464 465 466 467
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* cols_alloc_released = cols_alloc.release();
    context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
468 469
      memory::allocation::Allocator::AllocationDeleter(data_alloc_released);
      memory::allocation::Allocator::AllocationDeleter(cols_alloc_released);
470 471
    });
#endif
C
chengduoZH 已提交
472 473 474
  }
};

C
chengduoZH 已提交
475 476
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
477
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
478

C
chengduoZH 已提交
479
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
480

C
chengduoZH 已提交
481 482 483
}  // namespace math
}  // namespace operators
}  // namespace paddle