concat_and_split.cu 16.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduo 已提交
18
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
27 28 29
__global__ void ConcatKernel(const T** inputs, const int* input_cols,
                             int col_size, const int output_rows,
                             const int output_cols, T* output) {
C
chengduoZH 已提交
30
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
31 32
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
33
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
34 35
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
36 37
      curr_offset = curr_col_offset;
      ++curr_segment;
38
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
39 40 41 42
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
43

44
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
45
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
52
template <typename T>
53 54 55
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
56
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
57 58 59
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
60
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
65
    }
C
chengduoZH 已提交
66 67 68
  }
}

69 70 71 72 73 74 75 76 77 78 79
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
  const T* inputs_data[3];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const T* input_addr2, const T* input_addr3,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  const T* inputs_data[4];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  inputs_data[2] = input_addr2;
  inputs_data[3] = input_addr3;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

107 108 109 110 111 112 113 114
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
115
template <typename T>
C
chengduo 已提交
116 117 118
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int* out_cols,
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
119
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
120 121
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
122
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
123 124
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
125 126
      curr_offset = curr_col_offset;
      ++curr_segment;
127
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
128 129 130 131
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
132
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
133 134 135 136 137 138
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
139 140 141 142
  }
}

template <typename T>
143 144 145
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
146
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
147 148 149 150
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
151 152 153 154 155 156
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
157 158 159
  }
}

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2) {
  T* outputs_data[3];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T* outputs_addr0, T* outputs_addr1,
                            T* outputs_addr2, T* outputs_addr3) {
  T* outputs_data[4];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  outputs_data[2] = outputs_addr2;
  outputs_data[3] = outputs_addr3;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
                                int num_rows, int num_cols, dim3* block_dims,
                                dim3* grid_dims) {
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
  int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
  int grid_rows =
      std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1));
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
224
/*
C
chengduoZH 已提交
225
 * All tensors' dimension should be the same and the values of
226
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
227 228 229 230 231
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
232
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
233
                  framework::Tensor* output) {
C
chengduoZH 已提交
234
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
235 236
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
237 238
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
239
      in_row *= dim_0[i];
C
chengduoZH 已提交
240
    }
C
chengduoZH 已提交
241 242
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
243

244
    std::vector<const T*> inputs_data(in_num);
245
    std::vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
246

C
chengduoZH 已提交
247
    inputs_col[0] = 0;
248
    bool has_same_shape = true;
C
chengduoZH 已提交
249 250
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
251 252
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
253
      }
C
chengduoZH 已提交
254 255
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
256
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
257 258
    }

259 260 261 262 263 264
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
265
    if (!has_same_shape || in_num < 2 || in_num > 4) {
266 267 268 269 270 271 272 273
      tmp_dev_ins_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_data.size() * sizeof(T*));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_data.data()),
                   inputs_data.size() * sizeof(T*), context.stream());
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
274
    }
275 276 277 278 279 280

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
281 282 283 284 285 286 287 288
      } else if (in_num == 3) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], in_col, out_row,
            out_col, output->data<T>());
      } else if (in_num == 4) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], inputs_data[2], inputs_data[3],
            in_col, out_row, out_col, output->data<T>());
289 290 291 292
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
293
    } else {
294 295 296 297 298 299 300 301 302
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_col.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_col.data()),
                   inputs_col.size() * sizeof(int), context.stream());
      int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());

303
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
304 305
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
306 307 308 309
    }
  }
};

C
chengduoZH 已提交
310 311
/*
 * All tensors' dimension should be the same and the values of
312
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
313
 */
C
chengduoZH 已提交
314
template <typename T>
C
chengduo 已提交
315
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
316 317
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
318
                  const framework::Tensor& input,
C
chengduoZH 已提交
319 320
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
321
    // TODO(zcd): Add input data validity checking
322
    int o_num = outputs->size();
C
chengduoZH 已提交
323
    int out_row = 1;
Q
qiaolongfei 已提交
324
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
325
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
326
      out_row *= dim_0[i];
C
chengduoZH 已提交
327 328
    }

Q
qiaolongfei 已提交
329
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
330
    int in_col = 0, in_row = out_row;
331
    bool has_same_shape = true;
C
chengduoZH 已提交
332

333 334
    std::vector<T*> outputs_data(o_num);
    std::vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
335

C
chengduoZH 已提交
336 337
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
338
      int t_col = ref_inputs.at(i)->numel() / out_row;
339 340
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
341
      }
C
chengduoZH 已提交
342 343
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
344
      if (outputs->at(i) != nullptr) {
345
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
346
      } else {
347
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
348
      }
C
chengduoZH 已提交
349 350
    }

351 352 353 354 355 356
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
357
    if (!has_same_shape || o_num < 2 || o_num > 4) {
358 359 360 361 362 363 364 365
      tmp_dev_outs_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_data.size() * sizeof(T*));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_data.data()),
                   outputs_data.size() * sizeof(T*), context.stream());
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
366
    }
367 368 369 370 371 372

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
373 374 375 376 377 378 379 380
      } else if (o_num == 3) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2]);
      } else if (o_num == 4) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1], outputs_data[2], outputs_data[3]);
381 382 383 384
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
385
    } else {
386 387 388 389 390 391 392 393 394 395
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_cols.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_cols.data()),
                   outputs_cols.size() * sizeof(int), context.stream());
      int* dev_outs_col_data =
          reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());

396
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
397 398
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
399
    }
C
chengduoZH 已提交
400 401 402
  }
};

C
chengduoZH 已提交
403 404
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
405
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
406

C
chengduoZH 已提交
407
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
408

C
chengduoZH 已提交
409 410 411
}  // namespace math
}  // namespace operators
}  // namespace paddle