concat.cu 10.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduoZH 已提交
15
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduoZH 已提交
16
#include "paddle/fluid/operators/math/concat.h"
D
dzhwinter 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
  const T* orig = first;
  const T* it = nullptr;
  T step = 0;
  while (count > 0) {
    it = first;
    step = count / 2;
    it += step;
    if (!(val < *it)) {
      first = ++it;
      count -= step + 1;
    } else {
      count = step;
    }
  }
  return first - orig;
}

template <typename T>
C
chengduoZH 已提交
43
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
44 45 46
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
47
  int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
C
chengduoZH 已提交
48

C
chengduoZH 已提交
49
  int curr_offset = input_cols[segment];
C
chengduoZH 已提交
50 51 52
  int curr_segment = segment;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
C
chengduoZH 已提交
53
    while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
54 55 56 57 58 59
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
60
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64 65 66 67
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
68
template <typename T>
C
chengduoZH 已提交
69 70 71
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
                             const int out_rows, const int out_cols,
                             T* output_data) {
C
chengduoZH 已提交
72
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
73 74 75 76
  for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * 1.0 / fixed_in_col;
    int in_offset = tid_x - split * fixed_in_col;
    T* input_ptr = inputs_data[split];
C
chengduoZH 已提交
77
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
78 79 80
    for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
      output_data[tid_y * out_cols + tid_x] =
          input_ptr[tid_y * fixed_in_col + in_offset];
C
chengduoZH 已提交
81
    }
C
chengduoZH 已提交
82 83 84 85
  }
}

template <typename T>
C
chengduoZH 已提交
86 87 88
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int* out_cols,
                                 int out_cols_size, T** outputs_data) {
C
chengduoZH 已提交
89
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
90 91
  int segment = upper_bound<int>(out_cols, out_cols_size, tid_x) - 1;
  int curr_offset = out_cols[segment];
C
chengduoZH 已提交
92
  int curr_segment = segment;
C
chengduoZH 已提交
93
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
94
    T curr_col_offset;
C
chengduoZH 已提交
95
    while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
96 97 98 99 100 101
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
102
    T* output_ptr = outputs_data[curr_segment];
C
chengduoZH 已提交
103
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
104
    for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
C
chengduoZH 已提交
105
      output_ptr[tid_y * segment_width + local_col] =
C
chengduoZH 已提交
106
          input_data[tid_y * in_col + tid_x];
C
chengduoZH 已提交
107 108 109 110
  }
}

template <typename T>
C
chengduoZH 已提交
111 112 113
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
                                 const int in_col, const int fixed_out_col,
                                 T** outputs_data) {
C
chengduoZH 已提交
114
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
115 116 117 118
  for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x / fixed_out_col;
    int in_offset = tid_x - split * fixed_out_col;
    T* output_ptr = outputs_data[split];
C
chengduoZH 已提交
119
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
120 121 122
    for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * fixed_out_col + in_offset] =
          input_data[tid_y * in_col + tid_x];
C
chengduoZH 已提交
123 124 125
  }
}

C
chengduoZH 已提交
126
/*
C
chengduoZH 已提交
127
 * All tensors' dimension should be the same and the values of
128
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
129 130 131 132 133
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
134
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
135
                  framework::Tensor* output) {
C
chengduoZH 已提交
136
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
137 138
    int in_num = input.size();
    int in_row = 1;
C
chengduoZH 已提交
139 140
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
141
      in_row *= dim_0[i];
C
chengduoZH 已提交
142
    }
C
chengduoZH 已提交
143 144
    int in_col = input[0].numel() / in_row;
    int out_row = in_row, out_col = 0;
C
chengduoZH 已提交
145

C
chengduoZH 已提交
146 147
    framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
    framework::Vector<int> inputs_col(in_num + 1);
C
chengduoZH 已提交
148 149
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

C
chengduoZH 已提交
150
    inputs_col[0] = 0;
C
chengduoZH 已提交
151
    bool sameShape = true;
C
chengduoZH 已提交
152 153
    for (int i = 0; i < in_num; ++i) {
      int t_cols = input[i].numel() / in_row;
C
chengduoZH 已提交
154
      if (sameShape) {
C
chengduoZH 已提交
155
        if (t_cols != in_col) sameShape = false;
C
chengduoZH 已提交
156
      }
C
chengduoZH 已提交
157 158
      out_col += t_cols;
      inputs_col[i + 1] = out_col;
C
chengduoZH 已提交
159
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
160 161
    }

C
chengduoZH 已提交
162
    T** dev_ins_data =
C
chengduoZH 已提交
163 164
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
165
    // computation
C
chengduoZH 已提交
166 167
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
168
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
169 170
    if (out_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((out_col + 31) >> 5) << 5;
171 172
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
173 174
    dim3 block_size = dim3(block_cols, block_rows, 1);

175
    int max_threads = context.GetMaxPhysicalThreadCount();
C
chengduoZH 已提交
176 177 178
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
179
        std::min((out_col + block_cols - 1) / block_cols, max_blocks);
C
chengduoZH 已提交
180
    int grid_rows =
C
chengduoZH 已提交
181
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
182 183
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
184 185
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
186
          dev_ins_data, in_col, out_row, out_col, output->data<T>());
C
chengduoZH 已提交
187
    } else {
C
chengduoZH 已提交
188
      const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
C
chengduoZH 已提交
189
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
190 191
          dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
          out_row, out_col, output->data<T>());
C
chengduoZH 已提交
192 193 194 195
    }
  }
};

C
chengduoZH 已提交
196 197
/*
 * All tensors' dimension should be the same and the values of
198
 * each dimension must be the same, except the axis dimension.
C
chengduoZH 已提交
199
 */
