im2col.cu 17.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/im2col.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
H
hedaoyuan 已提交
19 20

namespace paddle {
21
namespace operators {
22
namespace math {
H
hedaoyuan 已提交
23 24

template <class T>
C
chengduoZH 已提交
25 26
__global__ void im2col(const T* data_im, int num_outs, int im_height,
                       int im_width, int dilation_h, int dilation_w,
H
hedaoyuan 已提交
27 28
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
29 30 31 32
                       int col_height, int col_width, T* data_col,
                       const DataLayout data_layout) {
  int input_channels = num_outs / col_height / col_width;
  int channels_col = input_channels * filter_height * filter_width;
C
chengduoZH 已提交
33 34
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
35
  if (index < num_outs) {
36
    int w_out = (data_layout != DataLayout::kNHWC
37 38
                     ? index % col_width
                     : (index / input_channels) % col_width);
39
    int h_out = (data_layout != DataLayout::kNHWC
40 41 42
                     ? (index / col_width) % col_height
                     : (index / input_channels / col_width) % col_height);
    int channel_in =
43
        (data_layout != DataLayout::kNHWC ? index / col_width / col_height
44
                                          : index % input_channels);
H
hedaoyuan 已提交
45
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
46 47
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
48

C
chengduoZH 已提交
49
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
H
hedaoyuan 已提交
50 51
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
52 53
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
54
        int im_idx;
55
        if (data_layout != DataLayout::kNHWC) {
56 57 58 59
          im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
        } else {
          im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
        }
C
chengduoZH 已提交
60 61 62
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
63
                : data_im[im_idx];
C
chengduoZH 已提交
64
        data_col += col_height * col_width;
H
hedaoyuan 已提交
65 66 67 68 69 70
      }
    }
  }
}

/*
H
hedaoyuan 已提交
71 72 73
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
74 75
 */
template <class T>
H
hedaoyuan 已提交
76
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
77
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
78
 public:
Q
QI JUN 已提交
79
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
80 81
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
82 83 84 85 86 87 88
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
    PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
                      "The dimension of col should be 5.");

    int im_channels =
89
        (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
90
    int im_height =
91
        (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
92
    int im_width =
93
        (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
C
chengduoZH 已提交
94 95 96 97 98
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

C
chengduoZH 已提交
99
    int num_outputs = im_channels * col_height * col_width;
H
hedaoyuan 已提交
100 101 102
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
103
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
104
    dim3 grid(block_x, block_y);
Q
QI JUN 已提交
105
    im2col<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
106 107
        im.data<T>(), num_outputs, im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
108 109
        padding[0], padding[1], col_height, col_width, col->data<T>(),
        data_layout);
H
hedaoyuan 已提交
110 111 112 113
  }
};

template <class T>
C
chengduoZH 已提交
114 115 116 117
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
                       int dilation_h, int dilation_w, int filter_height,
                       int filter_width, int stride_height, int stride_width,
                       int padding_height, int padding_width, int col_height,
118 119
                       int col_width, T* data_im,
                       const DataLayout data_layout) {
C
chengduoZH 已提交
120
  const int index =
H
hedaoyuan 已提交
121
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
122 123 124 125

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

126 127
  int input_channels = n / im_height / im_width;

H
hedaoyuan 已提交
128 129
  if (index < n) {
    T val = 0;
130
    int w = (data_layout != DataLayout::kNHWC
131 132
                 ? index % im_width + padding_width
                 : (index / input_channels) % im_width + padding_width);
133
    int h = (data_layout != DataLayout::kNHWC
134 135 136
                 ? (index / im_width) % im_height + padding_height
                 : (index / input_channels / im_width) % im_height +
                       padding_height);
137
    int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
138
                                              : index % input_channels);
C
chengduoZH 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
161

C
chengduoZH 已提交
162
          val += data_col[data_col_index];
H
hedaoyuan 已提交
163 164 165
        }
      }
    }
C
chengduoZH 已提交
166
    data_im[index] = val;
H
hedaoyuan 已提交
167 168 169 170
  }
}

/*
H
hedaoyuan 已提交
171 172 173
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
174 175
 */
