concat_and_split.cu 11.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduo 已提交
18
#include "paddle/fluid/operators/math/concat_and_split.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20
#include "paddle/fluid/platform/float16.h"
C
chengduoZH 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

template <typename T>
C
chengduo 已提交
27
__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
28 29 30
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
31 32
  int curr_segment = 0;
  int curr_offset = input_cols[0];
C
chengduoZH 已提交
33
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
34 35
    int curr_col_offset = input_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
36 37
      curr_offset = curr_col_offset;
      ++curr_segment;
38
      curr_col_offset = input_cols[curr_segment + 1];
C
chengduoZH 已提交
39 40 41 42
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
43

C
chengduoZH 已提交
44
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
45
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
46 47 48 49 50 51
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
52
template <typename T>
C
chengduo 已提交
53
__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
C
chengduoZH 已提交
54 55
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
56
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
57 58 59 60
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
65
    }
C
chengduoZH 已提交
66 67 68 69
  }
}

template <typename T>
C
chengduo 已提交
70 71 72
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int* out_cols,
                            int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
73
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
74 75
  int curr_segment = 0;
  int curr_offset = out_cols[0];
C
chengduoZH 已提交
76
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
77 78
    int curr_col_offset = out_cols[curr_segment + 1];
    while (curr_col_offset <= tid_x) {
C
chengduoZH 已提交
79 80
      curr_offset = curr_col_offset;
      ++curr_segment;
81
      curr_col_offset = out_cols[curr_segment + 1];
C
chengduoZH 已提交
82 83 84 85
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
86
    T* output_ptr = outputs_data[curr_segment];
Q
qiaolongfei 已提交
87 88 89 90 91 92
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * segment_width + local_col] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
93 94 95 96
  }
}

template <typename T>
C
chengduo 已提交
97 98 99
__global__ void SplitKernel(const T* input_data, const int in_row,
                            const int in_col, const int fixed_out_col,
                            T** outputs_data) {
C
chengduoZH 已提交
100
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
101 102 103 104
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
Q
qiaolongfei 已提交
105 106 107 108 109 110
    if (output_ptr != nullptr) {
      int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
      for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
        output_ptr[tid_y * fixed_out_col + in_offset] =
            input_data[tid_y * in_col + tid_x];
    }
C
chengduoZH 已提交
111 112 113
  }
}

C
chengduoZH 已提交
114
/*
C
chengduoZH 已提交
115
 * All tensors' dimension should be the same and the values of
116
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
117 118 119 120 121
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
122
                  const std::vector<framework::Tensor>& input, int axis,
C
chengduoZH 已提交
123
                  framework::Tensor* output) {
C
chengduoZH 已提交
124
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
125 126
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
127 128
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
129
      in_row *= dim_0[i];
C
chengduoZH 已提交
130
    }
C
chengduoZH 已提交
131 132
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
133

134 135
    std::vector<T*> inputs_data(in_num);
    std::vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
136

C
chengduoZH 已提交
137
    inputs_col[0] = 0;
C
chengduoZH 已提交
138
    bool sameShape = true;
C
chengduoZH 已提交
139 140
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
141
      if (sameShape) {
C
chengduoZH 已提交
142
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
143
      }
C
chengduoZH 已提交
144 145
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
146
      inputs_data[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
147 148 149
    }

    // computation
C
chengduoZH 已提交
150 151
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
152
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
153 154
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
155 156
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
157 158
    dim3 block_size = dim3(block_cols, block_rows, 1);

159
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
160 161 162
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
163
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
164
    int grid_rows =
C
chengduoZH 已提交
165
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
166 167
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

168 169 170 171 172 173 174 175 176
    auto tmp_dev_ins_data =
        platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
            inputs_data.size() * sizeof(T*));
    memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                 tmp_dev_ins_data->ptr(), platform::CPUPlace(),
                 static_cast<void*>(inputs_data.data()),
                 inputs_data.size() * sizeof(T*), context.stream());
    T** dev_ins_data = reinterpret_cast<T**>(tmp_dev_ins_data->ptr());

C
chengduoZH 已提交
177
    if (sameShape) {
C
chengduo 已提交
178
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
179
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
180
    } else {
181 182 183 184 185 186 187 188 189
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              inputs_col.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   static_cast<void*>(inputs_col.data()),
                   inputs_col.size() * sizeof(int), context.stream());
      int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());

C
chengduo 已提交
190
      ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
191 192
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
193 194 195 196
    }
  }
};

C
chengduoZH 已提交
197 198
/*
 * All tensors' dimension should be the same and the values of
199
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
200
 */
