concat.cu 9.9 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduoZH 已提交
15
#include "paddle/fluid/framework/mixed_vector.h"
C
chengduoZH 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
  const T* orig = first;
  const T* it = nullptr;
  T step = 0;
  while (count > 0) {
    it = first;
    step = count / 2;
    it += step;
    if (!(val < *it)) {
      first = ++it;
      count -= step + 1;
    } else {
      count = step;
    }
  }
  return first - orig;
}

template <typename T>
C
chengduoZH 已提交
43
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
C
chengduoZH 已提交
44 45 46
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
47
  int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
C
chengduoZH 已提交
48

C
chengduoZH 已提交
49
  int curr_offset = input_cols[segment];
C
chengduoZH 已提交
50 51 52
  int curr_segment = segment;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
C
chengduoZH 已提交
53
    while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
54 55 56 57 58 59
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
60
    T* input_ptr = inputs[curr_segment];
C
chengduoZH 已提交
61
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
62 63 64 65 66 67
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
68
template <typename T>
C
chengduoZH 已提交
69 70 71
__global__ void KernelConcat(T** inputs, const int input_col,
                             const int output_rows, const int output_cols,
                             T* output) {
C
chengduoZH 已提交
72 73 74 75 76
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  float inv_input_col = 1.0 / input_col;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * inv_input_col;
    int in_offset = tid_x - split * input_col;
C
chengduoZH 已提交
77
    T* input_ptr = inputs[split];
C
chengduoZH 已提交
78
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
79
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
C
chengduoZH 已提交
80 81
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * input_col + in_offset];
C
chengduoZH 已提交
82
    }
C
chengduoZH 已提交
83 84 85 86 87
  }
}

template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
C
chengduoZH 已提交
88 89
                                 const int input_col, const int* output_cols,
                                 int col_size, T** outputs) {
C
chengduoZH 已提交
90
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
91 92
  int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
  int curr_offset = output_cols[segment];
C
chengduoZH 已提交
93 94 95
  int curr_segment = segment;
  for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
C
chengduoZH 已提交
96
    while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
C
chengduoZH 已提交
97 98 99 100 101 102
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
C
chengduoZH 已提交
103
    T* output_ptr = outputs[curr_segment];
C
chengduoZH 已提交
104 105 106 107 108 109 110 111 112 113
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
    for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * segment_width + local_col] =
          input[tid_y * input_col + tid_x];
  }
}

template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
                                 const int input_col, const int output_cols,
C
chengduoZH 已提交
114
                                 T** outputs) {
C
chengduoZH 已提交
115 116 117 118 119
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  float inv_input_col = 1.0 / input_col;
  for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * inv_input_col;
    int in_offset = tid_x - split * input_col;
C
chengduoZH 已提交
120
    T* output_ptr = outputs[split];
C
chengduoZH 已提交
121 122 123 124 125 126 127
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
    for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * output_cols + in_offset] =
          input[tid_y * input_col + tid_x];
  }
}

