concat_and_split.cu 15.9 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
18
#include "paddle/fluid/memory/malloc.h"
C
chengduo 已提交
19
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
20
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
21
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
22 23 24 25 26 27

namespace paddle {
namespace operators {
namespace math {

template <typename T>
28 29 30
__global__ void ConcatKernel(const T** inputs, const int* input_cols,
                             int col_size, const int output_rows,
                             const int output_cols, T* output) {
C
chengduoZH 已提交
31
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
32 33
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
34
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
35 36
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
37 38
      curr_offset = curr_col_offset;
      ++curr_segment;
39
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
40 41 42 43
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
44

45
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
46
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
47 48 49 50 51 52
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
53
template <typename T>
54 55 56
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
57
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
58 59 60
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
61
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
62
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
63 64 65
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
66
    }
C
chengduoZH 已提交
67 68 69
  }
}

70 71 72 73 74 75 76 77 78 79 80
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const T* input_addr3,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

108 109 110 111 112 113 114 115
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
116
template <typename T>
T
Thunderbrook 已提交
117 118
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t* out_cols,
C
chengduo 已提交
119
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
120
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
121 122
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
123
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
124 125
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
126 127
      curr_offset = curr_col_offset;
      ++curr_segment;
128
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
129 130 131 132
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
133
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
134 135 136 137 138 139
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
140 141 142 143
  }
}

template <typename T>
144 145 146
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
147
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
148 149 150 151
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
152 153 154 155 156 157
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
158 159 160
  }
}

161
template <typename T>
T
Thunderbrook 已提交
162 163
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
164 165 166 167 168
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
169 170
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
171 172 173 174 175 176 177
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

178
template <typename T>
T
Thunderbrook 已提交
179 180
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
181 182 183 184 185 186 187 188 189 190
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
T
Thunderbrook 已提交
191 192
__global__ void SplitKernel(const T* input_data, const int64_t in_row,
                            const int64_t in_col, const int64_t fixed_out_col,
193 194 195 196 197 198 199 200 201 202
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2, T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

203
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
T
Thunderbrook 已提交
204 205
                                int64_t num_rows, int64_t num_cols,
                                dim3* block_dims, dim3* grid_dims) {
206 207 208 209 210 211 212 213 214 215
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
T
Thunderbrook 已提交
216
  int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
217 218 219

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
T
Thunderbrook 已提交
220 221
  int grid_rows = std::min(max_blocks / grid_cols,
                           std::max(num_rows / block_rows, (int64_t)1));
222 223 224
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
225
/*
C
chengduoZH 已提交
226
 * All tensors' dimension should be the same and the values of
227
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
228 229 230 231 232
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
233
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
234
                  framework::Tensor* output) {
C
chengduoZH 已提交
235
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
236 237
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
238 239
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
240
      in_row *= dim_0[i];
C
chengduoZH 已提交
241
    }
C
chengduoZH 已提交
242 243
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
244

245
    std::vector<const T*> inputs_data(in_num);
246
    std::vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
247

C
chengduoZH 已提交
248
    inputs_col[0] = 0;
249
    bool has_same_shape = true;
C
chengduoZH 已提交
250 251
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
252 253
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
254
      }
C
chengduoZH 已提交
255 256
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
257
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
258 259
    }

260 261 262 263 264 265
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
266
    if (!has_same_shape || in_num < 2 || in_num > 4) {
267
      tmp_dev_ins_data =
268
          memory::Alloc(context, inputs_data.size() * sizeof(T*));
269
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
270 271 272 273
                   tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_data.data()),
                   inputs_data.size() * sizeof(T*), context.stream());
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
274
    }
275 276 277 278 279 280

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
281 282 283 284 285 286 287 288
      } else if (in_num == 3) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row,
            out_col, output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3],
            in_col, out_row, out_col, output->data<T>());
289 290 291 292
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
293
    } else {
294
      auto tmp_dev_ins_col_data =
295
          memory::Alloc(context, inputs_col.size() * sizeof(int));
296
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
297 298 299 300 301
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_col.data()),
                   inputs_col.size() * sizeof(int), context.stream());
      int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());

302
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
303 304
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
305 306 307 308
    }
  }
};

C
chengduoZH 已提交
309 310
/*
 * All tensors' dimension should be the same and the values of
311
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
312
 */
C
chengduoZH 已提交
313
template <typename T>
C
chengduo 已提交
314
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
315 316
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
317
                  const framework::Tensor& input,
C
chengduoZH 已提交
318 319
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
320
    // TODO(zcd): Add input data validity checking
321
    int o_num = outputs->size();
T
Thunderbrook 已提交
322
    int64_t out_row = 1;
Q
qiaolongfei 已提交
323
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
324
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
325
      out_row *= dim_0[i];
C
chengduoZH 已提交
326 327
    }

T
Thunderbrook 已提交
328 329
    int64_t out0_col = ref_inputs[0]->numel() / out_row;
    int64_t in_col = 0, in_row = out_row;
330
    bool has_same_shape = true;
C
chengduoZH 已提交
331

332
    std::vector<T*> outputs_data(o_num);
T
Thunderbrook 已提交
333
    std::vector<int64_t> outputs_cols(o_num + 1);
C
chengduoZH 已提交
334

C
chengduoZH 已提交
335 336
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
T
Thunderbrook 已提交
337
      int64_t t_col = ref_inputs.at(i)->numel() / out_row;
338 339
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
340
      }
C
chengduoZH 已提交
341 342
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
343
      if (outputs->at(i) != nullptr) {
344
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
345
      } else {
346
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
347
      }
C
chengduoZH 已提交
348 349
    }

350 351 352 353 354 355
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
356
    if (!has_same_shape || o_num < 2 || o_num > 4) {
357
      tmp_dev_outs_data =
358
          memory::Alloc(context, outputs_data.size() * sizeof(T*));
359
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
360 361 362 363
                   tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_data.data()),
                   outputs_data.size() * sizeof(T*), context.stream());
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
364
    }
365 366 367 368 369 370

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
371 372 373 374 375 376 377 378
      } else if (o_num == 3) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2], outputs_data[3]);
379 380 381 382
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
383
    } else {
384
      auto tmp_dev_ins_col_data =
385 386
          memory::Alloc(context,

T
Thunderbrook 已提交
387
                        outputs_cols.size() * sizeof(int64_t));
388
      memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
389 390
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_cols.data()),
T
Thunderbrook 已提交
391 392 393
                   outputs_cols.size() * sizeof(int64_t), context.stream());
      int64_t* dev_outs_col_data =
          reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
394

395
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
396 397
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
398
    }
C
chengduoZH 已提交
399 400 401
  }
};

C
chengduoZH 已提交
402 403
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
404
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
405

C
chengduoZH 已提交
406
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
407

C
chengduoZH 已提交
408 409 410
}  // namespace math
}  // namespace operators
}  // namespace paddle