template <class T>
H
hedaoyuan 已提交
176
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
177
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
178
 public:
Q
QI JUN 已提交
179
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
180 181 182
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
183 184 185 186 187 188 189
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
    PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
                      "The dimension of col should be 5.");

    int im_channels =
190
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
191
    int im_height =
192
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
193
    int im_width =
194
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
195 196
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
197 198 199
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

C
chengduoZH 已提交
200 201 202
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
203 204 205 206
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
207 208 209
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
210 211 212 213 214 215
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
216

H
hedaoyuan 已提交
217 218 219
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
220
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
221
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
222 223 224

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
Q
QI JUN 已提交
225
    col2im<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
226 227
        num_kernels, col.data<T>(), im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
228 229
        padding[0], padding[1], col_height, col_width, im->data<T>(),
        data_layout);
H
hedaoyuan 已提交
230 231 232
  }
};

H
hedaoyuan 已提交
233
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
234
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
235
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
236
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
237
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
238
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
239
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
240
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
241 242

template <class T>
C
chengduoZH 已提交
243 244 245
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
246
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
247
                          int col_width, T* col_data) {
H
hedaoyuan 已提交
248 249
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
250
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
251 252 253 254
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
255
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
256 257
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
258

H
hedaoyuan 已提交
259 260
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
261 262 263 264 265 266 267 268
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
269 270 271 272 273 274
      }
    }
  }
}

/*
H
hedaoyuan 已提交
275 276 277
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
278 279
 */
template <class T>
H
hedaoyuan 已提交
280
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
281
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
282
 public:
Q
QI JUN 已提交
283
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
284 285
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
286 287 288 289 290 291
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
    PADDLE_ENFORCE_EQ(im.dims().size(), 3, "The dimension of im should be 3.");
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
                      "The dimension of col should be 5.");

C
chengduoZH 已提交
292 293 294
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
295 296 297 298 299
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

H
hedaoyuan 已提交
300 301 302 303 304 305 306 307 308 309 310
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
311
    } else {
H
hedaoyuan 已提交
312 313
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
314 315
    }

H
hedaoyuan 已提交
316
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
317 318
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
319
    im2colOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
320 321 322
        im.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, col->data<T>());
H
hedaoyuan 已提交
323 324 325 326
  }
};

template <class T>
C
chengduoZH 已提交
327 328 329
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
330
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
331
                          int col_width, T* im_data) {
H
hedaoyuan 已提交
332 333
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
334
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
335 336 337 338
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
339
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
340 341
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
342

H
hedaoyuan 已提交
343 344
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
345 346
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
347

C
chengduoZH 已提交
348 349
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
H
hedaoyuan 已提交
350 351
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
352 353 354 355 356 357 358
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
359 360 361
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
362 363
 */
template <class T>
H
hedaoyuan 已提交
364
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
365
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
366
 public:
Q
QI JUN 已提交
367
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
368 369 370
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
371 372 373 374 375 376
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
    PADDLE_ENFORCE_EQ(im->dims().size(), 3, "The dimension of im should be 3.");
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
                      "The dimension of col should be 5.");

C
chengduoZH 已提交
377 378 379
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
380 381
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
382 383 384
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

C
chengduoZH 已提交
385 386 387
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
388 389 390 391
                          1,
                      col_height,
                      "Output_height and padding(padding_up, padding_down) are "
                      "inconsistent.");
C
chengduoZH 已提交
392 393 394
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
395 396 397 398
                          1,
                      col_width,
                      "col_width and padding(padding_left, padding_right) are "
                      "inconsistent.");
C
chengduoZH 已提交
399

H
hedaoyuan 已提交
400 401 402 403 404 405 406 407 408 409 410
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
411
    } else {
H
hedaoyuan 已提交
412 413
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
414 415
    }

H
hedaoyuan 已提交
416
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
417 418
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
419
    col2imOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
420 421 422
        col.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, im->data<T>());
H
hedaoyuan 已提交
423 424 425
  }
};

H
hedaoyuan 已提交
426
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
427
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
428
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
429
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
430
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
431
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
432
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
433
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
434

435
}  // namespace math
436
}  // namespace operators
H
hedaoyuan 已提交
437
}  // namespace paddle