im2col.cu 14.2 KB
Newer Older
H
hedaoyuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

H
hedaoyuan 已提交
15 16
#include "paddle/operators/math/im2col.h"
#include "paddle/platform/cuda_helper.h"
H
hedaoyuan 已提交
17 18

namespace paddle {
19
namespace operators {
20
namespace math {
H
hedaoyuan 已提交
21 22

template <class T>
H
hedaoyuan 已提交
23 24 25 26
__global__ void im2col(const T* data_im, int num_outs, int height, int width,
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
                       int output_height, int output_width, T* data_col) {
H
hedaoyuan 已提交
27
  int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
28 29 30 31 32 33 34 35
  if (index < num_outs) {
    int w_out = index % output_width;
    index /= output_width;
    int h_out = index % output_height;
    int channel_in = index / output_height;
    int channel_out = channel_in * filter_height * filter_width;
    int h_in = h_out * stride_height;
    int w_in = w_out * stride_width;
H
hedaoyuan 已提交
36

H
hedaoyuan 已提交
37 38 39
    data_col += (channel_out * output_height + h_out) * output_width + w_out;
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
H
hedaoyuan 已提交
40 41
        int rIdx = int(h_in + i);
        int cIdx = int(w_in + j);
H
hedaoyuan 已提交
42 43 44 45
        if ((rIdx - (int)padding_height) >= (int)height ||
            (rIdx - (int)padding_height) < 0 ||
            (cIdx - (int)padding_width) >= (int)width ||
            (cIdx - (int)padding_width) < 0) {
H
hedaoyuan 已提交
46 47
          *data_col = 0;
        } else {
H
hedaoyuan 已提交
48 49
          rIdx = rIdx + channel_in * height - padding_height;
          cIdx = cIdx - padding_width;
H
hedaoyuan 已提交
50 51
          *data_col = data_im[rIdx * width + cIdx];
        }
H
hedaoyuan 已提交
52
        data_col += output_height * output_width;
H
hedaoyuan 已提交
53 54 55 56 57 58
      }
    }
  }
}

/*
H
hedaoyuan 已提交
59 60 61
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
62 63
 */
template <class T>
H
hedaoyuan 已提交
64
class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
H
hedaoyuan 已提交
65
 public:
H
hedaoyuan 已提交
66 67
  void operator()(const framework::Tensor& im, framework::Tensor& col,
                  int stride_height, int stride_width, int padding_height,
68
                  int padding_width, platform::DeviceContext* context) {
H
hedaoyuan 已提交
69 70
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
H
hedaoyuan 已提交
71

H
hedaoyuan 已提交
72 73 74 75 76 77 78 79 80 81 82 83
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];

    int num_outputs = input_channels * output_height * output_width;
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
84
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
85
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
86 87 88
    im2col<T><<<
        grid, threads, 0,
        reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
H
hedaoyuan 已提交
89 90 91
        im.data<T>(), num_outputs, input_height, input_width, filter_height,
        filter_width, stride_height, stride_width, padding_height,
        padding_width, output_height, output_width, col.data<T>());
H
hedaoyuan 已提交
92 93 94 95 96
  }
};

template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
H
hedaoyuan 已提交
97 98 99 100 101
                       size_t channels, size_t filter_height,
                       size_t filter_width, size_t stride_height,
                       size_t stride_width, size_t padding_height,
                       size_t padding_width, size_t output_height,
                       size_t output_width, T* data_im) {
H
hedaoyuan 已提交
102 103 104 105 106 107 108
  size_t index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    T val = 0;
    int w = int(index % width);
    int h = int((index / width) % height);
    int c = int(index / (width * height));
H
hedaoyuan 已提交
109 110 111 112
    if ((w - (int)padding_width) >= 0 &&
        (w - (int)padding_width) < (width - 2 * padding_width) &&
        (h - (int)padding_height) >= 0 &&
        (h - padding_height) < (height - 2 * padding_height)) {
H
hedaoyuan 已提交
113
      // compute the start and end of the output
H
hedaoyuan 已提交
114 115 116 117 118 119 120 121 122
      int w_col_start = (w < (int)filter_width)
                            ? 0
                            : (w - int(filter_width)) / (int)stride_width + 1;
      int w_col_end =
          min((int)(w / (int)stride_width + 1), (int)(output_width));
      int h_col_start = (h < (int)filter_height)
                            ? 0
                            : (h - (int)filter_height) / (int)stride_height + 1;
      int h_col_end = min(int(h / stride_height + 1), int(output_height));
H
hedaoyuan 已提交
123 124 125
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
          // the col location: [c * width * height + h_out, w_out]
H
hedaoyuan 已提交
126 127 128 129 130
          int c_col = int(c * filter_height * filter_width) +
                      (h - h_col * (int)stride_height) * (int)filter_width +
                      (w - w_col * (int)stride_width);
          val +=
              data_col[(c_col * output_height + h_col) * output_width + w_col];
H
hedaoyuan 已提交
131 132
        }
      }
