pooling.cu 79.7 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/pooling.h"
19 20 21
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
L
limingshu 已提交
22
#include "paddle/fluid/platform/fast_divmod.h"
C
chengduoZH 已提交
23

L
limingshu 已提交
24 25 26 27 28 29
#ifdef __HIPCC__
#define POOLING_BLOCK_SIZE 256
#else
#define POOLING_BLOCK_SIZE 512
#endif

C
chengduoZH 已提交
30 31 32 33
namespace paddle {
namespace operators {
namespace math {

L
limingshu 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
struct FastDivModForPooling {
 public:
  platform::FastDivMod channel;
  platform::FastDivMod width;
  platform::FastDivMod height;

  explicit HOSTDEVICE FastDivModForPooling(const int channels,
                                           const int output_width,
                                           const int output_height) {
    channel = platform::FastDivMod(channels);
    width = platform::FastDivMod(output_width);
    height = platform::FastDivMod(output_height);
  }
};

struct FastDivModForPoolingWithMoreStaff {
 public:
  platform::FastDivMod channel;
  platform::FastDivMod width;
  platform::FastDivMod height;
  platform::FastDivMod ksize_w;
  platform::FastDivMod ksize_h;
  platform::FastDivMod stride_w;
  platform::FastDivMod stride_h;

  explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
      const int channels, const int input_width, const int input_height,
      const int ksize_width, const int ksize_height, const int stride_width,
      const int stride_height) {
    channel = platform::FastDivMod(channels);
    width = platform::FastDivMod(input_width);
    height = platform::FastDivMod(input_height);
    ksize_w = platform::FastDivMod(ksize_width);
    ksize_h = platform::FastDivMod(ksize_height);
    stride_w = platform::FastDivMod(stride_width);
    stride_h = platform::FastDivMod(stride_height);
  }
};

template <typename FastDivModForPooling>
__device__ void OffsetPreparationFor4Dimension(
    int index, bool channel_last, FastDivModForPooling divmods,
    const int pad_width, const int pad_height, const int aux_width,
    const int aux_height, int* w_offset, int* h_offset, int* c_offset,
    int* stride) {
  if (!channel_last) { /* NCHW */
    auto input_width_divmod = divmods.width.Divmod(index);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *c_offset = channel_divmod.val[1];
    *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
              aux_height * aux_width;
  } else { /* NHWC */
    auto c_divmod = divmods.channel.Divmod(index);
    auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    *c_offset = c_divmod.val[1];
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *stride = input_height_divmod.val[0] * aux_height * aux_width *
              divmods.channel.divisor;
  }
}

int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx,
                       int threads_per_block, int64_t numel) {
  int sm_count = ctx.GetSMCount();
  if (numel / (sm_count << 1) < threads_per_block) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about twice of SM, to acquire better performance.
    threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1));
  } else if (numel / (sm_count << 2) < threads_per_block) {
    // Round up threads number into an exponential multiple of 2, while number
    // of acitve blocks is about 4 times of SM, to acquire better performance.
    threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2));
  }
  // Number of threads per block shall be larger than 64.
  return std::max(64, threads_per_block);
}

116
template <typename PoolProcess, typename T>
L
limingshu 已提交
117 118 119 120 121 122 123 124
__global__ void KernelPool2D(
    const int nthreads, const T* input_data, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width, FastDivModForPooling divmods,
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
125 126
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
127 128 129 130 131 132
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, channel_last, divmods, 0, 0, input_width, input_height,
        &w_offset, &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
133

D
dengkaipeng 已提交
134
    if (adaptive) {
L
limingshu 已提交
135 136 137 138
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
D
dengkaipeng 已提交
139
    } else {
L
limingshu 已提交
140
      hstart = h_offset * stride_height - padding_height;
141
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
142
      hstart = max(hstart, 0);
L
limingshu 已提交
143
      wstart = w_offset * stride_width - padding_width;
144
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
145 146
      wstart = max(wstart, 0);
    }
147

148
    T ele = pool_process.initial();
149 150
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
L
limingshu 已提交
151 152 153
        auto input_idx = channel_last
                             ? (h * input_width + w) * channels + c_offset
                             : h * input_width + w;
154
        pool_process.compute(input_data[input_idx], &ele);
155 156
      }
    }
