concat_and_split.cu 13.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduo 已提交
18
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
27 28 29
__global__ void ConcatKernel(const T** inputs, const int* input_cols,
                             int col_size, const int output_rows,
                             const int output_cols, T* output) {
C
chengduoZH 已提交
30
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
31 32
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
33
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
34 35
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
36 37
      curr_offset = curr_col_offset;
      ++curr_segment;
38
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
39 40 41 42
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
43

44
    const T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
45
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
52
template <typename T>
53 54 55
__device__ void ConcatKernelDetail(const T** inputs_data,
                                   const int fixed_in_col, const int out_rows,
                                   const int out_cols, T* output_data) {
C
chengduoZH 已提交
56
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
57 58 59
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
60
    const T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
65
    }
C
chengduoZH 已提交
66 67 68
  }
}

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  const T* inputs_data[2];
  inputs_data[0] = input_addr0;
  inputs_data[1] = input_addr1;
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
                             const int fixed_in_col, const int out_rows,
                             const int out_cols, T* output_data) {
  ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
                        output_data);
}

C
chengduoZH 已提交
88
template <typename T>
C
chengduo 已提交
89 90 91
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int* out_cols,
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
92
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
93 94
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
95
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
96 97
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
98 99
      curr_offset = curr_col_offset;
      ++curr_segment;
100
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
101 102 103 104
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
105
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
106 107 108 109 110 111
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
112 113 114 115
  }
}

template <typename T>
116 117 118
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
                                  const int in_col, const int fixed_out_col,
                                  T** outputs_data) {
C
chengduoZH 已提交
119
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
120 121 122 123
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
124 125 126 127 128 129
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
130 131 132
  }
}

133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T** outputs_data) {
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T* outputs_addr0, T* outputs_addr1) {
  T* outputs_data[2];
  outputs_data[0] = outputs_addr0;
  outputs_data[1] = outputs_addr1;
  SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}

static inline void GetBlockDims(const platform::CUDADeviceContext& context,
                                int num_rows, int num_cols, dim3* block_dims,
                                dim3* grid_dims) {
  // Set the thread block and grid according to CurrentDeviceId
  const int kThreadsPerBlock = 1024;
  int block_cols = kThreadsPerBlock;
  if (num_cols < kThreadsPerBlock) {  // block_cols is aligned by 32.
    block_cols = ((num_cols + 31) >> 5) << 5;
  }
  int block_rows = kThreadsPerBlock / block_cols;
  *block_dims = dim3(block_cols, block_rows, 1);

  int max_threads = context.GetMaxPhysicalThreadCount();
  int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

  int grid_cols =
      std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
  int grid_rows =
      std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1));
  *grid_dims = dim3(grid_cols, grid_rows, 1);
}

C
chengduoZH 已提交
172
/*
C
chengduoZH 已提交
173
 * All tensors' dimension should be the same and the values of
174
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
175 176 177 178 179
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
180
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
181
                  framework::Tensor* output) {
C
chengduoZH 已提交
182
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
183 184
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
185 186
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
187
      in_row *= dim_0[i];
C
chengduoZH 已提交
188
    }
C
chengduoZH 已提交
189 190
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
191

192
    std::vector<const T*> inputs_data(in_num);
193
    std::vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
194

C
chengduoZH 已提交
195
    inputs_col[0] = 0;
196
    bool has_same_shape = true;
C
chengduoZH 已提交
197 198
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
199 200
      if (has_same_shape) {
        if (t_cols != in_col) has_same_shape = false;
C
chengduoZH 已提交
201
      }
C
chengduoZH 已提交
202 203
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
204
      inputs_data[i] = input[i].data<T>();
C
chengduoZH 已提交
205 206
    }

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_ins_data;
    const T** dev_ins_data = nullptr;
    if (!has_same_shape || (in_num != 2)) {
      tmp_dev_ins_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_data.size() * sizeof(T*));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_data.data()),
                   inputs_data.size() * sizeof(T*), context.stream());
      dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
222
    }
223 224 225 226 227 228 229 230 231 232

    if (has_same_shape) {
      if (in_num == 2) {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            inputs_data[0], inputs_data[1], in_col, out_row, out_col,
            output->data<T>());
      } else {
        ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
      }
C
chengduoZH 已提交
233
    } else {
234 235 236 237 238 239 240 241 242
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_col.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_col.data()),
                   inputs_col.size() * sizeof(int), context.stream());
      int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());

243
      ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
244 245
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
246 247 248 249
    }
  }
};

C
chengduoZH 已提交
250 251
/*
 * All tensors' dimension should be the same and the values of
252
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
253
 */
C
chengduoZH 已提交
254
template <typename T>
C
chengduo 已提交
255
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
256 257
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
258
                  const framework::Tensor& input,
C
chengduoZH 已提交
259 260
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
261
    // TODO(zcd): Add input data validity checking
262
    int o_num = outputs->size();
C
chengduoZH 已提交
263
    int out_row = 1;
Q
qiaolongfei 已提交
264
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
265
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
266
      out_row *= dim_0[i];
C
chengduoZH 已提交
267 268
    }

Q
qiaolongfei 已提交
269
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
270
    int in_col = 0, in_row = out_row;
271
    bool has_same_shape = true;
C
chengduoZH 已提交
272

273 274
    std::vector<T*> outputs_data(o_num);
    std::vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
275

C
chengduoZH 已提交
276 277
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
278
      int t_col = ref_inputs.at(i)->numel() / out_row;
279 280
      if (has_same_shape) {
        if (t_col != out0_col) has_same_shape = false;
C
chengduoZH 已提交
281
      }
C
chengduoZH 已提交
282 283
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
284
      if (outputs->at(i) != nullptr) {
285
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
286
      } else {
287
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
288
      }
C
chengduoZH 已提交
289 290
    }

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
    dim3 block_dims;
    dim3 grid_dims;
    GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);

    memory::allocation::AllocationPtr tmp_dev_outs_data;
    T** dev_out_gpu_data = nullptr;
    if (!has_same_shape || (o_num != 2)) {
      tmp_dev_outs_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_data.size() * sizeof(T*));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_data.data()),
                   outputs_data.size() * sizeof(T*), context.stream());
      dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
306
    }
307 308 309 310 311 312 313 314 315 316

    if (has_same_shape) {
      if (o_num == 2) {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
            outputs_data[1]);
      } else {
        SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
            input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
      }
C
chengduoZH 已提交
317
    } else {
318 319 320 321 322 323 324 325 326 327
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_cols.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_cols.data()),
                   outputs_cols.size() * sizeof(int), context.stream());
      int* dev_outs_col_data =
          reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());

328
      SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
C
chengduoZH 已提交
329 330
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
331
    }
C
chengduoZH 已提交
332 333 334
  }
};

C
chengduoZH 已提交
335 336
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
337
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
338

C
chengduoZH 已提交
339
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
340

C
chengduoZH 已提交
341 342 343
}  // namespace math
}  // namespace operators
}  // namespace paddle