H
hedaoyuan 已提交
133 134 135 136 137
      h -= padding_height;
      w -= padding_width;
      data_im[c * ((width - 2 * padding_width) *
                   (height - 2 * padding_height)) +
              h * (width - 2 * padding_width) + w] += val;
H
hedaoyuan 已提交
138 139 140 141 142
    }
  }
}

/*
H
hedaoyuan 已提交
143 144 145
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
146 147
 */
template <class T>
H
hedaoyuan 已提交
148
class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
H
hedaoyuan 已提交
149
 public:
H
hedaoyuan 已提交
150 151
  void operator()(framework::Tensor& im, const framework::Tensor& col,
                  int stride_height, int stride_width, int padding_height,
152
                  int padding_width, platform::DeviceContext* context) {
H
hedaoyuan 已提交
153 154 155 156 157 158 159 160 161 162
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);

    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
    int output_height = col.dims()[3];
    int output_width = col.dims()[4];
H
hedaoyuan 已提交
163

H
hedaoyuan 已提交
164 165
    size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
                         (input_width + 2 * padding_width);
H
hedaoyuan 已提交
166

H
hedaoyuan 已提交
167 168 169
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
170
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
171
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
172 173 174

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
H
hedaoyuan 已提交
175 176 177
    col2im<T><<<
        grid, threads, 0,
        reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
H
hedaoyuan 已提交
178 179 180 181
        num_kernels, col.data<T>(), input_height + 2 * padding_height,
        input_width + 2 * padding_width, input_channels, filter_height,
        filter_width, stride_height, stride_width, padding_height,
        padding_width, output_height, output_width, im.data<T>());
H
hedaoyuan 已提交
182 183 184
  }
};

H
hedaoyuan 已提交
185 186 187 188
template class Im2ColFunctor<kCFO, platform::GPUPlace, float>;
template class Im2ColFunctor<kCFO, platform::GPUPlace, double>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, float>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, double>;
H
hedaoyuan 已提交
189 190

template <class T>
H
hedaoyuan 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
                          int input_height, int input_width, int filter_height,
                          int filter_width, int stride_height, int stride_width,
                          int padding_height, int padding_width,
                          int output_height, int output_width) {
  int swid = blockIdx.x;
  int shid = blockIdx.y;
  for (int channelid = threadIdx.z; channelid < input_channels;
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
        int height_offset = idy + shid * stride_height - padding_height;
        int im_offset = width_offset + height_offset * input_width +
                        channelid * input_height * input_width;
H
hedaoyuan 已提交
206

H
hedaoyuan 已提交
207 208 209 210
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
                         (shid * output_width + swid) *
                             (input_channels * filter_height * filter_width);
H
hedaoyuan 已提交
211

H
hedaoyuan 已提交
212 213 214
        if (height_offset >= input_height || height_offset < 0 ||
            width_offset >= input_width || width_offset < 0) {
          col_data[col_offset] = T(0);
H
hedaoyuan 已提交
215
        } else {
H
hedaoyuan 已提交
216
          col_data[col_offset] = im_data[im_offset];
H
hedaoyuan 已提交
217 218 219 220 221 222 223
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
224 225 226
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
227 228
 */
template <class T>
H
hedaoyuan 已提交
229
class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
H
hedaoyuan 已提交
230
 public:
H
hedaoyuan 已提交
231 232
  void operator()(const framework::Tensor& im, framework::Tensor& col,
                  int stride_height, int stride_width, int padding_height,
233
                  int padding_width, platform::DeviceContext* context) {
H
hedaoyuan 已提交
234 235 236 237 238 239 240 241 242
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
    int output_height = col.dims()[0];
    int output_width = col.dims()[1];
H
hedaoyuan 已提交
243

H
hedaoyuan 已提交
244 245 246 247 248 249 250 251 252 253 254
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
255
    } else {
H
hedaoyuan 已提交
256 257
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
258 259
    }

H
hedaoyuan 已提交
260 261 262 263
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
    dim3 threads(block_dim_x, block_dim_y,
                 std::min(block_dim_z, input_channels));
    dim3 grid(output_width, output_height);
H
hedaoyuan 已提交
264 265 266
    im2colOCF<T><<<
        grid, threads, 0,
        reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
H
hedaoyuan 已提交
267 268 269
        im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
        filter_height, filter_width, stride_height, stride_width,
        padding_height, padding_width, output_height, output_width);
H
hedaoyuan 已提交
270 271 272 273
  }
};

