Im2ColOpGpu.cu 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "Im2Col.h"
16
#include "hl_device_functions.cuh"
17 18 19

namespace paddle {

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
template<class T>
__global__
void im2col(const T* data_im, int numOuts, int height, int width,
            int blockH, int blockW,
            int strideH, int strideW,
            int paddingH, int paddingW,
            int height_col, int width_col,
            T* data_col) {
  int index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < numOuts) {
    int w_out = index % width_col;
    index /= width_col;
    int h_out = index % height_col;
    int channel_in = index / height_col;
    int channel_out = channel_in * blockH * blockW;
    int h_in = h_out * strideH;
    int w_in = w_out * strideW;

    data_col += (channel_out * height_col + h_out) * width_col + w_out;
    for (int i = 0; i < blockH; ++i) {
      for (int j = 0; j < blockW; ++j) {
        int rIdx = int(h_in+i);
        int cIdx = int(w_in+j);
        if ((rIdx-(int)paddingH) >= (int)height ||
            (rIdx-(int)paddingH) < 0 ||
            (cIdx-(int)paddingW) >= (int)width ||
            (cIdx-(int)paddingW) < 0) {
          *data_col = 0;
        } else {
          rIdx = rIdx + channel_in*height - paddingH;
          cIdx = cIdx - paddingW;
          *data_col = data_im[rIdx* width + cIdx];
        }
        data_col += height_col * width_col;
      }
    }
  }
}

60 61 62 63 64
/*
 * imShape = [inputChannels, inputHeight, inputWidth]
 * colShape =
 *   [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
 */
65 66 67 68 69 70 71 72 73 74 75 76 77 78
template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public:
  void operator()(const T* imData,
                  const TensorShape& imShape,
                  T* colData,
                  const TensorShape& colShape,
                  int strideHeight,
                  int strideWidth,
                  int paddingHeight,
                  int paddingWidth) {
    int inputChannels = imShape[0];
    int inputHeight = imShape[1];
    int inputWidth = imShape[2];
79 80 81 82
    int filterHeight = colShape[1];
    int filterWidth = colShape[2];
    int outputHeight = colShape[3];
    int outputWidth = colShape[4];
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

    int numKernels = inputChannels * outputHeight * outputWidth;
    int blocks = (numKernels + 1024 -1) / 1024;
    int blockX = 512;
    int blockY = (blocks + 512 - 1) / 512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);
    im2col<T><<< grid, threads, 0, STREAM_DEFAULT >>>
        (imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
         strideHeight, strideWidth, paddingHeight, paddingWidth,
         outputHeight, outputWidth, colData);
    CHECK_SYNC("Im2ColFunctor GPU failed");
  }
};

template<class T>
__global__
void col2im(size_t n, const T* data_col, size_t height,
            size_t width, size_t channels,
            size_t blockH, size_t blockW,
            size_t strideH, size_t strideW,
            size_t paddingH, size_t paddingW,
            size_t height_col, size_t width_col,
            T* data_im) {
  size_t index =
    (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
  if (index < n) {
    T val = 0;
    int w = int(index % width);
    int h = int((index / width) % height);
    int c = int(index / (width * height));
    if ((w - (int)paddingW) >= 0 &&
        (w - (int)paddingW) < (width-2 * paddingW) &&
        (h - (int)paddingH) >= 0 &&
        (h - paddingH) < (height - 2 * paddingH)) {
      // compute the start and end of the output
      int w_col_start =
        (w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
      int w_col_end =
        min((int)(w / (int)strideW + 1), (int)(width_col));
      int h_col_start =
        (h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
      int h_col_end = min(int(h / strideH + 1), int(height_col));
      for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
        for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
          // the col location: [c * width * height + h_out, w_out]
          int c_col = int(c * blockH* blockW) + \
            (h - h_col * (int)strideH) * (int)blockW +
            (w - w_col * (int)strideW);
          val += data_col[(c_col * height_col + h_col) * width_col + w_col];
        }
      }
      h -= paddingH;
      w -= paddingW;
      data_im[c*((width-2*paddingW) * (height-2*paddingH)) +
              h*(width-2*paddingW) + w] += val;
    }
  }
}

143 144 145 146 147
/*
 * imShape = [inputChannels, inputHeight, inputWidth]
 * colShape =
 *   [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
 */
148 149 150 151 152 153 154 155 156 157 158 159 160 161
template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public:
  void operator()(T* imData,
                  const TensorShape& imShape,
                  const T* colData,
                  const TensorShape& colShape,
                  int strideHeight,
                  int strideWidth,
                  int paddingHeight,
                  int paddingWidth) {
    int inputChannels = imShape[0];
    int inputHeight = imShape[1];
    int inputWidth = imShape[2];
162 163 164 165
    int filterHeight = colShape[1];
    int filterWidth = colShape[2];
    int outputHeight = colShape[3];
    int outputWidth = colShape[4];
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201

    size_t numKernels = inputChannels * (inputHeight + 2*paddingHeight)
        * (inputWidth + 2*paddingWidth);

    size_t blocks = (numKernels + 1024 -1) / 1024;
    size_t blockX = 512;
    size_t blockY = (blocks+512-1)/512;
    dim3 threads(1024, 1);
    dim3 grid(blockX, blockY);

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
    col2im<T><<< grid, threads, 0, STREAM_DEFAULT >>>
             (numKernels,
              colData,
              inputHeight + 2*paddingHeight,
              inputWidth + 2*paddingWidth,
              inputChannels,
              filterHeight,
              filterWidth,
              strideHeight,
              strideWidth,
              paddingHeight,
              paddingWidth,
              outputHeight,
              outputWidth,
              imData);
    CHECK_SYNC("Col2ImFunctor GPU failed");
  }
};