C
chengduoZH 已提交
128
/*
C
chengduoZH 已提交
129 130
 * All tensors' dimension should be the same and the values of
 * each dimension are the same, except the axis dimension.
C
chengduoZH 已提交
131 132 133 134 135
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
136
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
137
                  framework::Tensor* output) {
C
chengduoZH 已提交
138
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
139 140 141 142 143 144 145 146 147
    int num = input.size();
    int rows = 1;
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
      rows *= dim_0[i];
    }
    int cols = input[0].numel() / rows;
    int out_rows = rows, out_cols = 0;

C
chengduoZH 已提交
148 149 150 151 152 153
    paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
    paddle::framework::Vector<int> inputs_cols(num + 1);
    inputs_cols[0] = 0;
    T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());

    bool sameShape = true;
C
chengduoZH 已提交
154 155 156 157 158 159
    for (int i = 0; i < num; ++i) {
      int t_cols = input[i].numel() / rows;
      if (sameShape) {
        if (t_cols != cols) sameShape = false;
      }
      out_cols += t_cols;
C
chengduoZH 已提交
160 161
      inputs_cols[i + 1] = out_cols;
      inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
C
chengduoZH 已提交
162 163
    }

C
chengduoZH 已提交
164 165 166 167
    T** ins_gpu =
        reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
    const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());

C
chengduoZH 已提交
168
    // computation
C
chengduoZH 已提交
169 170
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
C
chengduoZH 已提交
171 172 173 174
    int block_cols = std::min(out_cols, kThreadsPerBlock);
    int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
    dim3 block_size = dim3(block_cols, block_rows, 1);

C
chengduoZH 已提交
175 176 177 178 179 180 181 182 183 184 185
    int dev_id = paddle::platform::GetCurrentDeviceId();
    int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
    int max_threads_per_mp =
        paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
    int max_threads = multi_process * max_threads_per_mp;
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
        std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
    int grid_rows =
        std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
C
chengduoZH 已提交
186 187
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
188 189
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
190
          ins_gpu, cols, out_rows, out_cols, output->data<T>());
C
chengduoZH 已提交
191 192
    } else {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
193 194
          ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
          out_cols, output->data<T>());
C
chengduoZH 已提交
195 196 197 198
    }
  }
};

C
chengduoZH 已提交
199 200 201 202
/*
 * All tensors' dimension should be the same and the values of
 * each dimension are the same, except the axis dimension.
 */
C
chengduoZH 已提交
203 204 205 206 207 208
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const int axis,
                  std::vector<framework::Tensor>& outputs) {
C
chengduoZH 已提交
209
    // TODO(zcd): Add input data validity checking
C
chengduoZH 已提交
210 211 212 213 214 215 216 217 218 219 220
    int num = outputs.size();
    int input_row = 1;
    auto dim_0 = outputs[0].dims();
    for (int i = 0; i < axis; ++i) {
      input_row *= dim_0[i];
    }

    int output_col_0 = outputs[0].numel() / input_row;
    int input_col = 0;
    bool sameShape = true;

C
chengduoZH 已提交
221 222 223 224
    paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
    paddle::framework::Vector<int> outputs_cols(num + 1);
    outputs_cols[0] = 0;
    T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
C
chengduoZH 已提交
225

C
chengduoZH 已提交
226 227 228 229 230 231
    for (int i = 0; i < num; ++i) {
      int t_col = outputs[i].numel() / input_row;
      if (sameShape) {
        if (t_col != output_col_0) sameShape = false;
      }
      input_col += t_col;
C
chengduoZH 已提交
232 233
      outputs_cols[i + 1] = input_col;
      outputs_ptr[i] = outputs[i].data<T>();
C
chengduoZH 已提交
234 235
    }

C
chengduoZH 已提交
236 237 238 239
    T** outs_gpu =
        reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
    const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());

C
chengduoZH 已提交
240
    // computation
C
chengduoZH 已提交
241
    const int kThreadsPerBlock = 1024;
C
chengduoZH 已提交
242 243 244 245 246 247 248 249 250 251
    int block_cols = std::min(input_col, kThreadsPerBlock);
    int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
    dim3 block_size = dim3(block_cols, block_rows, 1);

    int grid_cols = (input_col + block_cols - 1) / block_cols;
    int grid_rows = (input_row + block_rows - 1) / block_rows;
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
252
          input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
C
chengduoZH 已提交
253 254
    } else {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
C
chengduoZH 已提交
255 256
          input.data<T>(), input_row, input_col, outs_col_gpu,
          static_cast<int>(outputs_cols.size()), outs_gpu);
C
chengduoZH 已提交
257
    }
C
chengduoZH 已提交
258 259 260 261 262 263 264 265
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
266 267 268 269 270
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
271 272 273
}  // namespace math
}  // namespace operators
}  // namespace paddle