concat.cu 10.0 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
C
chengduoZH 已提交
25
static constexpr int MaxSize = 8;
C
chengduoZH 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
template <typename T>
struct CUDADeviceArray {
  T data[MaxSize];
  int size;
};

template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
  const T* orig = first;
  const T* it = nullptr;
  T step = 0;
  while (count > 0) {
    it = first;
    step = count / 2;
    it += step;
    if (!(val < *it)) {
      first = ++it;
      count -= step + 1;
    } else {
      count = step;
    }
  }
  return first - orig;
}

template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
                             const CUDADeviceArray<int> input_cols,
                             const int output_rows, const int output_cols,
                             T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1;

  int curr_offset = input_cols.data[segment];
  int curr_segment = segment;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
    while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
    const T* input_ptr = inputs.data[curr_segment];
C
chengduoZH 已提交
71
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
C
chengduoZH 已提交
72 73 74 75 76 77
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * segment_width + local_col];
  }
}

C
chengduoZH 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
template <typename T>
__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
                             const int input_col, const int output_rows,
                             const int output_cols, T* output) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  float inv_input_col = 1.0 / input_col;
  for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * inv_input_col;
    int in_offset = tid_x - split * input_col;
    const T* input_ptr = inputs.data[split];
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
    for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
      output[tid_y * output_cols + tid_x] =
          input_ptr[tid_y * input_col + in_offset];
  }
}

template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
                                 const int input_col,
                                 CUDADeviceArray<int> output_cols,
                                 CUDADeviceArray<T*> outputs) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1;
  int curr_offset = output_cols.data[segment];
  int curr_segment = segment;
  for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
    T curr_col_offset;
    while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) {
      curr_offset = curr_col_offset;
      ++curr_segment;
    }

    int local_col = tid_x - curr_offset;
    int segment_width = curr_col_offset - curr_offset;
    T* output_ptr = outputs.data[curr_segment];
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
    for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * segment_width + local_col] =
          input[tid_y * input_col + tid_x];
  }
}

template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
                                 const int input_col, const int output_cols,
                                 CUDADeviceArray<T*> outputs) {
  int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
  float inv_input_col = 1.0 / input_col;
  for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
    int split = tid_x * inv_input_col;
    int in_offset = tid_x - split * input_col;
    T* output_ptr = outputs.data[split];
    int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
    for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
      output_ptr[tid_y * output_cols + in_offset] =
          input[tid_y * input_col + tid_x];
  }
}

C
chengduoZH 已提交
138 139 140 141 142 143 144
/*
 * All tensors' dimension should be the same.
 */
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
145
                  const std::vector<framework::Tensor>& input, const int axis,
C
chengduoZH 已提交
146 147 148 149
                  framework::Tensor* output) {
    // assume the the max size of input is less than 8 and see the performance
    // save origin dim
    int num = input.size();
C
chengduoZH 已提交
150 151
    PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
                      MaxSize);
C
chengduoZH 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    // get the matrix size
    int rows = 1;
    auto dim_0 = input[0].dims();
    for (int i = 0; i < axis; ++i) {
      rows *= dim_0[i];
    }
    int cols = input[0].numel() / rows;
    int out_rows = rows, out_cols = 0;
    bool sameShape = true;

    CUDADeviceArray<const T*> inputs_data;
    CUDADeviceArray<int> inputs_cols;
    inputs_data.size = num;
    inputs_cols.size = num + 1;
    inputs_cols.data[0] = 0;
    // reshape to matrix
    // check input shape is valid
    for (int i = 0; i < num; ++i) {
      int t_cols = input[i].numel() / rows;
      if (sameShape) {
        if (t_cols != cols) sameShape = false;
      }
      out_cols += t_cols;
      inputs_cols.data[i + 1] = out_cols;
      inputs_data.data[i] = input[i].data<T>();
    }

    // computation
C
chengduoZH 已提交
180 181
    // set the thread block and grid according to CurrentDeviceId
    const int kThreadsPerBlock = 1024;
C
chengduoZH 已提交
182 183 184 185
    int block_cols = std::min(out_cols, kThreadsPerBlock);
    int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
    dim3 block_size = dim3(block_cols, block_rows, 1);

C
chengduoZH 已提交
186 187 188 189 190 191 192 193 194 195 196
    int dev_id = paddle::platform::GetCurrentDeviceId();
    int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id);
    int max_threads_per_mp =
        paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id);
    int max_threads = multi_process * max_threads_per_mp;
    int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

    int grid_cols =
        std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
    int grid_rows =
        std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
C
chengduoZH 已提交
197 198
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

C
chengduoZH 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    if (sameShape) {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
          inputs_data, cols, out_rows, out_cols, output->data<T>());
    } else {
      KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
          inputs_data, inputs_cols, out_rows, out_cols, output->data<T>());
    }
  }
};

template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const int axis,
                  std::vector<framework::Tensor>& outputs) {
    // assume the the max size of input is less than 8 and see the performance
    // save origin dim
    int num = outputs.size();
    PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d",
                      MaxSize);

    // get the matrix size
    int input_row = 1;
    auto dim_0 = outputs[0].dims();
    for (int i = 0; i < axis; ++i) {
      input_row *= dim_0[i];
    }

    int output_col_0 = outputs[0].numel() / input_row;
    int input_col = 0;
    bool sameShape = true;

    CUDADeviceArray<T*> outputs_data;
    CUDADeviceArray<int> outputs_cols;
    outputs_data.size = num;
    outputs_cols.size = num + 1;
    outputs_cols.data[0] = 0;
C
chengduoZH 已提交
237

C
chengduoZH 已提交
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    for (int i = 0; i < num; ++i) {
      int t_col = outputs[i].numel() / input_row;
      if (sameShape) {
        if (t_col != output_col_0) sameShape = false;
      }
      input_col += t_col;
      outputs_cols.data[i + 1] = input_col;
      outputs_data.data[i] = outputs[i].data<T>();
    }

    // computation
    const int kThreadsPerBlock = 256;
    int block_cols = std::min(input_col, kThreadsPerBlock);
    int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
    dim3 block_size = dim3(block_cols, block_rows, 1);

    int grid_cols = (input_col + block_cols - 1) / block_cols;
    int grid_rows = (input_row + block_rows - 1) / block_rows;
    dim3 grid_size = dim3(grid_cols, grid_rows, 1);

    if (sameShape) {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
          input.data<T>(), input_row, input_col, output_col_0, outputs_data);
    } else {
      KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
          input.data<T>(), input_row, input_col, outputs_cols, outputs_data);
    }
C
chengduoZH 已提交
265 266 267 268 269 270 271 272
  }
};

template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
273 274 275 276 277
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;

C
chengduoZH 已提交
278 279 280
}  // namespace math
}  // namespace operators
}  // namespace paddle