D
dengkaipeng 已提交
157 158
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
159
    pool_process.finalize(static_cast<T>(pool_size), &ele);
160 161 162
    output_data[index] = ele;
  }
}
L
limingshu 已提交
163 164

template <typename T, typename PoolProcess>
165
__global__ void KernelPool2DGrad(
L
limingshu 已提交
166 167 168 169 170 171 172 173
    const int nthreads, const T* __restrict__ input_data,
    const T* __restrict__ output_data, const const T* __restrict__ output_grad,
    const int output_width, const int output_height, const int input_width,
    const int input_height, const int ksize_width, const int ksize_height,
    const int stride_width, const int stride_height, const int padding_width,
    const int padding_height, FastDivModForPoolingWithMoreStaff divmods,
    PoolProcess pool_process, bool exclusive, bool adaptive,
    T* __restrict__ input_grad, bool channel_last = false) {
174 175
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
176 177 178 179 180 181 182 183 184 185 186
    T input = static_cast<T>(0);
    T input_grad_data = static_cast<T>(0);
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
    OffsetPreparationFor4Dimension<>(index, channel_last, divmods,
                                     padding_width, padding_height,
                                     output_width, output_height, &w_offset,
                                     &h_offset, &c_offset, &output_offset);
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_offset;
187
    }
L
limingshu 已提交
188
    output_grad += output_offset;
189

190
    if (adaptive) {
L
limingshu 已提交
191 192 193 194 195 196
      auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
      auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
      phstart = divmods.height.Div(h_offset * output_height);
      pwstart = divmods.width.Div(w_offset * output_width);
      phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
      pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
197

L
limingshu 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
          auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
          auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
                                                     : ksize_w_divmod.val[0];
          auto tmp_height = ksize_h_divmod.val[1] > 0
                                ? ksize_h_divmod.val[0] + 1
                                : ksize_h_divmod.val[0];
          int pool_size = tmp_height * tmp_width;
          int tmp_idx = ph * output_width + pw;
          int output_sub_idx =
              channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                           : tmp_idx;
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
          pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
        }
      }
219
    } else {
L
limingshu 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
      auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
      auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
      phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
      pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
      phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
      pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

      if (exclusive) {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
            pool_process.compute(
                input, ouput_value, output_grad[output_sub_idx],
                static_cast<T>(1.0 / pool_size), &input_grad_data);
          }
        }
      } else {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int pool_size = ksize_height * ksize_width;
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
            pool_process.compute(
                input, ouput_value, output_grad[output_sub_idx],
                static_cast<T>(1.0 / pool_size), &input_grad_data);
          }
262
        }
263 264
      }
    }
L
limingshu 已提交
265
    input_grad[index] = input_grad_data;
266 267 268
  }
}

269
template <typename T>
270
__global__ void KernelMaxPool2DGrad(
271
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
272 273 274 275
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
L
limingshu 已提交
276
    T* input_grad, FastDivModForPooling divmods, bool channel_last = false) {
277 278
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
279 280 281 282 283 284 285 286
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, channel_last, divmods, 0, 0, input_width, input_height,
        &w_offset, &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
    input_grad += input_offset;

    int hstart = h_offset * stride_height - padding_height;
287 288 289
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

L
limingshu 已提交
290
    int wstart = w_offset * stride_width - padding_width;
291 292 293 294 295 296 297 298
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
L
limingshu 已提交
299 300 301
        int input_data_idx = channel_last
                                 ? (h * input_width + w) * channels + c_offset
                                 : h * input_width + w;
302 303
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
304 305 306 307 308 309 310
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
311
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
312 313 314 315
    }
  }
}

