concat_and_split.cu 19.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
17
#include "gflags/gflags.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/framework/mixed_vector.h"
19
#include "paddle/fluid/memory/malloc.h"
C
chengduo 已提交
20
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
21
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
22
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
23 24 25 26 27 28

namespace paddle {
namespace operators {
namespace math {

template <typename T>
W
wuhuachaocoding 已提交
29 30 31
__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols,
                             int col_size, const int64_t output_rows,
                             const int64_t output_cols, T* output) {
C
chengduoZH 已提交
32
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
33 34
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
35
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
36 37
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
38 39
      curr_offset = curr_col_offset;
      ++curr_segment;
40
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
41 42 43 44
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
45

46
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
47
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
48 49 50 51 52 53
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
54
template <typename T>
55 56 57
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
58
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
59 60 61
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
62
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
63
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
64 65 66
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
67
    }
C
chengduoZH 已提交
68 69 70
  }
}

71 72
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
73 74
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
75 76 77 78 79 80 81
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

82 83
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
W
wuhuachaocoding 已提交
84 85
                             const T* input_addr2, const int64_t fixed_in_col,
                             const int64_t out_rows, const int64_t out_cols,
86 87 88 89 90 91 92 93 94 95 96 97
                             T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const T* input_addr3,
W
wuhuachaocoding 已提交
98 99
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
100 101 102 103 104 105 106 107 108
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

109 110
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
W
wuhuachaocoding 已提交
111 112
                             const int64_t fixed_in_col, const int64_t out_rows,
                             const int64_t out_cols, T* output_data) {
113 114 115 116
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
117
template <typename T>
T
Thunderbrook 已提交
118 119
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t* out_cols,
C
chengduo 已提交
120
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
121
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
122 123
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
124
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
125 126
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
127 128
      curr_offset = curr_col_offset;
      ++curr_segment;
129
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
130 131 132 133
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
134
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
135 136 137 138 139 140
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
141 142 143 144
  }
}

template <typename T>
145 146 147
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
148
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
149 150 151 152
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
153 154 155 156 157 158
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
159 160 161
  }
}

162
template <typename T>
T
Thunderbrook 已提交
163 164
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
165 166 167 168 169
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
170 171
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
172 173 174 175 176 177 178
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

179
template <typename T>
T
Thunderbrook 已提交
180 181
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
182 183 184 185 186 187 188 189 190 191
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
192 193
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
194 195 196 197 198 199 200 201 202 203
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2, T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

204
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
T
Thunderbrook 已提交
205 206
                                int64_t num_rows, int64_t num_cols,
                                dim3* block_dims, dim3* grid_dims) {
207 208 209 210 211 212 213 214 215 216
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
T
Thunderbrook 已提交
217
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
218 219 220

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
T
Thunderbrook 已提交
221 222
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
223 224 225
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
226
/*
C
chengduoZH 已提交
227
 * All tensors' dimension should be the same and the values of
228
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
229 230 231 232 233
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
234
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
235
                  framework::Tensor* output) {
C
chengduoZH 已提交
236
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
237
    int in_num = input.size();
W
wuhuachaocoding 已提交
238
    int64_t in_row = 1;
C
chengduoZH 已提交
239 240
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
241
      in_row *= dim_0[i];
C
chengduoZH 已提交
242
    }
W
wuhuachaocoding 已提交
243 244
    int64_t in_col = input[0].numel() / in_row;
    int64_t out_row = in_row, out_col = 0;
C
chengduoZH 已提交
245

246 247
    int inputs_col_num = in_num + 1;
    std::vector<const T*> inputs_data_vec(in_num);
W
wuhuachaocoding 已提交
248
    std::vector<int64_t> inputs_col_vec(inputs_col_num);
249
    const T** inputs_data = inputs_data_vec.data();
W
wuhuachaocoding 已提交
250
    int64_t* inputs_col = inputs_col_vec.data();
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, col_alloc;
    data_alloc =
        memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*));
    inputs_data = reinterpret_cast<const T**>(data_alloc->ptr());
    col_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                              inputs_col_num * sizeof(int));
W
wuhuachaocoding 已提交
266
    inputs_col = reinterpret_cast<int64_t*>(col_alloc->ptr());
267
#endif
C
chengduoZH 已提交
268

C
chengduoZH 已提交
269
    inputs_col[0] = 0;
270
    bool has_same_shape = true;
C
chengduoZH 已提交
271
    for (int i = 0; i < in_num; ++i) {
W
wuhuachaocoding 已提交
272
      int64_t t_cols = input[i].numel() / in_row;
273 274
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
275
      }
C
chengduoZH 已提交
276 277
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
278
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
279 280
    }

281 282 283 284 285 286
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
287
    if (!has_same_shape || in_num < 2 || in_num > 4) {
288
      tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*));
289
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
290
                   tmp_dev_ins_data->ptr(), platform::CPUPlace(),
291 292
                   static_cast<void*>(inputs_data), in_num * sizeof(T*),
                   context.stream());
293
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
294
    }
295 296 297 298 299 300

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
301 302 303 304 305 306 307 308
      } else if (in_num == 3) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row,
            out_col, output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3],
            in_col, out_row, out_col, output->data<T>());
309 310 311 312
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
313
    } else {
314
      auto tmp_dev_ins_col_data =
W
wuhuachaocoding 已提交
315
          memory::Alloc(context, inputs_col_num * sizeof(int64_t));
316
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
317
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
W
wuhuachaocoding 已提交
318 319 320 321
                   static_cast<void*>(inputs_col),
                   inputs_col_num * sizeof(int64_t), context.stream());
      int64_t* dev_ins_col_data =
          static_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
322

323
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
324
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num),
C
chengduoZH 已提交
325
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
326
    }
W
wuhuachaocoding 已提交
327

328 329 330 331 332 333 334 335 336 337 338 339
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* col_alloc_released = col_alloc.release();
    context.AddStreamCallback([data_alloc_released, col_alloc_released] {
      memory::allocation::AllocationDeleter deleter;
      deleter(data_alloc_released);
      deleter(col_alloc_released);
    });
#endif
C
chengduoZH 已提交
340 341 342
  }
};

C
chengduoZH 已提交
343 344
/*
 * All tensors' dimension should be the same and the values of
345
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
346
 */
C
chengduoZH 已提交
347
template <typename T>
C
chengduo 已提交
348
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
349
 public:
350
  SplitFunctor();
C
chengduoZH 已提交
351
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
352
                  const framework::Tensor& input,
C
chengduoZH 已提交
353 354
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
355
    // TODO(zcd): Add input data validity checking
356
    int o_num = outputs->size();
T
Thunderbrook 已提交
357
    int64_t out_row = 1;
Q
qiaolongfei 已提交
358
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
359
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
360
      out_row *= dim_0[i];
C
chengduoZH 已提交
361 362
    }

