pooling.cu 78.7 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/pooling.h"
19 20
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
L
limingshu 已提交
21
#include "paddle/fluid/platform/fast_divmod.h"
C
chengduoZH 已提交
22 23 24 25 26

namespace paddle {
namespace operators {
namespace math {

L
limingshu 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
struct FastDivModForPooling {
 public:
  platform::FastDivMod channel;
  platform::FastDivMod width;
  platform::FastDivMod height;

  explicit HOSTDEVICE FastDivModForPooling(const int channels,
                                           const int output_width,
                                           const int output_height) {
    channel = platform::FastDivMod(channels);
    width = platform::FastDivMod(output_width);
    height = platform::FastDivMod(output_height);
  }
};

struct FastDivModForPoolingWithMoreStaff {
 public:
  platform::FastDivMod channel;
  platform::FastDivMod width;
  platform::FastDivMod height;
  platform::FastDivMod ksize_w;
  platform::FastDivMod ksize_h;
  platform::FastDivMod stride_w;
  platform::FastDivMod stride_h;

  explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
      const int channels, const int input_width, const int input_height,
      const int ksize_width, const int ksize_height, const int stride_width,
      const int stride_height) {
    channel = platform::FastDivMod(channels);
    width = platform::FastDivMod(input_width);
    height = platform::FastDivMod(input_height);
    ksize_w = platform::FastDivMod(ksize_width);
    ksize_h = platform::FastDivMod(ksize_height);
    stride_w = platform::FastDivMod(stride_width);
    stride_h = platform::FastDivMod(stride_height);
  }
};

template <typename FastDivModForPooling>
__device__ void OffsetPreparationFor4Dimension(
    int index, bool channel_last, FastDivModForPooling divmods,
    const int pad_width, const int pad_height, const int aux_width,
    const int aux_height, int* w_offset, int* h_offset, int* c_offset,
    int* stride) {
  if (!channel_last) { /* NCHW */
    auto input_width_divmod = divmods.width.Divmod(index);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *c_offset = channel_divmod.val[1];
    *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
              aux_height * aux_width;
  } else { /* NHWC */
    auto c_divmod = divmods.channel.Divmod(index);
    auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    *c_offset = c_divmod.val[1];
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *stride = input_height_divmod.val[0] * aux_height * aux_width *
              divmods.channel.divisor;
  }
}

93
template <typename PoolProcess, typename T>
L
limingshu 已提交
94 95 96 97 98 99 100 101
__global__ void KernelPool2D(
    const int nthreads, const T* input_data, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width, FastDivModForPooling divmods,
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
102 103
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
104 105 106 107 108 109
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, channel_last, divmods, 0, 0, input_width, input_height,
        &w_offset, &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
110

D
dengkaipeng 已提交
111
    if (adaptive) {
L
limingshu 已提交
112 113 114 115
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
D
dengkaipeng 已提交
116
    } else {
L
limingshu 已提交
117
      hstart = h_offset * stride_height - padding_height;
118
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
119
      hstart = max(hstart, 0);
L
limingshu 已提交
120
      wstart = w_offset * stride_width - padding_width;
121
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
122 123
      wstart = max(wstart, 0);
    }
124

125
    T ele = pool_process.initial();
126 127
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
L
limingshu 已提交
128 129 130
        auto input_idx = channel_last
                             ? (h * input_width + w) * channels + c_offset
                             : h * input_width + w;
131
        pool_process.compute(input_data[input_idx], &ele);
132 133
      }
    }
D
dengkaipeng 已提交
134 135
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
136
    pool_process.finalize(static_cast<T>(pool_size), &ele);
137 138 139
    output_data[index] = ele;
  }
}
L
limingshu 已提交
140 141

