im2col.cu 17.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/cuda_helper.h"
H
hedaoyuan 已提交
17 18

namespace paddle {
19
namespace operators {
20
namespace math {
H
hedaoyuan 已提交
21 22

template <class T>
C
chengduoZH 已提交
23 24
__global__ void im2col(const T* data_im, int num_outs, int im_height,
                       int im_width, int dilation_h, int dilation_w,
H
hedaoyuan 已提交
25 26
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
C
chengduoZH 已提交
27 28 29
                       int col_height, int col_width, T* data_col) {
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
30
  if (index < num_outs) {
C
chengduoZH 已提交
31 32 33
    int w_out = index % col_width;
    int h_out = (index / col_width) % col_height;
    int channel_in = index / col_width / col_height;
H
hedaoyuan 已提交
34
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
35 36
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
37

C
chengduoZH 已提交
38 39
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
    data_im += (channel_in * im_height + h_in) * im_width + w_in;
H
hedaoyuan 已提交
40 41
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
42 43 44 45 46 47 48
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
                : data_im[i * dilation_h * im_width + j * dilation_w];
        data_col += col_height * col_width;
H
hedaoyuan 已提交
49 50 51 52 53 54
      }
    }
  }
}

/*
H
hedaoyuan 已提交
55 56 57
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
58 59
 */
template <class T>
H
hedaoyuan 已提交
60
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
61
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
62
 public:
Q
QI JUN 已提交
63
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
64 65 66
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
67
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
68
    PADDLE_ENFORCE(col->dims().size() == 5);
H
hedaoyuan 已提交
69

C
chengduoZH 已提交
70 71 72
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
73 74 75 76 77 78 79 80
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
81 82 83 84
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
85 86 87
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
88 89 90 91 92 93
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");

    int num_outputs = im_channels * col_height * col_width;
H
hedaoyuan 已提交
94 95 96
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
97
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
98
    dim3 grid(block_x, block_y);
Q
QI JUN 已提交
99
    im2col<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
100 101 102
        im.data<T>(), num_outputs, im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
        padding[0], padding[1], col_height, col_width, col->data<T>());
H
hedaoyuan 已提交
103 104 105 106
  }
};

template <class T>
C
chengduoZH 已提交
107 108 109 110 111 112
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
                       int dilation_h, int dilation_w, int filter_height,
                       int filter_width, int stride_height, int stride_width,
                       int padding_height, int padding_width, int col_height,
                       int col_width, T* data_im) {
  const int index =
H
hedaoyuan 已提交
113
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
114 115 116 117

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

H
hedaoyuan 已提交
118 119
  if (index < n) {
    T val = 0;
C
chengduoZH 已提交
120 121
    int w = index % im_width + padding_width;
    int h = (index / im_width) % im_height + padding_height;
C
chengduoZH 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    int c = index / (im_width * im_height);

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
145

C
chengduoZH 已提交
146
          val += data_col[data_col_index];
H
hedaoyuan 已提交
147 148 149
        }
      }
    }
C
chengduoZH 已提交
150
    data_im[index] = val;
H
hedaoyuan 已提交
151 152 153 154
  }
}

/*
H
hedaoyuan 已提交
155 156 157
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
158 159
 */
template <class T>
H
hedaoyuan 已提交
160
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
161
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
162
 public:
Q
QI JUN 已提交
163
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
164 165 166 167 168
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
169 170
    PADDLE_ENFORCE(col.dims().size() == 5);

C
chengduoZH 已提交
171 172 173
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
174 175
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
176 177 178
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

C
chengduoZH 已提交
179 180 181
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
182 183 184 185
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
186 187 188
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
189 190 191 192 193 194
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
195

H
hedaoyuan 已提交
196 197 198
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
199
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
200
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
201 202 203

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
Q
QI JUN 已提交
204
    col2im<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
205 206 207
        num_kernels, col.data<T>(), im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
        padding[0], padding[2], col_height, col_width, im->data<T>());
H
hedaoyuan 已提交
208 209 210
  }
};

H
hedaoyuan 已提交
211
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
212
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
213
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
214
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
215
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
216
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
217
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
218
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
219 220

template <class T>
C
chengduoZH 已提交
221 222 223
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
224
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
225
                          int col_width, T* col_data) {
H
hedaoyuan 已提交
226 227
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
228
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
229 230 231 232
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
233
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
234 235
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
236

H
hedaoyuan 已提交
237 238
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
239 240 241 242 243 244 245 246
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
247 248 249 250 251 252
      }
    }
  }
}

/*
H
hedaoyuan 已提交
253 254 255
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
256 257
 */
template <class T>
H
hedaoyuan 已提交
258
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
259
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
260
 public:
Q
QI JUN 已提交
261
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
262 263 264
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* col) {
H
hedaoyuan 已提交
265
    PADDLE_ENFORCE(im.dims().size() == 3);
C
chengduoZH 已提交
266
    PADDLE_ENFORCE(col->dims().size() == 5);
C
chengduoZH 已提交
267 268 269
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
270 271 272 273 274 275 276 277
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
278 279 280 281
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
282 283 284
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
285 286 287 288
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");
C
chengduoZH 已提交
289

H
hedaoyuan 已提交
290 291 292 293 294 295 296 297 298 299 300
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
301
    } else {
H
hedaoyuan 已提交
302 303
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
304 305
    }

H
hedaoyuan 已提交
306
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
307 308
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
309
    im2colOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
310 311 312
        im.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, col->data<T>());
H
hedaoyuan 已提交
313 314 315 316
  }
};

template <class T>
C
chengduoZH 已提交
317 318 319
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
320
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
321
                          int col_width, T* im_data) {
H
hedaoyuan 已提交
322 323
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
324
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
325 326 327 328
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
329
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
330 331
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
332

H
hedaoyuan 已提交
333 334
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
335 336
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
337

C
chengduoZH 已提交
338 339
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
H
hedaoyuan 已提交
340 341
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
342 343 344 345 346 347 348
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
349 350 351
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
352 353
 */
template <class T>
H
hedaoyuan 已提交
354
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
355
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
356
 public:
Q
QI JUN 已提交
357
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
358 359 360 361 362
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
                  const std::vector<int>& padding, framework::Tensor* im) {
    PADDLE_ENFORCE(im->dims().size() == 3);
H
hedaoyuan 已提交
363
    PADDLE_ENFORCE(col.dims().size() == 5);
C
chengduoZH 已提交
364 365 366
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
367 368
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
369 370 371
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

C
chengduoZH 已提交
372 373 374
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
375 376 377 378
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
379 380 381
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
382 383 384 385
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");
C
chengduoZH 已提交
386

H
hedaoyuan 已提交
387 388 389 390 391 392 393 394 395 396 397
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
398
    } else {
H
hedaoyuan 已提交
399 400
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
401 402
    }

H
hedaoyuan 已提交
403
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
404 405
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
406
    col2imOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
407 408 409
        col.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, im->data<T>());
H
hedaoyuan 已提交
410 411 412
  }
};

H
hedaoyuan 已提交
413
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
414
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
415
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
416
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
417
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
418
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
419
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
420
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
421

422
}  // namespace math
423
}  // namespace operators
H
hedaoyuan 已提交
424
}  // namespace paddle