im2col.cu 17.2 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hedaoyuan 已提交
15 16
#include "paddle/operators/math/im2col.h"
#include "paddle/platform/cuda_helper.h"
H
hedaoyuan 已提交
17 18

namespace paddle {
19
namespace operators {
20
namespace math {
H
hedaoyuan 已提交
21 22

template <class T>
C
chengduoZH 已提交
23 24
__global__ void im2col(const T* data_im, int num_outs, int im_height,
                       int im_width, int dilation_h, int dilation_w,
H
hedaoyuan 已提交
25 26
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
C
chengduoZH 已提交
27 28 29
                       int col_height, int col_width, T* data_col) {
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
30
  if (index < num_outs) {
C
chengduoZH 已提交
31 32 33
    int w_out = index % col_width;
    int h_out = (index / col_width) % col_height;
    int channel_in = index / col_width / col_height;
H
hedaoyuan 已提交
34
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
35 36
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
37

C
chengduoZH 已提交
38 39
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
    data_im += (channel_in * im_height + h_in) * im_width + w_in;
H
hedaoyuan 已提交
40 41
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
42 43 44 45 46 47 48
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
                : data_im[i * dilation_h * im_width + j * dilation_w];
        data_col += col_height * col_width;
H
hedaoyuan 已提交
49 50 51 52 53 54
      }
    }
  }
}

/*
H
hedaoyuan 已提交
55 56 57
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
58 59
 */
template <class T>
H
hedaoyuan 已提交
60 61
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
62
 public:
H
hedaoyuan 已提交
63
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
64 65 66
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
67
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
68
    PADDLE_ENFORCE(col->dims().size() == 5);
H
hedaoyuan 已提交
69

C
chengduoZH 已提交
70 71 72
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
73 74 75 76 77 78 79 80
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
81 82 83 84
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
85 86 87
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
88 89 90 91 92 93
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");

    int num_outputs = im_channels * col_height * col_width;
H
hedaoyuan 已提交
94 95 96
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
97
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
98
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
99 100 101
    im2col<T><<<grid, threads, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(context)
                    .stream()>>>(
C
chengduoZH 已提交
102 103 104
        im.data<T>(), num_outputs, im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
        padding[0], padding[1], col_height, col_width, col->data<T>());
H
hedaoyuan 已提交
105 106 107 108
  }
};

template <class T>
C
chengduoZH 已提交
109 110 111 112 113 114
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
                       int dilation_h, int dilation_w, int filter_height,
                       int filter_width, int stride_height, int stride_width,
                       int padding_height, int padding_width, int col_height,
                       int col_width, T* data_im) {
  const int index =
H
hedaoyuan 已提交
115
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
116 117 118 119

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

H
hedaoyuan 已提交
120 121
  if (index < n) {
    T val = 0;
C
chengduoZH 已提交
122 123
    int w = index % im_width + padding_width;
    int h = (index / im_width) % im_height + padding_height;
C
chengduoZH 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    int c = index / (im_width * im_height);

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
147

C
chengduoZH 已提交
148
          val += data_col[data_col_index];
H
hedaoyuan 已提交
149 150 151
        }
      }
    }
C
chengduoZH 已提交
152
    data_im[index] = val;
H
hedaoyuan 已提交
153 154 155 156
  }
}

/*
H
hedaoyuan 已提交
157 158 159
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
160 161
 */
template <class T>
H
hedaoyuan 已提交
162 163
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
164
 public:
C
chengduoZH 已提交
165 166 167 168 169 170
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
171 172
    PADDLE_ENFORCE(col.dims().size() == 5);

C
chengduoZH 已提交
173 174 175
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
176 177
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
178 179 180
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

C
chengduoZH 已提交
181 182 183
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
184 185 186 187
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
188 189 190
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
191 192 193 194 195 196
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
197

