concat_and_split.cu 9.8 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduo 已提交
18
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
C
chengduo 已提交
27
__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
28 29 30
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
31 32
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
33
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
34 35
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
36 37
      curr_offset = curr_col_offset;
      ++curr_segment;
38
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
39 40 41 42
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
43

C
chengduoZH 已提交
44
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
45
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
52
template <typename T>
C
chengduo 已提交
53
__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
C
chengduoZH 已提交
54 55
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
56
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
57 58 59 60
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
65
    }
C
chengduoZH 已提交
66 67 68 69
  }
}

template <typename T>
C
chengduo 已提交
70 71 72
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int* out_cols,
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
73
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
74 75
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
76
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
77 78
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
79 80
      curr_offset = curr_col_offset;
      ++curr_segment;
81
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
82 83 84 85
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
86
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
87 88 89 90 91 92
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
93 94 95 96
  }
}

template <typename T>
C
chengduo 已提交
97 98 99
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T** outputs_data) {
C
chengduoZH 已提交
100
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
101 102 103 104
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
105 106 107 108 109 110
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
111 112 113
  }
}

C
chengduoZH 已提交
114
/*
C
chengduoZH 已提交
115
 * All tensors' dimension should be the same and the values of
116
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
117 118 119 120 121
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
122
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
123
                  framework::Tensor* output) {
C
chengduoZH 已提交
124
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
125 126
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
127 128
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
129
      in_row *= dim_0[i];
C
chengduoZH 已提交
130
    }
C
chengduoZH 已提交
131 132
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
133

C
chengduoZH 已提交
134 135
    framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
    framework::Vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
136 137
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

C
chengduoZH 已提交
138
    inputs_col[0] = 0;
C
chengduoZH 已提交
139
    bool sameShape = true;
C
chengduoZH 已提交
140 141
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
142
      if (sameShape) {
C
chengduoZH 已提交
143
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
144
      }
C
chengduoZH 已提交
145 146
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduoZH 已提交
147
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
148 149
    }

C
chengduoZH 已提交
150
    T** dev_ins_data =
C
chengduoZH 已提交
151 152
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
153
    // computation
C
chengduoZH 已提交
154 155
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
156
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
157 158
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
159 160
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
161 162
    dim3 block_size = dim3(block_cols, block_rows, 1);

163
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
164 165 166
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
167
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
168
    int grid_rows =
C
chengduoZH 已提交
169
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
170 171
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
172
    if (sameShape) {
C
chengduo 已提交
173
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
174
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
175
    } else {
C
chengduoZH 已提交
176
      const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
C
chengduo 已提交
177
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
178 179
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
180
    }
D
dzhwinter 已提交
181 182
    // Wait() must be called because `inputs_data` may be destructed before
    // kernel ends
M
minqiyang 已提交
183
    context.Wait();
C
chengduoZH 已提交
184 185 186
  }
};

C
chengduoZH 已提交
187 188
/*
 * All tensors' dimension should be the same and the values of
189
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
190
 */
C
chengduoZH 已提交
191
template <typename T>
C
chengduo 已提交
192
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
193 194
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
195
                  const framework::Tensor& input,
C
chengduoZH 已提交
196 197
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
198
    // TODO(zcd): Add input data validity checking
199
    int o_num = outputs->size();
C
chengduoZH 已提交
200
    int out_row = 1;
Q
qiaolongfei 已提交
201
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
202
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
203
      out_row *= dim_0[i];
C
chengduoZH 已提交
204 205
    }

Q
qiaolongfei 已提交
206
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
207
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
208 209
    bool sameShape = true;

C
chengduoZH 已提交
210 211
    framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
    framework::Vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
212
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
213

C
chengduoZH 已提交
214 215
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
216
      int t_col = ref_inputs.at(i)->numel() / out_row;
C
chengduoZH 已提交
217
      if (sameShape) {
Q
qiaolongfei 已提交
218
        if (t_col != out0_col) sameShape = false;
C
chengduoZH 已提交
219
      }
C
chengduoZH 已提交
220 221
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
222 223 224 225 226
      if (outputs->at(i) != nullptr) {
        outputs_ptr[i] = outputs->at(i)->data<T>();
      } else {
        outputs_ptr[i] = nullptr;
      }
C
chengduoZH 已提交
227 228
    }

C
chengduoZH 已提交
229
    T** dev_out_gpu_data =
C
chengduoZH 已提交
230 231
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
232
    // computation
C
chengduoZH 已提交
233
    const int kThreadsPerBlock = 1024;
234
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
235 236
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
237 238
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
239 240
    dim3 block_size = dim3(block_cols, block_rows, 1);

241 242 243 244
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
245
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
246
    int grid_rows =
C
chengduoZH 已提交
247
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
248 249 250
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
C
chengduo 已提交
251
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
Q
qiaolongfei 已提交
252
          input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
C
chengduoZH 已提交
253
    } else {
C
chengduoZH 已提交
254
      const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
C
chengduo 已提交
255
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
256 257
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
258
    }
D
dzhwinter 已提交
259 260
    // Wait() must be called because `outputs_data` may be destructed before
    // kernel ends
M
minqiyang 已提交
261
    context.Wait();
C
chengduoZH 已提交
262 263 264
  }
};

C
chengduoZH 已提交
265 266
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
267
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
268

C
chengduoZH 已提交
269
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
270

C
chengduoZH 已提交
271 272 273
}  // namespace math
}  // namespace operators
}  // namespace paddle