N
nhzlx 已提交
316 317 318 319 320
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
321 322
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
N
nhzlx 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
338 339 340 341 342 343 344
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  // platform::ChangeThreadNum(context, &thread_num);
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
345 346
  dim3 grid(blocks, 1);

L
limingshu 已提交
347 348
  auto pool_divmods =
      FastDivModForPooling(input_channels, output_width, output_height);
N
nhzlx 已提交
349 350 351
  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
L
limingshu 已提交
352 353
      padding_height, padding_width, pool_divmods, pool_compute, exclusive,
      adaptive, output);
N
nhzlx 已提交
354 355
}

C
chengduoZH 已提交
356
/*
357 358 359 360 361 362
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
363
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
364
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
365
 public:
Q
QI JUN 已提交
366
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
367 368
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
369 370 371
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
372 373 374 375
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
376 377 378
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
379 380 381 382 383 384 385 386
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
387
    T* output_data = output->mutable_data<T>(context.GetPlace());
388 389

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
390 391 392 393 394 395
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
396
    dim3 grid(blocks, 1);
L
limingshu 已提交
397 398 399

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
400
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
401 402
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
L
limingshu 已提交
403 404
        stride_width, padding_height, padding_width, pool_divmods, pool_process,
        exclusive, adaptive, output_data);
405
  }
406 407 408 409
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
410 411
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
431

432 433 434 435 436 437 438
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
439 440 441 442 443 444
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
445
    dim3 grid(blocks, 1);
L
limingshu 已提交
446 447 448

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
449 450 451
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
L
limingshu 已提交
452 453
        stride_width, padding_height, padding_width, pool_divmods, pool_process,
        exclusive, adaptive, output_data, channel_last);
454 455
  }
};
C
chengduoZH 已提交
456
/*
457 458 459 460 461 462
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
463
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
464
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
465
 public:
Q
QI JUN 已提交
466
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
467
                  const framework::Tensor& input,
468
                  const framework::Tensor& output,
C
chengduo 已提交
469 470 471
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
472 473 474
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
491
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
492 493

    int nthreads = batch_size * input_channels * input_height * input_width;
L
limingshu 已提交
494 495 496 497 498 499 500 501 502 503 504 505
    int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
    int grids = (nthreads + blocks - 1) / blocks;

    auto pool_divmods = FastDivModForPoolingWithMoreStaff(
        input_channels, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height);

    KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, output_width,
        output_height, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height, padding_width, padding_height,
        pool_divmods, pool_process, exclusive, adaptive, input_grad_data);
506
  }
507 508 509 510 511 512 513 514 515
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
L
limingshu 已提交
544 545 546 547 548 549 550 551 552 553 554 555 556
    int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
    int grids = (nthreads + blocks - 1) / blocks;

    auto pool_divmods = FastDivModForPoolingWithMoreStaff(
        input_channels, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height);

    KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, output_width,
        output_height, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height, padding_width, padding_height,
        pool_divmods, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);
557
  }
558 559
};

C
chengduoZH 已提交
560
/*
561 562 563 564 565 566
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
567
template <typename T>
Q
QI JUN 已提交
568
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
569
 public:
Q
QI JUN 已提交
570
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
571
                  const framework::Tensor& input,
572
                  const framework::Tensor& output,
C
chengduo 已提交
573 574 575 576
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
577
                  framework::Tensor* input_grad) {
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
595
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
596 597 598 599 600 601

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
602 603
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
604
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
605 606 607
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
L
limingshu 已提交
608
        input_grad_data, pool_divmods);
609
  }
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
649 650 651
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);

652 653 654 655
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
L
limingshu 已提交
656
        input_grad_data, pool_divmods, channel_last);
657
  }
658 659
};

N
nhzlx 已提交
660 661 662 663 664
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
665 666
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
667 668
template class MaxPool2dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
669

Q
QI JUN 已提交
670
template class Pool2dFunctor<platform::CUDADeviceContext,
671
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
672
template class Pool2dFunctor<platform::CUDADeviceContext,
673
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
674 675 676 677 678 679 680
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
681
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
682
template class Pool2dFunctor<platform::CUDADeviceContext,
683
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
684 685 686 687 688 689
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
690

691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

708
template <typename PoolProcess, typename T>
709
__global__ void KernelPool3D(
710 711 712
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
713
    const int ksize_depth, const int ksize_height, const int ksize_width,
714
    const int stride_depth, const int stride_height, const int stride_width,
715
    const int padding_depth, const int padding_height, const int padding_width,
716 717
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
718
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
719
       index += blockDim.x * gridDim.x) {
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
736 737 738 739 740

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
741 742
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
743

D
dengkaipeng 已提交
744 745
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
746

D
dengkaipeng 已提交
747 748
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
749 750 751 752 753 754 755 756 757 758 759
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
760 761 762 763 764 765 766 767 768 769 770

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

771
    T ele = pool_process.initial();
772 773 774
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
775 776 777 778 779
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
780 781 782
        }
      }
    }
783
    int pool_size = (exclusive || adaptive)
784 785
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
786
    pool_process.finalize(static_cast<T>(pool_size), &ele);
787 788 789 790
    output_data[index] = ele;
  }
}

L
limingshu 已提交
791
template <typename T, typename PoolProcess>
792
__global__ void KernelPool3DGrad(
L
limingshu 已提交
793 794 795 796 797 798 799 800 801
    const int nthreads, const T* __restrict__ input_data,
    const T* __restrict__ output_data, const T* __restrict__ output_grad,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, PoolProcess pool_process, bool exclusive,
    bool adaptive, T* input_grad, bool channel_last = false) {
802
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
803
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
804 805
    int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
    T input = static_cast<T>(0);
806 807 808 809 810
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
L
limingshu 已提交
811
      c_offset = (index / input_width / input_height / input_depth) % channels;
812
      batch_idx = index / input_width / input_height / input_depth / channels;
L
limingshu 已提交
813 814
      output_stride = (batch_idx * channels + c_offset) * output_depth *
                      output_height * output_width;
815
    } else { /* "NDHWC" */