template <typename T, typename PoolProcess>
142
__global__ void KernelPool2DGrad(
L
limingshu 已提交
143 144 145 146 147 148 149 150
    const int nthreads, const T* __restrict__ input_data,
    const T* __restrict__ output_data, const const T* __restrict__ output_grad,
    const int output_width, const int output_height, const int input_width,
    const int input_height, const int ksize_width, const int ksize_height,
    const int stride_width, const int stride_height, const int padding_width,
    const int padding_height, FastDivModForPoolingWithMoreStaff divmods,
    PoolProcess pool_process, bool exclusive, bool adaptive,
    T* __restrict__ input_grad, bool channel_last = false) {
151 152
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
153 154 155 156 157 158 159 160 161 162 163
    T input = static_cast<T>(0);
    T input_grad_data = static_cast<T>(0);
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
    OffsetPreparationFor4Dimension<>(index, channel_last, divmods,
                                     padding_width, padding_height,
                                     output_width, output_height, &w_offset,
                                     &h_offset, &c_offset, &output_offset);
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_offset;
164
    }
L
limingshu 已提交
165
    output_grad += output_offset;
166

167
    if (adaptive) {
L
limingshu 已提交
168 169 170 171 172 173
      auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
      auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
      phstart = divmods.height.Div(h_offset * output_height);
      pwstart = divmods.width.Div(w_offset * output_width);
      phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
      pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
174

L
limingshu 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
          auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
          auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
                                                     : ksize_w_divmod.val[0];
          auto tmp_height = ksize_h_divmod.val[1] > 0
                                ? ksize_h_divmod.val[0] + 1
                                : ksize_h_divmod.val[0];
          int pool_size = tmp_height * tmp_width;
          int tmp_idx = ph * output_width + pw;
          int output_sub_idx =
              channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                           : tmp_idx;
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
          pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
        }
      }
196
    } else {
L
limingshu 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
      auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
      auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
      phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
      pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
      phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
      pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

      if (exclusive) {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
            pool_process.compute(
                input, ouput_value, output_grad[output_sub_idx],
                static_cast<T>(1.0 / pool_size), &input_grad_data);
          }
        }
      } else {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int pool_size = ksize_height * ksize_width;
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
            pool_process.compute(
                input, ouput_value, output_grad[output_sub_idx],
                static_cast<T>(1.0 / pool_size), &input_grad_data);
          }
239
        }
240 241
      }
    }
L
limingshu 已提交
242
    input_grad[index] = input_grad_data;
243 244 245
  }
}

246
template <typename T>
247
__global__ void KernelMaxPool2DGrad(
248
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
249 250 251 252
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
L
limingshu 已提交
253
    T* input_grad, FastDivModForPooling divmods, bool channel_last = false) {
254 255
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
256 257 258 259 260 261 262 263
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, channel_last, divmods, 0, 0, input_width, input_height,
        &w_offset, &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
    input_grad += input_offset;

    int hstart = h_offset * stride_height - padding_height;
264 265 266
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

L
limingshu 已提交
267
    int wstart = w_offset * stride_width - padding_width;
268 269 270 271 272 273 274 275
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
L
limingshu 已提交
276 277 278
        int input_data_idx = channel_last
                                 ? (h * input_width + w) * channels + c_offset
                                 : h * input_width + w;
279 280
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
281 282 283 284 285 286 287
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
288
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
289 290 291 292
    }
  }
}

N
nhzlx 已提交
293 294 295 296 297
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
298 299
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
N
nhzlx 已提交
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
315 316 317 318 319 320 321
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  // platform::ChangeThreadNum(context, &thread_num);
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
322 323
  dim3 grid(blocks, 1);

L
limingshu 已提交
324 325
  auto pool_divmods =
      FastDivModForPooling(input_channels, output_width, output_height);
N
nhzlx 已提交
326 327 328
  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
L
limingshu 已提交
329 330
      padding_height, padding_width, pool_divmods, pool_compute, exclusive,
      adaptive, output);
N
nhzlx 已提交
331 332
}

