concat.cu 10.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
C
chengduoZH 已提交
17
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduoZH 已提交
18
#include "paddle/fluid/operators/math/concat.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
  const T* orig = first;
  const T* it = nullptr;
  T step = 0;
  while (count > 0) {
    it = first;
    step = count / 2;
    it += step;
    if (!(val < *it)) {
      first = ++it;
      count -= step + 1;
    } else {
      count = step;
    }
  }
  return first - orig;
}

template <typename T>
C
chengduoZH 已提交
45
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
46 47 48
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
49
  int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
C
chengduoZH 已提交
50

C
chengduoZH 已提交
51
  int curr_offset = input_cols[segment];
C
chengduoZH 已提交
52 53 54
  int curr_segment = segment;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
C
chengduoZH 已提交
55
    while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
56 57 58 59 60 61
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
62
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
63
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
64 65 66 67 68 69
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
70
template <typename T>
C
chengduoZH 已提交
71 72 73
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
74
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
75 76 77 78
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
79
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
80 81 82
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
83
    }
C
chengduoZH 已提交
84 85 86 87
  }
}

template <typename T>
C
chengduoZH 已提交
88 89 90
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int* out_cols,
                                 int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
91
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
92 93
  int segment = upper_bound<int>(out_cols, out_cols_size, tid_x) - 1;
  int curr_offset = out_cols[segment];
C
chengduoZH 已提交
94
  int curr_segment = segment;
C
chengduoZH 已提交
95
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
96
    T curr_col_offset;
C
chengduoZH 已提交
97
    while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
98 99 100 101 102 103
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
104
    T* output_ptr = outputs_data[curr_segment];
C
chengduoZH 已提交
105
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
106
    for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
C
chengduoZH 已提交
107
      output_ptr[tid_y * segment_width + local_col] =
C
chengduoZH 已提交
108
          input_data[tid_y * in_col + tid_x];
C
chengduoZH 已提交
109 110 111 112
  }
}

template <typename T>
C
chengduoZH 已提交
113 114 115
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int fixed_out_col,
                                 T** outputs_data) {
C
chengduoZH 已提交
116
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
117 118 119 120
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
C
chengduoZH 已提交
121
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
122 123 124
    for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * fixed_out_col + in_offset] =
          input_data[tid_y * in_col + tid_x];
C
chengduoZH 已提交
125 126 127
  }
}

C
chengduoZH 已提交
128
/*
C
chengduoZH 已提交
129
 * All tensors' dimension should be the same and the values of
130
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
131 132 133 134 135
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
136
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
137
                  framework::Tensor* output) {
C
chengduoZH 已提交
138
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
139 140
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
141 142
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
143
      in_row *= dim_0[i];
C
chengduoZH 已提交
144
    }
C
chengduoZH 已提交
145 146
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
147

C
chengduoZH 已提交
148 149
    framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
    framework::Vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
150 151
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

C
chengduoZH 已提交
152
    inputs_col[0] = 0;
C
chengduoZH 已提交
153
    bool sameShape = true;
C
chengduoZH 已提交
154 155
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
156
      if (sameShape) {
C
chengduoZH 已提交
157
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
158
      }
C
chengduoZH 已提交
159 160
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduoZH 已提交
161
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
162 163
    }

C
chengduoZH 已提交
164
    T** dev_ins_data =
C
chengduoZH 已提交
165 166
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
167
    // computation
C
chengduoZH 已提交
168 169
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
170
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
171 172
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
173 174
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
175 176
    dim3 block_size = dim3(block_cols, block_rows, 1);

177
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
178 179 180
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
181
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
182
    int grid_rows =
C
chengduoZH 已提交
183
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
184 185
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
186 187
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
188
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
189
    } else {
C
chengduoZH 已提交
190
      const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
C
chengduoZH 已提交
191
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
192 193
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
194 195 196 197
    }
  }
};

C
chengduoZH 已提交
198 199
/*
 * All tensors' dimension should be the same and the values of
200
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
201
 */
C
chengduoZH 已提交
202 203 204 205 206
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const int axis,
207
                  std::vector<framework::Tensor>* outputs) {
C
chengduoZH 已提交
208
    // TODO(zcd): Add input data validity checking
209
    int o_num = outputs->size();
C
chengduoZH 已提交
210
    int out_row = 1;
211
    auto dim_0 = outputs->at(0).dims();
C
chengduoZH 已提交
212
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
213
      out_row *= dim_0[i];
C
chengduoZH 已提交
214 215
    }

216
    int out_col = outputs->at(0).numel() / out_row;
C
chengduoZH 已提交
217
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
218 219
    bool sameShape = true;

C
chengduoZH 已提交
220 221
    framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
    framework::Vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
222
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
223

C
chengduoZH 已提交
224 225
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
226
      int t_col = outputs->at(i).numel() / out_row;
C
chengduoZH 已提交
227
      if (sameShape) {
C
chengduoZH 已提交
228
        if (t_col != out_col) sameShape = false;
C
chengduoZH 已提交
229
      }
C
chengduoZH 已提交
230 231
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
232
      outputs_ptr[i] = outputs->at(i).data<T>();
C
chengduoZH 已提交
233 234
    }

C
chengduoZH 已提交
235
    T** dev_out_gpu_data =
C
chengduoZH 已提交
236 237
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
238
    // computation
C
chengduoZH 已提交
239
    const int kThreadsPerBlock = 1024;
240
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
241 242
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
243 244
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
245 246
    dim3 block_size = dim3(block_cols, block_rows, 1);

247 248 249 250
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
251
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
252
    int grid_rows =
C
chengduoZH 已提交
253
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
254 255 256 257
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
258
          input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
C
chengduoZH 已提交
259
    } else {
C
chengduoZH 已提交
260
      const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
C
chengduoZH 已提交
261
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
262 263
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
264
    }
C
chengduoZH 已提交
265 266 267 268 269 270 271 272
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
273 274 275 276 277
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
278 279 280
}  // namespace math
}  // namespace operators
}  // namespace paddle