C
chengduoZH 已提交
200 201 202 203 204 205
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const int axis,
                  std::vector<framework::Tensor>& outputs) {
C
chengduoZH 已提交
206
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
207 208
    int o_num = outputs.size();
    int out_row = 1;
C
chengduoZH 已提交
209 210
    auto dim_0 = outputs[0].dims();
    for (int i = 0; i < axis; ++i) {
C
chengduoZH 已提交
211
      out_row *= dim_0[i];
C
chengduoZH 已提交
212 213
    }

C
chengduoZH 已提交
214 215
    int out_col = outputs[0].numel() / out_row;
    int in_col = 0, in_row = out_row;
C
chengduoZH 已提交
216 217
    bool sameShape = true;

C
chengduoZH 已提交
218 219
    framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
    framework::Vector<int> outputs_cols(o_num + 1);
C
chengduoZH 已提交
220
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
221

C
chengduoZH 已提交
222 223 224
    outputs_cols[0] = 0;
    for (int i = 0; i < o_num; ++i) {
      int t_col = outputs[i].numel() / out_row;
C
chengduoZH 已提交
225
      if (sameShape) {
C
chengduoZH 已提交
226
        if (t_col != out_col) sameShape = false;
C
chengduoZH 已提交
227
      }
C
chengduoZH 已提交
228 229
      in_col += t_col;
      outputs_cols[i + 1] = in_col;
C
chengduoZH 已提交
230
      outputs_ptr[i] = outputs[i].data<T>();
C
chengduoZH 已提交
231 232
    }

C
chengduoZH 已提交
233
    T** dev_out_gpu_data =
C
chengduoZH 已提交
234 235
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));

C
chengduoZH 已提交
236
    // computation
C
chengduoZH 已提交
237
    const int kThreadsPerBlock = 1024;
238
    int block_cols = kThreadsPerBlock;
C
chengduoZH 已提交
239 240
    if (in_col < kThreadsPerBlock) {  // block_cols is aligned by 32.
      block_cols = ((in_col + 31) >> 5) << 5;
241 242
    }
    int block_rows = kThreadsPerBlock / block_cols;
C
chengduoZH 已提交
243 244
    dim3 block_size = dim3(block_cols, block_rows, 1);

245 246 247 248
    int max_threads = context.GetMaxPhysicalThreadCount();
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
C
chengduoZH 已提交
249
        std::min((in_col + block_cols - 1) / block_cols, max_blocks);
250
    int grid_rows =
C
chengduoZH 已提交
251
        std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
C
chengduoZH 已提交
252 253 254 255
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
256
          input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
C
chengduoZH 已提交
257
    } else {
C
chengduoZH 已提交
258
      const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
C
chengduoZH 已提交
259
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
260 261
          input.data<T>(), in_row, in_col, dev_outs_col_data,
          static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
C
chengduoZH 已提交
262
    }
C
chengduoZH 已提交
263 264 265 266 267 268 269 270
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
271 272 273 274 275
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
276 277 278
}  // namespace math
}  // namespace operators
}  // namespace paddle