C
chengduoZH 已提交
333
/*
334 335 336 337 338 339
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
340
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
341
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
342
 public:
Q
QI JUN 已提交
343
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
344 345
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
346 347 348
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
349 350 351 352
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
353 354 355
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
356 357 358 359 360 361 362 363
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
364
    T* output_data = output->mutable_data<T>(context.GetPlace());
365 366

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
367 368 369 370 371 372
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
373
    dim3 grid(blocks, 1);
L
limingshu 已提交
374 375 376

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
377
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
378 379
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
L
limingshu 已提交
380 381
        stride_width, padding_height, padding_width, pool_divmods, pool_process,
        exclusive, adaptive, output_data);
382
  }
383 384 385 386
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
387 388
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
408

409 410 411 412 413 414 415
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
416 417 418 419 420 421
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
422
    dim3 grid(blocks, 1);
L
limingshu 已提交
423 424 425

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
426 427 428
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
L
limingshu 已提交
429 430
        stride_width, padding_height, padding_width, pool_divmods, pool_process,
        exclusive, adaptive, output_data, channel_last);
431 432
  }
};
C
chengduoZH 已提交
433
/*
434 435 436 437 438 439
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
440
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
441
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
442
 public:
Q
QI JUN 已提交
443
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
444
                  const framework::Tensor& input,
445
                  const framework::Tensor& output,
C
chengduo 已提交
446 447 448
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
449 450 451
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
468
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
469 470

    int nthreads = batch_size * input_channels * input_height * input_width;
L
limingshu 已提交
471 472 473 474
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(
        input_channels, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height);

475 476 477
    auto config = GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<
        config.block_per_grid, config.thread_per_block, 0, context.stream()>>>(
L
limingshu 已提交
478 479 480 481
        nthreads, input_data, output_data, output_grad_data, output_width,
        output_height, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height, padding_width, padding_height,
        pool_divmods, pool_process, exclusive, adaptive, input_grad_data);
482
  }
483 484 485 486 487 488 489 490 491
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
L
limingshu 已提交
520 521 522 523
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(
        input_channels, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height);

524 525 526
    auto config = GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<
        config.block_per_grid, config.thread_per_block, 0, context.stream()>>>(
L
limingshu 已提交
527 528 529 530 531
        nthreads, input_data, output_data, output_grad_data, output_width,
        output_height, input_width, input_height, ksize_width, ksize_height,
        stride_width, stride_height, padding_width, padding_height,
        pool_divmods, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);
532
  }
533 534
};

C
chengduoZH 已提交
535
/*
536 537 538 539 540 541
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
542
template <typename T>
Q
QI JUN 已提交
543
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
544
 public:
Q
QI JUN 已提交
545
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
546
                  const framework::Tensor& input,
547
                  const framework::Tensor& output,
C
chengduo 已提交
548 549 550 551
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
552
                  framework::Tensor* input_grad) {
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
570
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
571 572 573 574 575 576

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
577 578
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
579
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
580 581 582
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
L
limingshu 已提交
583
        input_grad_data, pool_divmods);
584
  }
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
624 625 626
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);

627 628 629 630
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
L
limingshu 已提交
631
        input_grad_data, pool_divmods, channel_last);
632
  }
633 634
};

N
nhzlx 已提交
635 636 637 638 639
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
640 641
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
642 643
template class MaxPool2dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
644

Q
QI JUN 已提交
645
template class Pool2dFunctor<platform::CUDADeviceContext,
646
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
647
template class Pool2dFunctor<platform::CUDADeviceContext,
648
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
649 650 651 652 653 654 655
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
656
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
657
template class Pool2dFunctor<platform::CUDADeviceContext,
658
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
659 660 661 662 663 664
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
665

666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

683
template <typename PoolProcess, typename T>
684
__global__ void KernelPool3D(
685 686 687
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
688
    const int ksize_depth, const int ksize_height, const int ksize_width,
689
    const int stride_depth, const int stride_height, const int stride_width,
690
    const int padding_depth, const int padding_height, const int padding_width,
691 692
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
693
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
694
       index += blockDim.x * gridDim.x) {
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
711 712 713 714 715

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
716 717
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
718

D
dengkaipeng 已提交
719 720
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
721

D
dengkaipeng 已提交
722 723
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
724 725 726 727 728 729 730 731 732 733 734
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
735 736 737 738 739 740 741 742 743 744 745

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

746
    T ele = pool_process.initial();
747 748 749
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
750 751 752 753 754
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
755 756 757
        }
      }
    }
758
    int pool_size = (exclusive || adaptive)
759 760
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
761
    pool_process.finalize(static_cast<T>(pool_size), &ele);
762 763 764 765
    output_data[index] = ele;
  }
}

L
limingshu 已提交
766
template <typename T, typename PoolProcess>
767
__global__ void KernelPool3DGrad(
L
limingshu 已提交
768 769 770 771 772 773 774 775 776
    const int nthreads, const T* __restrict__ input_data,
    const T* __restrict__ output_data, const T* __restrict__ output_grad,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, PoolProcess pool_process, bool exclusive,
    bool adaptive, T* input_grad, bool channel_last = false) {
777
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
778
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
779 780
    int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
    T input = static_cast<T>(0);
781 782 783 784 785
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
L
limingshu 已提交
786
      c_offset = (index / input_width / input_height / input_depth) % channels;
787
      batch_idx = index / input_width / input_height / input_depth / channels;
L
limingshu 已提交
788 789
      output_stride = (batch_idx * channels + c_offset) * output_depth *
                      output_height * output_width;
790
    } else { /* "NDHWC" */