template <class T>
H
hedaoyuan 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
                          int input_height, int input_width, int filter_height,
                          int filter_width, int stride_height, int stride_width,
                          int padding_height, int padding_width,
                          int output_height, int output_width) {
  int swid = blockIdx.x;
  int shid = blockIdx.y;
  for (int channelid = threadIdx.z; channelid < input_channels;
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
        int height_offset = idy + shid * stride_height - padding_height;
        int im_offset = width_offset + height_offset * input_width +
                        channelid * input_height * input_width;
H
hedaoyuan 已提交
289

H
hedaoyuan 已提交
290 291 292 293
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
                         (shid * output_width + swid) *
                             (input_channels * filter_height * filter_width);
H
hedaoyuan 已提交
294

H
hedaoyuan 已提交
295 296 297 298
        if (height_offset >= 0 && height_offset < input_height &&
            width_offset >= 0 && width_offset < input_width) {
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
299 300 301 302 303 304 305
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
306 307 308
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
309 310
 */
template <class T>
H
hedaoyuan 已提交
311
class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
H
hedaoyuan 已提交
312
 public:
H
hedaoyuan 已提交
313 314
  void operator()(framework::Tensor& im, const framework::Tensor& col,
                  int stride_height, int stride_width, int padding_height,
315
                  int padding_width, platform::DeviceContext* context) {
H
hedaoyuan 已提交
316 317 318 319 320 321 322 323 324
    PADDLE_ENFORCE(im.dims().size() == 3);
    PADDLE_ENFORCE(col.dims().size() == 5);
    int input_channels = im.dims()[0];
    int input_height = im.dims()[1];
    int input_width = im.dims()[2];
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
    int output_height = col.dims()[0];
    int output_width = col.dims()[1];
H
hedaoyuan 已提交
325

H
hedaoyuan 已提交
326 327 328 329 330 331 332 333 334 335 336
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
337
    } else {
H
hedaoyuan 已提交
338 339
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
340 341
    }

H
hedaoyuan 已提交
342 343 344 345
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
    dim3 threads(block_dim_x, block_dim_y,
                 std::min(block_dim_z, input_channels));
    dim3 grid(output_width, output_height);
H
hedaoyuan 已提交
346 347 348
    col2imOCF<T><<<
        grid, threads, 0,
        reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
H
hedaoyuan 已提交
349 350 351
        im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
        filter_height, filter_width, stride_height, stride_width,
        padding_height, padding_width, output_height, output_width);
H
hedaoyuan 已提交
352 353 354
  }
};

H
hedaoyuan 已提交
355 356 357 358
template class Im2ColFunctor<kOCF, platform::GPUPlace, float>;
template class Im2ColFunctor<kOCF, platform::GPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, double>;
H
hedaoyuan 已提交
359

360
}  // namespace math
361
}  // namespace operators
H
hedaoyuan 已提交
362
}  // namespace paddle