pooling.cu 18.9 KB
Newer Older
C
chengduoZH 已提交
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/pooling.h"
C
chengduoZH 已提交
16
#include "paddle/platform/cuda_helper.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

namespace paddle {
namespace operators {
namespace math {

template <typename PoolProcess, typename T>
__global__ void KernelPool2dForward(
    const int nthreads, const T* input_data, T* output_data, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width, PoolProcess pool_process) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < nthreads) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    T ele = pool_process.initial();
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        pool_process.process(ele, input_data[h * input_width + w]);
      }
    }
    int pool_size = (hend - hstart) * (wend - wstart);
    pool_process.finalize(ele, (static_cast<T>(pool_size)));
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
__global__ void KernelPool2dBackward(
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width, PoolProcess pool_process) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < nthreads) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
        int pool_size = (hend - hstart) * (wend - wstart);
        int output_sub_idx = ph * output_width + pw;
        pool_process.gradProcess(input, output_data[output_sub_idx],
                                 output_grad[output_sub_idx], gradient,
98
                                 static_cast<T>(1.0 / pool_size));
99 100 101 102 103 104 105 106 107
      }
    }
    input_grad[index] = gradient;
  }
}

template <typename PoolProcess, typename T>
class Pool2dForwardFunctor<platform::GPUPlace, PoolProcess, T> {
 public:
108 109
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
110
                  std::vector<int>& ksize, std::vector<int>& strides,
111
                  std::vector<int>& paddings, PoolProcess pool_process) {
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
127
    T* output_data = output.mutable_data<T>(context.GetPlace());
128 129 130 131 132 133

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
134 135 136 137 138 139 140 141 142
    KernelPool2dForward<
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(nthreads, input_data, output_data, input_channels,
                              input_height, input_width, output_height,
                              output_width, ksize_height, ksize_width,
                              stride_height, stride_width, padding_height,
                              padding_width, pool_process);
143 144 145 146 147 148
  }
};

template <typename PoolProcess, typename T>
class Pool2dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
 public:
149 150
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
151 152 153
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
154
                  PoolProcess pool_process) {
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
171
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
172 173 174 175 176 177

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
178 179 180 181 182
    KernelPool2dBackward<
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_height, input_width, output_height, output_width,
        ksize_height, ksize_width, stride_height, stride_width, padding_height,
        padding_width, pool_process);
  }
};

template class Pool2dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool2dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;

template <typename PoolProcess, typename T>
__global__ void KernelPool3DForward(
    const int nthreads, const T* input_data, T* output_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
    PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = pool_process.initial();
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          pool_process.process(
              ele, input_data[(d * input_height + h) * input_width + w]);
        }
      }
    }
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
    pool_process.finalize(ele, static_cast<T>(pool_size));
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
__global__ void KernelPool3DBackward(
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
    PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
271
                      : (offsetD - ksize_depth) / stride_depth + 1;
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
303
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
304 305
          pool_process.gradProcess(input, output_data[output_sub_idx],
                                   output_grad[output_sub_idx], gradient,
306
                                   static_cast<T>(1.0 / pool_size));
307 308 309 310 311 312 313 314 315 316
        }
      }
    }
    input_grad[index] = gradient;
  }
}

template <typename PoolProcess, class T>
class Pool3dForwardFunctor<platform::GPUPlace, PoolProcess, T> {
 public:
317 318
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
319
                  std::vector<int>& ksize, std::vector<int>& strides,
320
                  std::vector<int>& paddings, PoolProcess pool_process) {
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
341
    T* output_data = output.mutable_data<T>(context.GetPlace());
342 343 344 345 346 347 348

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
349 350 351 352 353
    KernelPool3DForward<
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
354 355 356 357 358 359 360 361 362 363 364
        nthreads, input_data, output_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        pool_process);
  }
};

template <typename PoolProcess, class T>
class Pool3dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
 public:
365 366
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
367 368 369
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
370
                  PoolProcess pool_process) {
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
393
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
394

395 396
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
397 398 399 400
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
401 402 403 404 405
    KernelPool3DBackward<
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_depth, input_height, input_width, output_depth,
        output_height, output_width, ksize_depth, ksize_height, ksize_width,
        stride_depth, stride_height, stride_width, padding_depth,
        padding_height, padding_width, pool_process);
  }
};

template class Pool3dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dForwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool3dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dBackwardFunctor<
    platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;

}  // namespace math
}  // namespace operators
}  // namespace paddle