im2col.cu 16.0 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hedaoyuan 已提交
15 16
#include "paddle/operators/math/im2col.h"
#include "paddle/platform/cuda_helper.h"
H
hedaoyuan 已提交
17 18

namespace paddle {
19
namespace operators {
20
namespace math {
H
hedaoyuan 已提交
21 22

template <class T>
H
hedaoyuan 已提交
23 24 25 26
__global__ void im2col(const T* data_im, int num_outs, int height, int width,
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
                       int output_height, int output_width, T* data_col) {
H
hedaoyuan 已提交
27
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
28 29 30 31 32 33 34 35
  if (index < num_outs) {
    int w_out = index % output_width;
    index /= output_width;
    int h_out = index % output_height;
    int channel_in = index / output_height;
    int channel_out = channel_in * filter_height * filter_width;
    int h_in = h_out * stride_height;
    int w_in = w_out * stride_width;
H
hedaoyuan 已提交
36

H
hedaoyuan 已提交
37 38 39
    data_col += (channel_out * output_height + h_out) * output_width + w_out;
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
H
hedaoyuan 已提交
40 41
        int rIdx = int(h_in + i);
        int cIdx = int(w_in + j);
H
hedaoyuan 已提交
42 43 44 45
        if ((rIdx - (int)padding_height) >= (int)height ||
            (rIdx - (int)padding_height) < 0 ||
            (cIdx - (int)padding_width) >= (int)width ||
            (cIdx - (int)padding_width) < 0) {
H
hedaoyuan 已提交
46 47
          *data_col = 0;
        } else {
H
hedaoyuan 已提交
48 49
          rIdx = rIdx + channel_in * height - padding_height;
          cIdx = cIdx - padding_width;
H
hedaoyuan 已提交
50 51
          *data_col = data_im[rIdx * width + cIdx];
        }
H
hedaoyuan 已提交
52
        data_col += output_height * output_width;
H
hedaoyuan 已提交
53 54 55 56 57 58
      }
    }
  }
}

/*
H
hedaoyuan 已提交
59 60 61
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
62 63
 */
template <class T>
H
hedaoyuan 已提交
64 65
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
66
 public:
H
hedaoyuan 已提交
67 68
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& im, framework::Tensor& col,
H
hedaoyuan 已提交
69
                  int stride_height, int stride_width, int padding_height,
H
hedaoyuan 已提交
70
                  int padding_width) {
H
hedaoyuan 已提交
71 72
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
H
hedaoyuan 已提交
73

H
hedaoyuan 已提交
74 75 76 77 78 79 80 81 82 83 84 85
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];

    int num_outputs = input_channels * output_height * output_width;
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
86
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
87
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
88 89 90
    im2col<T><<<grid, threads, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(context)
                    .stream()>>>(
H
hedaoyuan 已提交
91 92 93
        im.data<T>(), num_outputs, input_height, input_width, filter_height,
        filter_width, stride_height, stride_width, padding_height,
        padding_width, output_height, output_width, col.data<T>());
H
hedaoyuan 已提交
94 95 96 97 98
  }
};

template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
H
hedaoyuan 已提交
99 100 101 102 103
                       size_t channels, size_t filter_height,
                       size_t filter_width, size_t stride_height,
                       size_t stride_width, size_t padding_height,
                       size_t padding_width, size_t output_height,
                       size_t output_width, T* data_im) {
H
hedaoyuan 已提交
104 105 106 107 108 109 110
  size_t index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    T val = 0;
    int w = int(index % width);
    int h = int((index / width) % height);
    int c = int(index / (width * height));
H
hedaoyuan 已提交
111 112 113 114
    if ((w - (int)padding_width) >= 0 &&
        (w - (int)padding_width) < (width - 2 * padding_width) &&
        (h - (int)padding_height) >= 0 &&
        (h - padding_height) < (height - 2 * padding_height)) {
H
hedaoyuan 已提交
115
      // compute the start and end of the output
H
hedaoyuan 已提交
116 117 118 119 120 121 122 123 124
      int w_col_start = (w < (int)filter_width)
                            ? 0
                            : (w - int(filter_width)) / (int)stride_width + 1;
      int w_col_end =
          min((int)(w / (int)stride_width + 1), (int)(output_width));
      int h_col_start = (h < (int)filter_height)
                            ? 0
                            : (h - (int)filter_height) / (int)stride_height + 1;
      int h_col_end = min(int(h / stride_height + 1), int(output_height));
H
hedaoyuan 已提交
125 126 127
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
          // the col location: [c * width * height + h_out, w_out]
H
hedaoyuan 已提交
128 129 130 131 132
          int c_col = int(c * filter_height * filter_width) +
                      (h - h_col * (int)stride_height) * (int)filter_width +
                      (w - w_col * (int)stride_width);
          val +=
              data_col[(c_col * output_height + h_col) * output_width + w_col];
H
hedaoyuan 已提交
133 134
        }
      }