L
limingshu 已提交
791
      c_offset = index % channels;
792 793 794 795 796 797
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
L
limingshu 已提交
798 799
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
800
    }
801

802 803 804 805
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
806 807 808 809 810 811 812 813
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
814
    } else {
D
dengkaipeng 已提交
815
      pdstart = (d_offset < ksize_depth)
816
                    ? 0
D
dengkaipeng 已提交
817 818
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
819
                    ? 0
D
dengkaipeng 已提交
820 821
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
822
                    ? 0
D
dengkaipeng 已提交
823 824 825 826
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
827
    }
L
limingshu 已提交
828 829 830
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_stride;
831 832
    }
    output_grad += output_stride;
L
limingshu 已提交
833
    T input_grad_data = static_cast<T>(0.0);
834 835 836 837 838

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
862 863 864 865

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
L
limingshu 已提交
866
                        c_offset
867
                  : (pd * output_height + ph) * output_width + pw;
L
limingshu 已提交
868 869 870 871 872
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
          pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
873 874 875
        }
      }
    }
L
limingshu 已提交
876
    input_grad[index] = input_grad_data;
877 878 879
  }
}

880
template <typename T>
881
__global__ void KernelMaxPool3DGrad(
882
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
883 884 885 886 887
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
888 889
    const int padding_height, const int padding_width, T* input_grad,
    bool channel_last = false) {
890
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
891
       index += blockDim.x * gridDim.x) {
892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

910 911 912
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
913

914 915 916
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
917

918 919 920
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
921

922 923 924 925
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

926 927 928 929 930 931 932 933 934 935
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
936 937 938
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
939 940 941 942 943
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
944
            stop = true;
945
            maxIdx = input_data_idx;
946 947 948 949 950 951
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
952
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
953 954 955 956
    }
  }
}

F
feng_shuai 已提交
957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_depth, input_height, input_width,
      output_depth, output_height, output_width, ksize_depth, ksize_height,
      ksize_width, stride_depth, stride_height, stride_width, padding_depth,
      padding_height, padding_width, pool_compute, exclusive, adaptive, output);
}