C
chengduoZH 已提交
201
template <typename T>
C
chengduo 已提交
202
class SplitFunctor<platform::CUDADeviceContext, T> {
C
chengduoZH 已提交
203 204
 public:
  void operator()(const platform::CUDADeviceContext& context,
Q
qiaolongfei 已提交
205
                  const framework::Tensor& input,
C
chengduoZH 已提交
206 207
                  const std::vector<const framework::Tensor*>& ref_inputs,
                  int axis, std::vector<framework::Tensor*>* outputs) {
C
chengduoZH 已提交
208
    // TODO(zcd): Add input data validity checking
209
    int o_num = outputs->size();
C
chengduoZH 已提交
210
    int out_row = 1;
Q
qiaolongfei 已提交
211
    auto dim_0 = ref_inputs[0]->dims();
C
chengduoZH 已提交
212
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
213
      out_row *= dim_0[i];
C
chengduoZH 已提交
214 215
    }

Q
qiaolongfei 已提交
216
    int out0_col = ref_inputs[0]->numel() / out_row;
C
chengduoZH 已提交
217
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
218 219
    bool sameShape = true;

220 221
    std::vector<T*> outputs_data(o_num);
    std::vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
222

C
chengduoZH 已提交
223 224
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
Q
qiaolongfei 已提交
225
      int t_col = ref_inputs.at(i)->numel() / out_row;
C
chengduoZH 已提交
226
      if (sameShape) {
Q
qiaolongfei 已提交
227
        if (t_col != out0_col) sameShape = false;
C
chengduoZH 已提交
228
      }
C
chengduoZH 已提交
229 230
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
Q
qiaolongfei 已提交
231
      if (outputs->at(i) != nullptr) {
232
        outputs_data[i] = outputs->at(i)->data<T>();
Q
qiaolongfei 已提交
233
      } else {
234
        outputs_data[i] = nullptr;
Q
qiaolongfei 已提交
235
      }
C
chengduoZH 已提交
236 237 238
    }

    // computation
C
chengduoZH 已提交
239
    const int kThreadsPerBlock = 1024;
240
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
241 242
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
243 244
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
245 246
    dim3 block_size = dim3(block_cols, block_rows, 1);

247 248 249 250
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
251
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
252
    int grid_rows =
C
chengduoZH 已提交
253
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
254 255
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

256 257 258 259 260 261 262 263 264
    auto tmp_dev_outs_data =
        platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
            outputs_data.size() * sizeof(T*));
    memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                 tmp_dev_outs_data->ptr(), platform::CPUPlace(),
                 reinterpret_cast<void*>(outputs_data.data()),
                 outputs_data.size() * sizeof(T*), context.stream());
    T** dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());

C
chengduoZH 已提交
265
    if (sameShape) {
C
chengduo 已提交
266
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
Q
qiaolongfei 已提交
267
          input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
C
chengduoZH 已提交
268
    } else {
269 270 271 272 273 274 275 276 277 278
      auto tmp_dev_ins_col_data =
          platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
              outputs_cols.size() * sizeof(int));
      memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
                   tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
                   reinterpret_cast<void*>(outputs_cols.data()),
                   outputs_cols.size() * sizeof(int), context.stream());
      int* dev_outs_col_data =
          reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());

C
chengduo 已提交
279
      SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
280 281
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
282
    }
C
chengduoZH 已提交
283 284 285
  }
};

C
chengduoZH 已提交
286 287
#define DEFINE_FUNCTOR(type)                                       \
  template class ConcatFunctor<platform::CUDADeviceContext, type>; \
C
chengduo 已提交
288
  template class SplitFunctor<platform::CUDADeviceContext, type>
C
chengduoZH 已提交
289

C
chengduoZH 已提交
290
FOR_ALL_TYPES(DEFINE_FUNCTOR);
C
chengduoZH 已提交
291

C
chengduoZH 已提交
292 293 294
}  // namespace math
}  // namespace operators
}  // namespace paddle