H
hedaoyuan 已提交
198 199 200
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
201
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
202
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
203 204 205

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
H
hedaoyuan 已提交
206 207 208
    col2im<T><<<grid, threads, 0,
                reinterpret_cast<const platform::CUDADeviceContext&>(context)
                    .stream()>>>(
C
chengduoZH 已提交
209 210 211
        num_kernels, col.data<T>(), im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
        padding[0], padding[2], col_height, col_width, im->data<T>());
H
hedaoyuan 已提交
212 213 214
  }
};

H
hedaoyuan 已提交
215 216 217 218 219 220 221 222
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             platform::GPUPlace, double>;
H
hedaoyuan 已提交
223 224

template <class T>
C
chengduoZH 已提交
225 226 227
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
228
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
229
                          int col_width, T* col_data) {
H
hedaoyuan 已提交
230 231
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
232
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
233 234 235 236
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
237
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
238 239
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
240

H
hedaoyuan 已提交
241 242
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
243 244 245 246 247 248 249 250
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
251 252 253 254 255 256
      }
    }
  }
}

/*
H
hedaoyuan 已提交
257 258 259
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
260 261
 */
template <class T>
H
hedaoyuan 已提交
262 263
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
264
 public:
H
hedaoyuan 已提交
265
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
266 267 268
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
269
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
270
    PADDLE_ENFORCE(col->dims().size() == 5);
C
chengduoZH 已提交
271 272 273
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
274 275 276 277 278 279 280 281
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
282 283 284 285
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
286 287 288
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
289 290 291 292
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");
C
chengduoZH 已提交
293

H
hedaoyuan 已提交
294 295 296 297 298 299 300 301 302 303 304
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
305
    } else {
H
hedaoyuan 已提交
306 307
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
308 309
    }

H
hedaoyuan 已提交
310
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
311 312
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
H
hedaoyuan 已提交
313 314 315
    im2colOCF<T><<<grid, threads, 0,
                   reinterpret_cast<const platform::CUDADeviceContext&>(context)
                       .stream()>>>(
C
chengduoZH 已提交
316 317 318
        im.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, col->data<T>());
H
hedaoyuan 已提交
319 320 321 322
  }
};

template <class T>
C
chengduoZH 已提交
323 324 325
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
326
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
327
                          int col_width, T* im_data) {
H
hedaoyuan 已提交
328 329
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
330
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
331 332 333 334
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
335
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
336 337
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
338

H
hedaoyuan 已提交
339 340
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
341 342
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
343

C
chengduoZH 已提交
344 345
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
H
hedaoyuan 已提交
346 347
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
348 349 350 351 352 353 354
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
355 356 357
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
358 359
 */
template <class T>
H
hedaoyuan 已提交
360 361
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                    platform::GPUPlace, T> {
H
hedaoyuan 已提交
362
 public:
C
chengduoZH 已提交
363 364 365 366 367 368
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
369
    PADDLE_ENFORCE(col.dims().size() == 5);
C
chengduoZH 已提交
370 371 372
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
373 374
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
375 376 377
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

C
chengduoZH 已提交
378 379 380
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
381 382 383 384
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
385 386 387
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
388 389 390 391
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");
C
chengduoZH 已提交
392

H
hedaoyuan 已提交
393 394 395 396 397 398 399 400 401 402 403
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
404
    } else {
H
hedaoyuan 已提交
405 406
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
407 408
    }

H
hedaoyuan 已提交
409
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
410 411
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
H
hedaoyuan 已提交
412 413 414
    col2imOCF<T><<<grid, threads, 0,
                   reinterpret_cast<const platform::CUDADeviceContext&>(context)
                       .stream()>>>(
C
chengduoZH 已提交
415 416 417
        col.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, im->data<T>());
H
hedaoyuan 已提交
418 419 420
  }
};

H
hedaoyuan 已提交
421 422 423 424 425 426 427 428
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             platform::GPUPlace, double>;
H
hedaoyuan 已提交
429

430
}  // namespace math
431
}  // namespace operators
H
hedaoyuan 已提交
432
}  // namespace paddle