C
chengduoZH 已提交
1000
/*
1001 1002 1003 1004 1005 1006 1007
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1008
template <typename PoolProcess, class T>
Q
QI JUN 已提交
1009
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
1010
 public:
Q
QI JUN 已提交
1011
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1012 1013
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1014 1015 1016
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
1017 1018 1019 1020 1021
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1022 1023 1024 1025
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
1037
    T* output_data = output->mutable_data<T>(context.GetPlace());
1038 1039 1040

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1041 1042 1043 1044 1045 1046
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1047 1048
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1049
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1050 1051 1052
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1053
        padding_depth, padding_height, padding_width, pool_process, exclusive,
1054
        adaptive, output_data);
1055
  }
1056 1057 1058 1059
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
1060 1061
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1096 1097 1098 1099 1100 1101
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1102 1103 1104 1105 1106 1107 1108 1109 1110
    dim3 grid(blocks, 1);

    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
1111 1112
};

C
chengduoZH 已提交
1113
/*
1114 1115 1116 1117 1118 1119 1120
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1121
template <typename PoolProcess, class T>
Q
QI JUN 已提交
1122
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
1123
 public:
Q
QI JUN 已提交
1124
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1125
                  const framework::Tensor& input,
1126
                  const framework::Tensor& output,
C
chengduo 已提交
1127 1128 1129
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1130 1131 1132
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1155
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1156

1157 1158
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1159 1160 1161 1162
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1163
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1164 1165 1166 1167
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
1168
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
1169
  }
1170 1171 1172 1173 1174 1175 1176 1177 1178
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1217
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
1218 1219 1220 1221 1222 1223 1224
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);  // add channel_last
  }
1225 1226
};

C
chengduoZH 已提交
1227
/*
1228 1229 1230 1231 1232 1233 1234
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1235
template <class T>
Q
QI JUN 已提交
1236
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
1237
 public:
Q
QI JUN 已提交
1238
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1239
                  const framework::Tensor& input,
1240
                  const framework::Tensor& output,
C
chengduo 已提交
1241 1242 1243 1244
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1245
                  framework::Tensor* input_grad) {
1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1268
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1269 1270 1271 1272 1273 1274 1275

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1276
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1277 1278 1279 1280 1281
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
1282
  }
1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data, channel_last);  // add channel_last
  }
1334 1335
};

F
feng_shuai 已提交
1336 1337 1338 1339 1340
template class Pool3dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool3dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
1341 1342
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
1343 1344
template class MaxPool3dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
1345

Q
QI JUN 已提交
1346
template class Pool3dFunctor<platform::CUDADeviceContext,
1347
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
1348
template class Pool3dFunctor<platform::CUDADeviceContext,
1349
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
1350 1351 1352 1353 1354 1355 1356
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
1357
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
1358
template class Pool3dFunctor<platform::CUDADeviceContext,
1359
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
1360 1361 1362 1363 1364 1365
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
1366

1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

C
chengduoZH 已提交
1384
template <typename T1, typename T2>
C
chengduoZH 已提交
1385
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
1386
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1387 1388 1389
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
L
limingshu 已提交
1390 1391
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data,
    FastDivModForPooling divmods) {
C
chengduoZH 已提交
1392
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1393
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1394 1395 1396 1397 1398 1399
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, false, divmods, 0, 0, input_width, input_height, &w_offset,
        &h_offset, &c_offset, &input_offset);
    input_data += input_offset;
C
chengduoZH 已提交
1400

1401
    if (adaptive) {
L
limingshu 已提交
1402 1403
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
C
chengduoZH 已提交
1404

L
limingshu 已提交
1405 1406
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
1407
    } else {
L
limingshu 已提交
1408
      hstart = h_offset * stride_height - padding_height;
1409 1410 1411
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

L
limingshu 已提交
1412
      wstart = w_offset * stride_width - padding_width;
1413 1414 1415
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1416

C
chengduoZH 已提交
1417
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1418
    int max_index = -1;
C
chengduoZH 已提交
1419 1420
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1421 1422 1423 1424
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1425 1426 1427 1428
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1429
    mask_data[index] = max_index;
C
chengduoZH 已提交
1430 1431 1432
  }
}

C
chengduoZH 已提交
1433
template <typename T1, typename T2>
C
chengduoZH 已提交
1434
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
1435
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
1436 1437 1438
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
1439
    const int padding_height, const int padding_width, bool adaptive,
L
limingshu 已提交
1440
    T1* input_grad, FastDivModForPooling divmods) {
C
chengduoZH 已提交
1441
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1442
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1443 1444 1445 1446 1447 1448 1449
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
    OffsetPreparationFor4Dimension<FastDivModForPooling>(
        index, false, divmods, 0, 0, output_width, output_height, &w_offset,
        &h_offset, &c_offset, &output_offset);
    mask_data += output_offset;
    output_grad += output_offset;
C
chengduoZH 已提交
1450

1451
    if (adaptive) {
D
dengkaipeng 已提交
1452
      phstart = h_offset * output_height / input_height;
1453
      phend =
D
dengkaipeng 已提交
1454 1455 1456 1457
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1458 1459
    } else {
      phstart =
D
dengkaipeng 已提交
1460
          (h_offset + padding_height < ksize_height)
1461
              ? 0
D
dengkaipeng 已提交
1462
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1463
      pwstart =
D
dengkaipeng 已提交
1464
          (w_offset + padding_width < ksize_width)
1465
              ? 0
D
dengkaipeng 已提交
1466
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1467
      phend =
D
dengkaipeng 已提交
1468 1469
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1470
    }
C
chengduoZH 已提交
1471

L
limingshu 已提交
1472
    T1 input_grad_data = 0;
D
dengkaipeng 已提交
1473
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
1474 1475
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1476
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
L
limingshu 已提交
1477
          input_grad_data += output_grad[ph * output_width + pw];
C
chengduoZH 已提交
1478 1479
      }
    }
L
limingshu 已提交
1480
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1481 1482 1483
  }
}

C
chengduoZH 已提交
1484 1485 1486 1487 1488
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1489
template <typename T1, typename T2>
Q
QI JUN 已提交
1490
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1491
 public:
Q
QI JUN 已提交
1492
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1493 1494
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1495 1496
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1497 1498 1499 1500
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1501 1502 1503
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1504 1505 1506 1507 1508 1509 1510
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1511 1512 1513
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1514 1515

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1516 1517 1518 1519
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
C
chengduoZH 已提交
1520

F
feng_shuai 已提交
1521 1522 1523
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
L
limingshu 已提交
1524 1525 1526

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
1527
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1528 1529
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
1530
        stride_width, padding_height, padding_width, adaptive, output_data,
L
limingshu 已提交
1531
        mask_data, pool_divmods);
C
chengduoZH 已提交
1532 1533 1534
  }
};

C
chengduoZH 已提交
1535 1536 1537 1538 1539
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1540
template <typename T1, typename T2>
Q
QI JUN 已提交
1541
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1542
 public:
Q
QI JUN 已提交
1543
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1544
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1545 1546
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1547
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1548 1549 1550 1551 1552
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1553 1554 1555 1556 1557 1558 1559 1560 1561
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1562 1563 1564
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1565 1566 1567 1568 1569 1570

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1571 1572
    auto pool_divmods =
        FastDivModForPooling(input_channels, input_width, input_height);
Q
QI JUN 已提交
1573
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1574 1575
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
1576
        stride_height, stride_width, padding_height, padding_width, adaptive,
L
limingshu 已提交
1577
        input_grad_data, pool_divmods);
C
chengduoZH 已提交
1578 1579 1580
  }
};

Q
QI JUN 已提交
1581 1582 1583 1584 1585 1586 1587 1588
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1589

C
chengduoZH 已提交
1590
template <typename T1, typename T2>
C
chengduoZH 已提交
1591
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
1592
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1593 1594 1595 1596 1597
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
1598
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1599
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1600 1601 1602 1603 1604 1605 1606
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
1607

1608 1609 1610 1611
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1612 1613
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
1614

D
dengkaipeng 已提交
1615 1616
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
1617

D
dengkaipeng 已提交
1618 1619
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1631

C
chengduoZH 已提交
1632
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1633
    int max_index = -1;
C
chengduoZH 已提交
1634 1635 1636 1637 1638 1639 1640
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1641 1642
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1643 1644 1645 1646 1647
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1648
    mask_data[index] = max_index;
C
chengduoZH 已提交
1649 1650 1651
  }
}

C
chengduoZH 已提交
1652
template <typename T1, typename T2>
C
chengduoZH 已提交
1653
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1654 1655 1656 1657 1658 1659
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1660
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1661
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1662
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1663 1664 1665
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
L
limingshu 已提交
1666 1667
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1668 1669
    int batch_idx = index / input_width / input_height / input_depth / channels;

1670 1671 1672 1673
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1674 1675 1676 1677
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
1678
      phend =
D
dengkaipeng 已提交
1679 1680 1681 1682
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1683 1684
    } else {
      pdstart =
D
dengkaipeng 已提交
1685
          (d_offset + padding_depth < ksize_depth)
1686
              ? 0
D
dengkaipeng 已提交
1687
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
1688
      phstart =
D
dengkaipeng 已提交
1689
          (h_offset + padding_height < ksize_height)
1690
              ? 0
D
dengkaipeng 已提交
1691
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1692
      pwstart =
D
dengkaipeng 已提交
1693
          (w_offset + padding_width < ksize_width)
1694
              ? 0
D
dengkaipeng 已提交
1695 1696
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
1697
      phend =
D
dengkaipeng 已提交
1698 1699
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1700
    }
C
chengduoZH 已提交
1701

L
limingshu 已提交
1702
    T1 input_grad_data = 0;
C
chengduoZH 已提交
1703
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
1704
        (d_offset * input_height + h_offset) * input_width + w_offset;
L
limingshu 已提交
1705
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
1706 1707 1708 1709
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1710 1711 1712
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1713 1714
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
L
limingshu 已提交
1715
            input_grad_data +=
C
chengduoZH 已提交
1716 1717 1718 1719
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
L
limingshu 已提交
1720
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1721 1722 1723
  }
}

C
chengduoZH 已提交
1724 1725 1726 1727 1728
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1729
template <typename T1, typename T2>
Q
QI JUN 已提交
1730
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1731
 public:
Q
QI JUN 已提交
1732
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1733 1734
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1735 1736
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1737 1738 1739 1740 1741
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1742 1743 1744 1745
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1746 1747 1748 1749 1750 1751 1752 1753 1754 1755
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1756 1757 1758
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1759 1760 1761

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1762 1763 1764 1765 1766 1767 1768
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif

    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
C
chengduoZH 已提交
1769 1770
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1771
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1772 1773 1774
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1775 1776
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1777 1778 1779
  }
};

C
chengduoZH 已提交
1780 1781 1782 1783 1784
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1785
template <typename T1, typename T2>
Q
QI JUN 已提交
1786
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1787
 public:
Q
QI JUN 已提交
1788
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1789
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1790 1791
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1792
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1793 1794 1795 1796 1797 1798
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1799 1800 1801
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1802 1803 1804 1805 1806 1807 1808 1809 1810 1811
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1812 1813 1814
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1815 1816 1817 1818 1819 1820 1821

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1822
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1823 1824 1825
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1826
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1827
        input_grad_data);
C
chengduoZH 已提交
1828 1829 1830
  }
};

Q
QI JUN 已提交
1831 1832 1833 1834 1835 1836 1837 1838
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1839 1840 1841 1842

}  // namespace math
}  // namespace operators
}  // namespace paddle