L
limingshu 已提交
816
      c_offset = index % channels;
817 818 819 820 821 822
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
L
limingshu 已提交
823 824
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
825
    }
826

827 828 829 830
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
831 832 833 834 835 836 837 838
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
839
    } else {
D
dengkaipeng 已提交
840
      pdstart = (d_offset < ksize_depth)
841
                    ? 0
D
dengkaipeng 已提交
842 843
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
844
                    ? 0
D
dengkaipeng 已提交
845 846
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
847
                    ? 0
D
dengkaipeng 已提交
848 849 850 851
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
852
    }
L
limingshu 已提交
853 854 855
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_stride;
856 857
    }
    output_grad += output_stride;
L
limingshu 已提交
858
    T input_grad_data = static_cast<T>(0.0);
859 860 861 862 863

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
887 888 889 890

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
L
limingshu 已提交
891
                        c_offset
892
                  : (pd * output_height + ph) * output_width + pw;
L
limingshu 已提交
893 894 895 896 897
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
          pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
898 899 900
        }
      }
    }
L
limingshu 已提交
901
    input_grad[index] = input_grad_data;
902 903 904
  }
}

905
template <typename T>
906
__global__ void KernelMaxPool3DGrad(
907
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
908 909 910 911 912
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
913 914
    const int padding_height, const int padding_width, T* input_grad,
    bool channel_last = false) {
915
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
916
       index += blockDim.x * gridDim.x) {
917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

935 936 937
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
938

939 940 941
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
942

943 944 945
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
946

947 948 949 950
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

951 952 953 954 955 956 957 958 959 960
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
961 962 963
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
964 965 966 967 968
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
969
            stop = true;
970
            maxIdx = input_data_idx;
971 972 973 974 975 976
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
977
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
978 979 980 981
    }
  }
}

