hl_perturbation_util.cu 10.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <stdlib.h>
L
liaogang 已提交
16
#include <cmath>
Z
zhangjinchao01 已提交
17
#include "hl_base.h"
L
liaogang 已提交
18
#include "hl_cuda.h"
Z
zhangjinchao01 已提交
19
#include "hl_perturbation_util.cuh"
L
liaogang 已提交
20
#include "hl_time.h"
Z
zhangjinchao01 已提交
21 22 23 24 25 26 27 28 29 30 31

#define _USE_MATH_DEFINES

/*
 * Get the original coordinate for a pixel in a transformed image.
 * x, y: coordiate in the transformed image.
 * tgtCenter: the center coordiate of the transformed image.
 * imgSCenter: the center coordinate of the source image.
 * centerX, centerY: translation.
 * sourceX, sourceY: output coordinates in the original image.
 */
L
liaogang 已提交
32 33 34 35 36 37 38 39 40 41
__device__ void getTranformCoord(int x,
                                 int y,
                                 real theta,
                                 real scale,
                                 real tgtCenter,
                                 real imgCenter,
                                 real centerR,
                                 real centerC,
                                 int* sourceX,
                                 int* sourceY) {
Z
zhangjinchao01 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
  real H[4] = {cosf(-theta), -sinf(-theta), sinf(-theta), cosf(-theta)};

  // compute coornidates in the rotated and scaled image
  real x_new = x - tgtCenter + centerC;
  real y_new = y - tgtCenter + centerR;

  // compute coornidates in the original image
  x_new -= imgCenter;
  y_new -= imgCenter;
  real xx = H[0] * x_new + H[1] * y_new;
  real yy = H[2] * x_new + H[3] * y_new;
  *sourceX = __float2int_rn(xx / scale + imgCenter);
  *sourceY = __float2int_rn(yy / scale + imgCenter);
}

/*
 * imgs:            (numImages, imgPixels)
 * target:          (numImages * samplingRate, tgtPixels)
 * the channels of one pixel are stored continuously in memory.
 *
 * created by Wei Xu (genome), converted by Jiang Wang
 */

L
liaogang 已提交
65 66 67 68 69 70 71 72 73 74 75
__global__ void kSamplingPatches(const real* imgs,
                                 real* targets,
                                 int imgSize,
                                 int tgtSize,
                                 const int channels,
                                 int samplingRate,
                                 const real* thetas,
                                 const real* scales,
                                 const int* centerRs,
                                 const int* centerCs,
                                 const real padValue,
Z
zhangjinchao01 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
                                 const int numImages) {
  const int caseIdx = blockIdx.x * 4 + threadIdx.x;
  const int pxIdx = blockIdx.y * 128 + threadIdx.y;
  const int imgPixels = imgSize * imgSize;
  const int tgtPixels = tgtSize * tgtSize;
  const int numPatches = numImages * samplingRate;

  real tgtCenter = (tgtSize - 1) / 2;
  real imgCenter = (imgSize - 1) / 2;

  if (pxIdx < tgtPixels && caseIdx < numPatches) {
    const int imgIdx = caseIdx / samplingRate;

    // transform coordiates
    const int pxX = pxIdx % tgtSize;
    const int pxY = pxIdx / tgtSize;

    int srcPxX, srcPxY;
L
liaogang 已提交
94 95 96 97 98 99 100 101 102
    getTranformCoord(pxX,
                     pxY,
                     thetas[imgIdx],
                     scales[imgIdx],
                     tgtCenter,
                     imgCenter,
                     centerCs[caseIdx],
                     centerRs[caseIdx],
                     &srcPxX,
Z
zhangjinchao01 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
                     &srcPxY);

    imgs += (imgIdx * imgPixels + srcPxY * imgSize + srcPxX) * channels;
    targets += (caseIdx * tgtPixels + pxIdx) * channels;
    if (srcPxX >= 0 && srcPxX < imgSize && srcPxY >= 0 && srcPxY < imgSize) {
      for (int j = 0; j < channels; j++) targets[j] = imgs[j];
    } else {
      for (int j = 0; j < channels; j++) targets[j] = padValue;
    }
  }
}

/*
 * Functionality: generate the disturb (rotation and scaling) and
 *                sampling location sequence
 *
 * created by Wei Xu
 */