T
Thunderbrook 已提交
363 364
    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
365
    bool has_same_shape = true;
C
chengduoZH 已提交
366

367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    int outputs_cols_num = o_num + 1;
    std::vector<T*> outputs_data_vec(o_num);
    std::vector<int64_t> outputs_cols_vec(outputs_cols_num);
    T** outputs_data = outputs_data_vec.data();
    int64_t* outputs_cols = outputs_cols_vec.data();

// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
    memory::AllocationPtr data_alloc, cols_alloc;
    data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*));
    outputs_data = reinterpret_cast<T**>(data_alloc->ptr());
    cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(),
                               (outputs_cols_num) * sizeof(int64_t));
    outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr());
#endif
C
chengduoZH 已提交
388

C
chengduoZH 已提交
389 390
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
T
Thunderbrook 已提交
391
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
392 393
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
394
      }
C
chengduoZH 已提交
395 396
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
397
      if (outputs->at(i) != nullptr) {
398
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
399
      } else {
400
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
401
      }
C
chengduoZH 已提交
402 403
    }

404 405 406 407 408 409
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
410
    if (!has_same_shape || o_num < 2 || o_num > 4) {
411
      tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*));
412
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
413
                   tmp_dev_outs_data->ptr(), platform::CPUPlace(),
414 415
                   reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*),
                   context.stream());
416
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
417
    }
418 419 420 421 422 423

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
424 425 426 427 428 429 430 431
      } else if (o_num == 3) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2], outputs_data[3]);
432 433 434 435
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
436
    } else {
437
      auto tmp_dev_ins_col_data =
438
          memory::Alloc(context, outputs_cols_num * sizeof(int64_t));
439
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
440
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
441 442
                   reinterpret_cast<void*>(outputs_cols),
                   outputs_cols_num * sizeof(int64_t), context.stream());
T
Thunderbrook 已提交
443 444
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
445

446
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
447
          input.data<T>(), in_row, in_col, dev_outs_col_data,
448
          static_cast<int>(outputs_cols_num), dev_out_gpu_data);
C
chengduoZH 已提交
449
    }
450 451 452 453 454 455 456 457 458 459 460 461
#ifdef PADDLE_WITH_HIP
    // Prevent the pinned memory value from being covered and release the memory
    // after the launch kernel of the stream is executed (reapply pinned memory
    // next time)
    auto* data_alloc_released = data_alloc.release();
    auto* cols_alloc_released = cols_alloc.release();
    context.AddStreamCallback([data_alloc_released, cols_alloc_released] {
      memory::allocation::AllocationDeleter deleter;
      deleter(data_alloc_released);
      deleter(cols_alloc_released);
    });
#endif
C
chengduoZH 已提交
462 463 464
  }
};

C
chengduoZH 已提交
465 466
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
467
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
468

C
chengduoZH 已提交
469
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
470

C
chengduoZH 已提交
471 472 473
}  // namespace math
}  // namespace operators
}  // namespace paddle