H
hedaoyuan 已提交
135 136 137 138 139
      h -= padding_height;
      w -= padding_width;
      data_im[c * ((width - 2 * padding_width) *
                   (height - 2 * padding_height)) +
              h * (width - 2 * padding_width) + w] += val;
H
hedaoyuan 已提交
140 141 142 143 144
    }
  }
}

/*
H
hedaoyuan 已提交
145 146 147
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
148 149
 */
template <class T>
H
hedaoyuan 已提交
150 151
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
152
 public:
H
hedaoyuan 已提交
153 154 155
  void operator()(const platform::DeviceContext& context, framework::Tensor& im,
                  const framework::Tensor& col, int stride_height,
                  int stride_width, int padding_height, int padding_width) {
H
hedaoyuan 已提交
156 157 158 159 160 161 162 163 164 165
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);

    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];
H
hedaoyuan 已提交
166

H
hedaoyuan 已提交
167 168
    size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
                         (input_width + 2 * padding_width);
H
hedaoyuan 已提交
169

H
hedaoyuan 已提交
170 171 172
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
173
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
174
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
175 176 177

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
H
hedaoyuan 已提交
178 179 180
    col2im<T><<<grid, threads, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(context)
                    .stream()>>>(
H
hedaoyuan 已提交
181 182 183 184
        num_kernels, col.data<T>(), input_height + 2 * padding_height,
        input_width + 2 * padding_width, input_channels, filter_height,
        filter_width, stride_height, stride_width, padding_height,
        padding_width, output_height, output_width, im.data<T>());
H
hedaoyuan 已提交
185 186 187
  }
};

H
hedaoyuan 已提交
188 189 190 191 192 193 194 195
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, double>;
H
hedaoyuan 已提交
196 197

template <class T>
H
hedaoyuan 已提交
198 199 200 201
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
                          int input_height, int input_width, int filter_height,
                          int filter_width, int stride_height, int stride_width,
                          int padding_height, int padding_width,
C
chengduoZH 已提交
202 203
                          int output_height, int output_width, int row_begin,
                          int row_end) {
H
hedaoyuan 已提交
204 205 206 207 208 209 210
  int swid = blockIdx.x;
  int shid = blockIdx.y;
  for (int channelid = threadIdx.z; channelid < input_channels;
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
211 212
        int height_offset =
            idy + (shid + row_begin) * stride_height - padding_height;
H
hedaoyuan 已提交
213 214
        int im_offset = width_offset + height_offset * input_width +
                        channelid * input_height * input_width;
H
hedaoyuan 已提交
215

H
hedaoyuan 已提交
216 217 218 219
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
                         (shid * output_width + swid) *
                             (input_channels * filter_height * filter_width);
H
hedaoyuan 已提交
220

H
hedaoyuan 已提交
221 222 223
        if (height_offset >= input_height || height_offset < 0 ||
            width_offset >= input_width || width_offset < 0) {
          col_data[col_offset] = T(0);
H
hedaoyuan 已提交
224
        } else {
H
hedaoyuan 已提交
225
          col_data[col_offset] = im_data[im_offset];
H
hedaoyuan 已提交
226 227 228 229 230 231 232
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
233 234 235
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
236 237
 */
template <class T>
H
hedaoyuan 已提交
238 239
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
240
 public:
H
hedaoyuan 已提交
241 242
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& im, framework::Tensor& col,
C
chengduoZH 已提交
243 244
                  int stride_height, int stride_width, int up_pad,
                  int down_pad) {
H
hedaoyuan 已提交
245 246 247 248 249 250 251
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264

    int row_begin, row_end;
    int padding_height = std::max(up_pad, down_pad);
    int padding_width = 0;
    if (up_pad >= down_pad) {
      row_begin = 0;
    } else {
      row_begin = down_pad - up_pad;
    }
    row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
                               stride_height +
                           1);

C
chengduoZH 已提交
265
    int output_height = row_end - row_begin;  // col.dims()[0];
H
hedaoyuan 已提交
266
    int output_width = col.dims()[1];
H
hedaoyuan 已提交
267

H
hedaoyuan 已提交
268 269 270 271 272 273 274 275 276 277 278
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
279
    } else {
H
hedaoyuan 已提交
280 281
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
282 283
    }

H
hedaoyuan 已提交
284 285 286 287
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
    dim3 threads(block_dim_x, block_dim_y,
                 std::min(block_dim_z, input_channels));
    dim3 grid(output_width, output_height);
H
hedaoyuan 已提交
288 289 290
    im2colOCF<T><<<grid, threads, 0,
                   reinterpret_cast<const platform::CUDADeviceContext&>(context)
                       .stream()>>>(
H
hedaoyuan 已提交
291 292
        im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
        filter_height, filter_width, stride_height, stride_width,
C
chengduoZH 已提交
293 294
        padding_height, padding_width, output_height, output_width, row_begin,
        row_end);
H
hedaoyuan 已提交
295 296 297 298
  }
};

