concat_and_split.cu 11.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduo 已提交
18
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
C
chengduo 已提交
27
__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
28 29 30
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
31 32
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
33
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
34 35
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
36 37
      curr_offset = curr_col_offset;
      ++curr_segment;
38
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
39 40 41 42
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
43

C
chengduoZH 已提交
44
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
45
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
52
template <typename T>
C
chengduo 已提交
53
__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
C
chengduoZH 已提交
54 55
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
56
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
57 58 59 60
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
65
    }
C
chengduoZH 已提交
66 67 68 69
  }
}

template <typename T>
C
chengduo 已提交
70 71 72
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int* out_cols,
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
73
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
74 75
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
76
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
77 78
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
79 80
      curr_offset = curr_col_offset;
      ++curr_segment;
81
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
82 83 84 85
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
86
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
87 88 89 90 91 92
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
93 94 95 96
  }
}

template <typename T>
C
chengduo 已提交
97 98 99
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T** outputs_data) {
C
chengduoZH 已提交
100
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
101 102 103 104
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
105 106 107 108 109 110
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
111 112 113
  }
}

C
chengduoZH 已提交
114
/*
C
chengduoZH 已提交
115
 * All tensors' dimension should be the same and the values of
116
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
117 118 119 120 121
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
122
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
123
                  framework::Tensor* output) {
C
chengduoZH 已提交
124
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
125 126
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
127 128
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
129
      in_row *= dim_0[i];
C
chengduoZH 已提交
130
    }
C
chengduoZH 已提交
131 132
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
133

C
chengduo 已提交
134
    std::vector<const T*> inputs_data;
135
    std::vector<int> inputs_col(in_num + 1);
C
chengduo 已提交
136
    inputs_data.reserve(in_num);
C
chengduoZH 已提交
137

C
chengduoZH 已提交
138
    inputs_col[0] = 0;
C
chengduoZH 已提交
139
    bool sameShape = true;
C
chengduoZH 已提交
140 141
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
142
      if (sameShape) {
C
chengduoZH 已提交
143
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
144
      }
C
chengduoZH 已提交
145 146
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduo 已提交
147
      inputs_data.emplace_back(input[i].data<T>());
C
chengduoZH 已提交
148 149 150
    }

    // computation
C
chengduoZH 已提交
151 152
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
153
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
154 155
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
156 157
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
158 159
    dim3 block_size = dim3(block_cols, block_rows, 1);

160
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
161 162 163
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
164
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
165
    int grid_rows =
C
chengduoZH 已提交
166
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
167 168
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

169 170 171 172 173 174 175 176 177
    auto tmp_dev_ins_data =
        platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
            inputs_data.size() * sizeof(T*));
    memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                 tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                 static_cast<void*>(inputs_data.data()),
                 inputs_data.size() * sizeof(T*), context.stream());
    T** dev_ins_data = reinterpret_cast<T**>(tmp_dev_ins_data->ptr());

C
chengduoZH 已提交
178
    if (sameShape) {
C
chengduo 已提交
179
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
180
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
181
    } else {
182 183 184 185 186 187 188 189 190
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_col.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_col.data()),
                   inputs_col.size() * sizeof(int), context.stream());
      int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());

C
chengduo 已提交
191
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
192 193
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
194 195 196 197
    }
  }
};

C
chengduoZH 已提交
198 199
/*
 * All tensors' dimension should be the same and the values of
200
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
201
 */
C
chengduoZH 已提交
202
template <typename T>
C
chengduo 已提交
203
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
204 205
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
206
                  const framework::Tensor& input,
C
chengduoZH 已提交
207 208
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
209
    // TODO(zcd): Add input data validity checking
210
    int o_num = outputs->size();
C
chengduoZH 已提交
211
    int out_row = 1;
Q
qiaolongfei 已提交
212
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
213
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
214
      out_row *= dim_0[i];
C
chengduoZH 已提交
215 216
    }

Q
qiaolongfei 已提交
217
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
218
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
219 220
    bool sameShape = true;

221 222
    std::vector<T*> outputs_data(o_num);
    std::vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
223

C
chengduoZH 已提交
224 225
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
226
      int t_col = ref_inputs.at(i)->numel() / out_row;
C
chengduoZH 已提交
227
      if (sameShape) {
Q
qiaolongfei 已提交
228
        if (t_col != out0_col) sameShape = false;
C
chengduoZH 已提交
229
      }
C
chengduoZH 已提交
230 231
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
232
      if (outputs->at(i) != nullptr) {
233
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
234
      } else {
235
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
236
      }
C
chengduoZH 已提交
237 238 239
    }

    // computation
C
chengduoZH 已提交
240
    const int kThreadsPerBlock = 1024;
241
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
242 243
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
244 245
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
246 247
    dim3 block_size = dim3(block_cols, block_rows, 1);

248 249 250 251
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
252
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
253
    int grid_rows =
C
chengduoZH 已提交
254
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
255 256
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

257 258 259 260 261 262 263 264 265
    auto tmp_dev_outs_data =
        platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
            outputs_data.size() * sizeof(T*));
    memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                 tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                 reinterpret_cast<void*>(outputs_data.data()),
                 outputs_data.size() * sizeof(T*), context.stream());
    T** dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());

C
chengduoZH 已提交
266
    if (sameShape) {
C
chengduo 已提交
267
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
Q
qiaolongfei 已提交
268
          input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
C
chengduoZH 已提交
269
    } else {
270 271 272 273 274 275 276 277 278 279
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_cols.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_cols.data()),
                   outputs_cols.size() * sizeof(int), context.stream());
      int* dev_outs_col_data =
          reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());

C
chengduo 已提交
280
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
281 282
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
283
    }
C
chengduoZH 已提交
284 285 286
  }
};

C
chengduoZH 已提交
287 288
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
289
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
290

C
chengduoZH 已提交
291
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
292

C
chengduoZH 已提交
293 294 295
}  // namespace math
}  // namespace operators
}  // namespace paddle