F
feng_shuai 已提交
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_depth, input_height, input_width,
      output_depth, output_height, output_width, ksize_depth, ksize_height,
      ksize_width, stride_depth, stride_height, stride_width, padding_depth,
      padding_height, padding_width, pool_compute, exclusive, adaptive, output);
}

C
chengduoZH 已提交
1025
/*
1026 1027 1028 1029 1030 1031 1032
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1033
template <typename PoolProcess, class T>
Q
QI JUN 已提交
1034
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
1035
 public:
Q
QI JUN 已提交
1036
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1037 1038
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1039 1040 1041
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
1042 1043 1044 1045 1046
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1047 1048 1049 1050
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
1062
    T* output_data = output->mutable_data<T>(context.GetPlace());
1063 1064 1065

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1066 1067 1068 1069 1070 1071
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1072 1073
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1074
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1075 1076 1077
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1078
        padding_depth, padding_height, padding_width, pool_process, exclusive,
1079
        adaptive, output_data);
1080
  }
1081 1082 1083 1084
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1085 1086
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1121 1122 1123 1124 1125 1126
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1127 1128 1129 1130 1131 1132 1133 1134 1135
    dim3 grid(blocks, 1);

    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
1136 1137
};

C
chengduoZH 已提交
1138
/*
1139 1140 1141 1142 1143 1144 1145
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1146
template <typename PoolProcess, class T>
Q
QI JUN 已提交
1147
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
1148
 public:
Q
QI JUN 已提交
1149
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1150
                  const framework::Tensor& input,
1151
                  const framework::Tensor& output,
C
chengduo 已提交
1152 1153 1154
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1155 1156 1157
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1180
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1181

1182 1183
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1184 1185 1186 1187
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1188
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1189 1190 1191 1192
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
1193
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
1194
  }
1195 1196 1197 1198 1199 1200 1201 1202 1203
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1242
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
1243 1244 1245 1246 1247 1248 1249
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);  // add channel_last
  }
1250 1251
};

C
chengduoZH 已提交
1252
/*
1253 1254 1255 1256 1257 1258 1259
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1260
template <class T>
Q
QI JUN 已提交
1261
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
1262
 public:
Q
QI JUN 已提交
1263
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1264
                  const framework::Tensor& input,
1265
                  const framework::Tensor& output,
C
chengduo 已提交
1266 1267 1268 1269
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1270
                  framework::Tensor* input_grad) {
1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1293
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1294 1295 1296 1297 1298 1299 1300

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1301
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1302 1303 1304 1305 1306
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
1307
  }
1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data, channel_last);  // add channel_last
  }
1359 1360
};

F
feng_shuai 已提交
1361 1362 1363 1364 1365
template class Pool3dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool3dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
1366 1367
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
1368 1369
template class MaxPool3dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
1370

Q
QI JUN 已提交
1371
template class Pool3dFunctor<platform::CUDADeviceContext,
1372
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
1373
template class Pool3dFunctor<platform::CUDADeviceContext,
1374
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
1375 1376 1377 1378 1379 1380 1381
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
1382
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
1383
template class Pool3dFunctor<platform::CUDADeviceContext,
1384
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
1385 1386 1387 1388 1389 1390
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
1391

1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

C
chengduoZH 已提交
1409
template <typename T1, typename T2>
C
chengduoZH 已提交
1410
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
1411
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1412 1413 1414
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
L
limingshu 已提交
1415 1416
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data,
    FastDivModForPooling divmods) {
C
chengduoZH 已提交
1417
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1418
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1419 1420 1421 1422 1423 1424
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, false, divmods, 0, 0, input_width, input_height, &w_offset,
        &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
C
chengduoZH 已提交
1425

1426
    if (adaptive) {
L
limingshu 已提交
1427 1428
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
C
chengduoZH 已提交
1429

L
limingshu 已提交
1430 1431
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
1432
    } else {
L
limingshu 已提交
1433
      hstart = h_offset * stride_height - padding_height;
1434 1435 1436
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

L
limingshu 已提交
1437
      wstart = w_offset * stride_width - padding_width;
1438 1439 1440
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1441

C
chengduoZH 已提交
1442
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1443
    int max_index = -1;
C
chengduoZH 已提交
1444 1445
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1446 1447 1448 1449
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1450 1451 1452 1453
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1454
    mask_data[index] = max_index;
C
chengduoZH 已提交
1455 1456 1457
  }
}

C
chengduoZH 已提交
1458
template <typename T1, typename T2>
C
chengduoZH 已提交
1459
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
1460
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
1461 1462 1463
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
1464
    const int padding_height, const int padding_width, bool adaptive,
L
limingshu 已提交
1465
    T1* input_grad, FastDivModForPooling divmods) {
C
chengduoZH 已提交
1466
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1467
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1468 1469 1470 1471 1472 1473 1474
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, false, divmods, 0, 0, output_width, output_height, &w_offset,
        &h_offset, &c_offset, &output_offset);
    mask_data += output_offset;
    output_grad += output_offset;
C
chengduoZH 已提交
1475

1476
    if (adaptive) {
D
dengkaipeng 已提交
1477
      phstart = h_offset * output_height / input_height;
1478
      phend =
D
dengkaipeng 已提交
1479 1480 1481 1482
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1483 1484
    } else {
      phstart =
D
dengkaipeng 已提交
1485
          (h_offset + padding_height < ksize_height)
1486
              ? 0
D
dengkaipeng 已提交
1487
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1488
      pwstart =
D
dengkaipeng 已提交
1489
          (w_offset + padding_width < ksize_width)
1490
              ? 0
D
dengkaipeng 已提交
1491
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1492
      phend =
D
dengkaipeng 已提交
1493 1494
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1495
    }
C
chengduoZH 已提交
1496

L
limingshu 已提交
1497
    T1 input_grad_data = 0;
D
dengkaipeng 已提交
1498
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
1499 1500
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1501
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
L
limingshu 已提交
1502
          input_grad_data += output_grad[ph * output_width + pw];
C
chengduoZH 已提交
1503 1504
      }
    }
L
limingshu 已提交
1505
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1506 1507 1508
  }
}

C
chengduoZH 已提交
1509 1510 1511 1512 1513
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1514
template <typename T1, typename T2>
Q
QI JUN 已提交
1515
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1516
 public:
Q
QI JUN 已提交
1517
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1518 1519
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1520 1521
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1522 1523 1524 1525
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1526 1527 1528
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1529 1530 1531 1532 1533 1534 1535
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1536 1537 1538
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1539 1540

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1541 1542 1543 1544
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
C
chengduoZH 已提交
1545

F
feng_shuai 已提交
1546 1547 1548
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
L
limingshu 已提交
1549 1550 1551

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
1552
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1553 1554
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
1555
        stride_width, padding_height, padding_width, adaptive, output_data,
L
limingshu 已提交
1556
        mask_data, pool_divmods);
C
chengduoZH 已提交
1557 1558 1559
  }
};

C
chengduoZH 已提交
1560 1561 1562 1563 1564
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1565
template <typename T1, typename T2>
Q
QI JUN 已提交
1566
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1567
 public:
Q
QI JUN 已提交
1568
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1569
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1570 1571
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1572
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1573 1574 1575 1576 1577
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1578 1579 1580 1581 1582 1583 1584 1585 1586
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1587 1588 1589
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1590 1591 1592 1593 1594 1595

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1596 1597
    auto pool_divmods =
        FastDivModForPooling(input_channels, input_width, input_height);
Q
QI JUN 已提交
1598
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1599 1600
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
1601
        stride_height, stride_width, padding_height, padding_width, adaptive,
L
limingshu 已提交
1602
        input_grad_data, pool_divmods);
C
chengduoZH 已提交
1603 1604 1605
  }
};

Q
QI JUN 已提交
1606 1607 1608 1609 1610 1611 1612 1613
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1614

C
chengduoZH 已提交
1615
template <typename T1, typename T2>
C
chengduoZH 已提交
1616
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
1617
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1618 1619 1620 1621 1622
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
1623
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1624
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1625 1626 1627 1628 1629 1630 1631
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
1632

1633 1634 1635 1636
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1637 1638
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
1639

D
dengkaipeng 已提交
1640 1641
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
1642

D
dengkaipeng 已提交
1643 1644
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1656

C
chengduoZH 已提交
1657
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1658
    int max_index = -1;
C
chengduoZH 已提交
1659 1660 1661 1662 1663 1664 1665
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1666 1667
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1668 1669 1670 1671 1672
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1673
    mask_data[index] = max_index;
C
chengduoZH 已提交
1674 1675 1676
  }
}

C
chengduoZH 已提交
1677
template <typename T1, typename T2>
C
chengduoZH 已提交
1678
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1679 1680 1681 1682 1683 1684
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1685
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1686
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1687
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1688 1689 1690
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
L
limingshu 已提交
1691 1692
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1693 1694
    int batch_idx = index / input_width / input_height / input_depth / channels;

1695 1696 1697 1698
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1699 1700 1701 1702
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
1703
      phend =
D
dengkaipeng 已提交
1704 1705 1706 1707
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1708 1709
    } else {
      pdstart =
D
dengkaipeng 已提交
1710
          (d_offset + padding_depth < ksize_depth)
1711
              ? 0
D
dengkaipeng 已提交
1712
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
1713
      phstart =
D
dengkaipeng 已提交
1714
          (h_offset + padding_height < ksize_height)
1715
              ? 0
D
dengkaipeng 已提交
1716
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1717
      pwstart =
D
dengkaipeng 已提交
1718
          (w_offset + padding_width < ksize_width)
1719
              ? 0
D
dengkaipeng 已提交
1720 1721
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
1722
      phend =
D
dengkaipeng 已提交
1723 1724
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1725
    }
C
chengduoZH 已提交
1726

L
limingshu 已提交
1727
    T1 input_grad_data = 0;
C
chengduoZH 已提交
1728
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
1729
        (d_offset * input_height + h_offset) * input_width + w_offset;
L
limingshu 已提交
1730
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
1731 1732 1733 1734
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1735 1736 1737
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1738 1739
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
L
limingshu 已提交
1740
            input_grad_data +=
C
chengduoZH 已提交
1741 1742 1743 1744
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
L
limingshu 已提交
1745
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1746 1747 1748
  }
}

C
chengduoZH 已提交
1749 1750 1751 1752 1753
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1754
template <typename T1, typename T2>
Q
QI JUN 已提交
1755
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1756
 public:
Q
QI JUN 已提交
1757
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1758 1759
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1760 1761
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1762 1763 1764 1765 1766
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1767 1768 1769 1770
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1771 1772 1773 1774 1775 1776 1777 1778 1779 1780
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1781 1782 1783
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1784 1785 1786

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1787 1788 1789 1790 1791 1792 1793
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif

    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
C
chengduoZH 已提交
1794 1795
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1796
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1797 1798 1799
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1800 1801
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1802 1803 1804
  }
};

C
chengduoZH 已提交
1805 1806 1807 1808 1809
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1810
template <typename T1, typename T2>
Q
QI JUN 已提交
1811
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1812
 public:
Q
QI JUN 已提交
1813
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1814
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1815 1816
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1817
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1818 1819 1820 1821 1822 1823
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1824 1825 1826
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1827 1828 1829 1830 1831 1832 1833 1834 1835 1836
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1837 1838 1839
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1840 1841 1842 1843 1844 1845 1846

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1847
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1848 1849 1850
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1851
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1852
        input_grad_data);
C
chengduoZH 已提交
1853 1854 1855
  }
};

Q
QI JUN 已提交
1856 1857 1858 1859 1860 1861 1862 1863
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1864 1865 1866 1867

}  // namespace math
}  // namespace operators
}  // namespace paddle