template <class T>
H
hedaoyuan 已提交
299 300 301 302
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
                          int input_height, int input_width, int filter_height,
                          int filter_width, int stride_height, int stride_width,
                          int padding_height, int padding_width,
C
chengduoZH 已提交
303 304
                          int output_height, int output_width, int row_begin,
                          int row_end) {
H
hedaoyuan 已提交
305 306 307 308 309 310 311
  int swid = blockIdx.x;
  int shid = blockIdx.y;
  for (int channelid = threadIdx.z; channelid < input_channels;
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
312 313
        int height_offset =
            idy + (shid + row_begin) * stride_height - padding_height;
H
hedaoyuan 已提交
314 315
        int im_offset = width_offset + height_offset * input_width +
                        channelid * input_height * input_width;
H
hedaoyuan 已提交
316

H
hedaoyuan 已提交
317 318 319 320
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
                         (shid * output_width + swid) *
                             (input_channels * filter_height * filter_width);
H
hedaoyuan 已提交
321

H
hedaoyuan 已提交
322 323 324 325
        if (height_offset >= 0 && height_offset < input_height &&
            width_offset >= 0 && width_offset < input_width) {
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
326 327 328 329 330 331 332
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
333 334 335
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
336 337
 */
template <class T>
H
hedaoyuan 已提交
338 339
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
340
 public:
H
hedaoyuan 已提交
341
  void operator()(const platform::DeviceContext& context, framework::Tensor& im,
C
chengduoZH 已提交
342 343
                  const framework::Tensor& col, int stride_height,
                  int stride_width, int up_pad, int down_pad) {
H
hedaoyuan 已提交
344 345 346 347 348 349 350
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
351 352 353 354 355 356 357 358 359 360 361 362 363

    int row_begin, row_end;
    int padding_height = std::max(up_pad, down_pad);
    int padding_width = 0;
    if (up_pad >= down_pad) {
      row_begin = 0;
    } else {
      row_begin = down_pad - up_pad;
    }
    row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
                               stride_height +
                           1);

C
chengduoZH 已提交
364
    int output_height = row_end - row_begin;  // col.dims()[0];
H
hedaoyuan 已提交
365
    int output_width = col.dims()[1];
H
hedaoyuan 已提交
366

H
hedaoyuan 已提交
367 368 369 370 371 372 373 374 375 376 377
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
378
    } else {
H
hedaoyuan 已提交
379 380
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
381 382
    }

H
hedaoyuan 已提交
383 384 385 386
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
    dim3 threads(block_dim_x, block_dim_y,
                 std::min(block_dim_z, input_channels));
    dim3 grid(output_width, output_height);
H
hedaoyuan 已提交
387 388 389
    col2imOCF<T><<<grid, threads, 0,
                   reinterpret_cast<const platform::CUDADeviceContext&>(context)
                       .stream()>>>(
H
hedaoyuan 已提交
390 391
        im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
        filter_height, filter_width, stride_height, stride_width,
C
chengduoZH 已提交
392 393
        padding_height, padding_width, output_height, output_width, row_begin,
        row_end);
H
hedaoyuan 已提交
394 395 396
  }
};

H
hedaoyuan 已提交
397 398 399 400 401 402 403 404
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, double>;
H
hedaoyuan 已提交
405

406
}  // namespace math
407
}  // namespace operators
H
hedaoyuan 已提交
408
}  // namespace paddle