template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, double>;

202 203 204 205 206 207 208 209 210 211 212 213 214 215
template<class T>
__global__
void im2colOCF(const T* imData, T* colData,
               int inputChannels,
               int inputHeight, int inputWidth,
               int filterHeight, int filterWidth,
               int strideHeight, int strideWidth,
               int paddingHeight, int paddingWidth,
               int outputHeight, int outputWidth) {
  int swId = blockIdx.x;
  int shId = blockIdx.y;
  for (int channelId = threadIdx.z;
       channelId < inputChannels;
       channelId += blockDim.z) {
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
    for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
        int widthOffset = idx + swId * strideWidth - paddingWidth;
        int heightOffset = idy + shId * strideHeight - paddingHeight;
        int imOffset = widthOffset + heightOffset * inputWidth
           + channelId * inputHeight * inputWidth;

        int colOffset = idx + idy * filterWidth
          + channelId * filterHeight * filterWidth
          + (shId * outputWidth + swId)
          * (inputChannels * filterHeight * filterWidth);

        if (heightOffset >= inputHeight || heightOffset < 0 ||
            widthOffset >= inputWidth || widthOffset < 0) {
          colData[colOffset] = T(0);
        } else {
          colData[colOffset] = imData[imOffset];
        }
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
      }
    }
  }
}

/*
 * imShape = [inputChannels, inputHeight, inputWidth]
 * colShape =
 *   [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
 */
template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public:
  void operator()(const T* imData,
                  const TensorShape& imShape,
                  T* colData,
                  const TensorShape& colShape,
                  int strideHeight,
                  int strideWidth,
                  int paddingHeight,
                  int paddingWidth) {
    int inputChannels = imShape[0];
    int inputHeight = imShape[1];
    int inputWidth = imShape[2];
    int filterHeight = colShape[3];
    int filterWidth = colShape[4];
    int outputHeight = colShape[0];
    int outputWidth = colShape[1];

    int blockDimX = 0;
    int blockDimY = 0;
    if (filterHeight <= 4 && filterWidth <= 4) {
      blockDimX = 4;
      blockDimY = 4;
    } else if (filterHeight <= 8 && filterWidth <= 8) {
      blockDimX = 8;
      blockDimY = 8;
    } else if (filterHeight <= 16 && filterWidth <= 16) {
      blockDimX = 16;
      blockDimY = 16;
    } else {
      blockDimX = 32;
      blockDimY = 32;
    }

    int blockDimZ = 1024 / blockDimX / blockDimY;
    dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
    dim3 grid(outputWidth, outputHeight);
    im2colOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
        (imData, colData, inputChannels, inputHeight, inputWidth,
         filterHeight, filterWidth, strideHeight, strideWidth,
         paddingHeight, paddingWidth, outputHeight, outputWidth);
    CHECK_SYNC("Im2ColFunctor GPU failed");
  }
};

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
template<class T>
__global__
void col2imOCF(T* imData, const T* colData,
               int inputChannels,
               int inputHeight, int inputWidth,
               int filterHeight, int filterWidth,
               int strideHeight, int strideWidth,
               int paddingHeight, int paddingWidth,
               int outputHeight, int outputWidth) {
  int swId = blockIdx.x;
  int shId = blockIdx.y;
  for (int channelId = threadIdx.z;
       channelId < inputChannels;
       channelId += blockDim.z) {
    for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
        int widthOffset = idx + swId * strideWidth - paddingWidth;
        int heightOffset = idy + shId * strideHeight - paddingHeight;
        int imOffset = widthOffset + heightOffset * inputWidth
           + channelId * inputHeight * inputWidth;

        int colOffset = idx + idy * filterWidth
          + channelId * filterHeight * filterWidth
          + (shId * outputWidth + swId)
          * (inputChannels * filterHeight * filterWidth);

        if (heightOffset >= 0 && heightOffset < inputHeight &&
            widthOffset >= 0 && widthOffset < inputWidth) {
          paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]);
        }
      }
    }
  }
}

325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
/*
 * imShape = [inputChannels, inputHeight, inputWidth]
 * colShape =
 *   [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
 */
template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public:
  void operator()(T* imData,
                  const TensorShape& imShape,
                  const T* colData,
                  const TensorShape& colShape,
                  int strideHeight,
                  int strideWidth,
                  int paddingHeight,
                  int paddingWidth) {
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    int inputChannels = imShape[0];
    int inputHeight = imShape[1];
    int inputWidth = imShape[2];
    int filterHeight = colShape[3];
    int filterWidth = colShape[4];
    int outputHeight = colShape[0];
    int outputWidth = colShape[1];

    int blockDimX = 0;
    int blockDimY = 0;
    if (filterHeight <= 4 && filterWidth <= 4) {
      blockDimX = 4;
      blockDimY = 4;
    } else if (filterHeight <= 8 && filterWidth <= 8) {
      blockDimX = 8;
      blockDimY = 8;
    } else if (filterHeight <= 16 && filterWidth <= 16) {
      blockDimX = 16;
      blockDimY = 16;
    } else {
      blockDimX = 32;
      blockDimY = 32;
    }

    int blockDimZ = 1024 / blockDimX / blockDimY;
    dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
    dim3 grid(outputWidth, outputHeight);
    col2imOCF<T><<< grid, threads, 0, STREAM_DEFAULT >>>
        (imData, colData, inputChannels, inputHeight, inputWidth,
         filterHeight, filterWidth, strideHeight, strideWidth,
         paddingHeight, paddingWidth, outputHeight, outputWidth);
    CHECK_SYNC("Col2ImFunctor GPU failed");
373 374 375 376 377
  }
};

template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>;
378 379
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>;
380 381

}  // namespace paddle