L
liaogang 已提交
121 122 123 124 125 126 127 128 129
void hl_generate_disturb_params(real*& gpuAngle,
                                real*& gpuScaleRatio,
                                int*& gpuCenterR,
                                int*& gpuCenterC,
                                int numImages,
                                int imgSize,
                                real rotateAngle,
                                real scaleRatio,
                                int samplingRate,
Z
zhangjinchao01 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
                                bool isTrain) {
  // The number of output samples.
  int numPatches = numImages * samplingRate;

  // create CPU perturbation parameters.
  real* r_angle = new real[numImages];
  real* s_ratio = new real[numImages];
  int* center_r = new int[numPatches];
  int* center_c = new int[numPatches];

  // generate the random disturbance sequence and the sampling locations
  if (isTrain) {  // random sampling for training
    // generate rotation ans scaling parameters
    // TODO(yuyang18): Since it will initialize random seed here, we can use
    // rand_r instead of rand to make this method thread safe.
    srand(getCurrentTimeStick());
    for (int i = 0; i < numImages; i++) {
      r_angle[i] =
          (rotateAngle * M_PI / 180.0) * (rand() / (RAND_MAX + 1.0)  // NOLINT
L
liaogang 已提交
149 150
                                          -
                                          0.5);
Z
zhangjinchao01 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
      s_ratio[i] =
          1 + (rand() / (RAND_MAX + 1.0) - 0.5) * scaleRatio;  // NOLINT
    }

    int imgCenter = (imgSize - 1) / 2;

    // generate sampling location parameters
    for (int i = 0; i < numImages; i++) {
      int j = 0;
      srand((unsigned)time(NULL));
      while (j < samplingRate) {
        int pxX =
            (int)(real(imgSize - 1) * rand() / (RAND_MAX + 1.0));  // NOLINT
        int pxY =
            (int)(real(imgSize - 1) * rand() / (RAND_MAX + 1.0));  // NOLINT

L
liaogang 已提交
167 168 169 170
        const real H[4] = {cos(-r_angle[i]),
                           -sin(-r_angle[i]),
                           sin(-r_angle[i]),
                           cos(-r_angle[i])};
Z
zhangjinchao01 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
        real x = pxX - imgCenter;
        real y = pxY - imgCenter;
        real xx = H[0] * x + H[1] * y;
        real yy = H[2] * x + H[3] * y;

        real srcPxX = xx / s_ratio[i] + imgCenter;
        real srcPxY = yy / s_ratio[i] + imgCenter;

        if (srcPxX >= 0 && srcPxX <= imgSize - 1 && srcPxY >= 0 &&
            srcPxY <= imgSize - 1) {
          center_r[i * samplingRate + j] = pxY;
          center_c[i * samplingRate + j] = pxX;
          j++;
        }
      }
    }
  } else {  // central crop for testing
    for (int i = 0; i < numImages; i++) {
      r_angle[i] = 0.0;
      s_ratio[i] = 1.0;

      for (int j = 0; j < samplingRate; j++) {
        center_r[i * samplingRate + j] = (imgSize - 1) / 2;
        center_c[i * samplingRate + j] = (imgSize - 1) / 2;
      }
    }
  }

  // copy disturbance sequence to gpu
  hl_memcpy_host2device(gpuAngle, r_angle, sizeof(real) * numImages);
  hl_memcpy_host2device(gpuScaleRatio, s_ratio, sizeof(real) * numImages);

  delete[] r_angle;
  delete[] s_ratio;

  // copy sampling location sequence to gpu
  hl_memcpy_host2device(gpuCenterR, center_r, sizeof(int) * numPatches);
  hl_memcpy_host2device(gpuCenterC, center_c, sizeof(int) * numPatches);

  delete[] center_r;
  delete[] center_c;
}

L
liaogang 已提交
214 215 216 217 218 219
void hl_conv_random_disturb_with_params(const real* images,
                                        int imgSize,
                                        int tgtSize,
                                        int channels,
                                        int numImages,
                                        int samplingRate,
Z
zhangjinchao01 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233
                                        const real* gpuRotationAngle,
                                        const real* gpuScaleRatio,
                                        const int* gpuCenterR,
                                        const int* gpuCenterC,
                                        int paddingValue,
                                        real* target) {
  // The number of output samples.
  int numPatches = numImages * samplingRate;
  // The memory size of one output patch.
  int targetSize = tgtSize * tgtSize;

  dim3 threadsPerBlock(4, 128);
  dim3 numBlocks(DIVUP(numPatches, 4), DIVUP(targetSize, 128));

L
liaogang 已提交
234 235 236 237 238 239 240 241 242 243 244 245
  kSamplingPatches<<<numBlocks, threadsPerBlock>>>(images,
                                                   target,
                                                   imgSize,
                                                   tgtSize,
                                                   channels,
                                                   samplingRate,
                                                   gpuRotationAngle,
                                                   gpuScaleRatio,
                                                   gpuCenterR,
                                                   gpuCenterC,
                                                   paddingValue,
                                                   numImages);
Z
zhangjinchao01 已提交
246 247 248 249

  hl_device_synchronize();
}

L
liaogang 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
void hl_conv_random_disturb(const real* images,
                            int imgSize,
                            int tgtSize,
                            int channels,
                            int numImages,
                            real scaleRatio,
                            real rotateAngle,
                            int samplingRate,
                            real* gpu_r_angle,
                            real* gpu_s_ratio,
                            int* gpu_center_r,
                            int* gpu_center_c,
                            int paddingValue,
                            bool isTrain,
                            real* targets) {
Z
zhangjinchao01 已提交
265
  // generate the random disturbance sequence and the sampling locations
L
liaogang 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
  hl_generate_disturb_params(gpu_r_angle,
                             gpu_s_ratio,
                             gpu_center_r,
                             gpu_center_c,
                             numImages,
                             imgSize,
                             rotateAngle,
                             scaleRatio,
                             samplingRate,
                             isTrain);

  hl_conv_random_disturb_with_params(images,
                                     imgSize,
                                     tgtSize,
                                     channels,
                                     numImages,
                                     samplingRate,
                                     gpu_r_angle,
                                     gpu_s_ratio,
                                     gpu_center_r,
                                     gpu_center_r,
                                     paddingValue,
                                     targets);
Z
zhangjinchao01 已提交
289
}