pooling.cu 103.9 KB
Newer Older
F
From00 已提交
1
/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
17

18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
L
limingshu 已提交
19
#include "paddle/fluid/platform/fast_divmod.h"
F
From00 已提交
20
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
21
#include "paddle/phi/kernels/funcs/pooling.h"
C
chengduoZH 已提交
22

F
From00 已提交
23 24
namespace phi {
namespace funcs {
C
chengduoZH 已提交
25

L
limingshu 已提交
26 27
struct FastDivModForPooling {
 public:
F
From00 已提交
28 29 30
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
L
limingshu 已提交
31 32 33 34

  explicit HOSTDEVICE FastDivModForPooling(const int channels,
                                           const int output_width,
                                           const int output_height) {
F
From00 已提交
35 36 37
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(output_width);
    height = paddle::platform::FastDivMod(output_height);
L
limingshu 已提交
38 39 40
  }
};

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
struct FastDivModForPooling3D {
 public:
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
  paddle::platform::FastDivMod depth;

  explicit HOSTDEVICE FastDivModForPooling3D(const int channels,
                                             const int output_width,
                                             const int output_height,
                                             const int output_depth) {
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(output_width);
    height = paddle::platform::FastDivMod(output_height);
    depth = paddle::platform::FastDivMod(output_depth);
  }
};

L
limingshu 已提交
59 60
struct FastDivModForPoolingWithMoreStaff {
 public:
F
From00 已提交
61 62 63 64 65 66 67
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
  paddle::platform::FastDivMod ksize_w;
  paddle::platform::FastDivMod ksize_h;
  paddle::platform::FastDivMod stride_w;
  paddle::platform::FastDivMod stride_h;
L
limingshu 已提交
68 69

  explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
F
From00 已提交
70 71 72 73 74 75
      const int channels,
      const int input_width,
      const int input_height,
      const int ksize_width,
      const int ksize_height,
      const int stride_width,
L
limingshu 已提交
76
      const int stride_height) {
F
From00 已提交
77 78 79 80 81 82 83
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(input_width);
    height = paddle::platform::FastDivMod(input_height);
    ksize_w = paddle::platform::FastDivMod(ksize_width);
    ksize_h = paddle::platform::FastDivMod(ksize_height);
    stride_w = paddle::platform::FastDivMod(stride_width);
    stride_h = paddle::platform::FastDivMod(stride_height);
L
limingshu 已提交
84 85 86 87
  }
};

template <typename FastDivModForPooling>
F
From00 已提交
88 89 90 91 92 93 94 95 96 97 98
__device__ void OffsetPreparationFor4Dimension(int index,
                                               bool channel_last,
                                               FastDivModForPooling divmods,
                                               const int pad_width,
                                               const int pad_height,
                                               const int aux_width,
                                               const int aux_height,
                                               int* w_offset,
                                               int* h_offset,
                                               int* c_offset,
                                               int* stride) {
L
limingshu 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  if (!channel_last) { /* NCHW */
    auto input_width_divmod = divmods.width.Divmod(index);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *c_offset = channel_divmod.val[1];
    *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
              aux_height * aux_width;
  } else { /* NHWC */
    auto c_divmod = divmods.channel.Divmod(index);
    auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    *c_offset = c_divmod.val[1];
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *stride = input_height_divmod.val[0] * aux_height * aux_width *
              divmods.channel.divisor;
  }
}

120
template <typename PoolProcess, typename T>
F
From00 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
__global__ void KernelPool2D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_height,
                             const int input_width,
                             const int output_height,
                             const int output_width,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_height,
                             const int stride_width,
                             const int padding_height,
                             const int padding_width,
                             FastDivModForPooling divmods,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
140 141
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
142 143
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
144 145 146 147 148 149 150 151 152 153 154
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
155
    input_data += input_offset;
156

D
dengkaipeng 已提交
157
    if (adaptive) {
L
limingshu 已提交
158 159 160 161
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
D
dengkaipeng 已提交
162
    } else {
L
limingshu 已提交
163
      hstart = h_offset * stride_height - padding_height;
164
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
165
      hstart = max(hstart, 0);
L
limingshu 已提交
166
      wstart = w_offset * stride_width - padding_width;
167
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
168 169
      wstart = max(wstart, 0);
    }
170

171
    T ele = pool_process.initial();
172 173
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
L
limingshu 已提交
174 175 176
        auto input_idx = channel_last
                             ? (h * input_width + w) * channels + c_offset
                             : h * input_width + w;
177
        pool_process.compute(input_data[input_idx], &ele);
178 179
      }
    }
D
dengkaipeng 已提交
180 181
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
182
    pool_process.finalize(static_cast<T>(pool_size), &ele);
183 184 185
    output_data[index] = ele;
  }
}
L
limingshu 已提交
186 187

template <typename T, typename PoolProcess>
F
From00 已提交
188 189 190
__global__ void KernelPool2DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
191
                                 const T* __restrict__ output_grad,
F
From00 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
                                 const int output_width,
                                 const int output_height,
                                 const int input_width,
                                 const int input_height,
                                 const int ksize_width,
                                 const int ksize_height,
                                 const int stride_width,
                                 const int stride_height,
                                 const int padding_width,
                                 const int padding_height,
                                 FastDivModForPoolingWithMoreStaff divmods,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* __restrict__ input_grad,
                                 bool channel_last = false) {
208 209
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
210 211 212 213
    T input = static_cast<T>(0);
    T input_grad_data = static_cast<T>(0);
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
214 215 216 217 218 219 220 221 222 223 224
    OffsetPreparationFor4Dimension<>(index,
                                     channel_last,
                                     divmods,
                                     padding_width,
                                     padding_height,
                                     output_width,
                                     output_height,
                                     &w_offset,
                                     &h_offset,
                                     &c_offset,
                                     &output_offset);
L
limingshu 已提交
225 226 227
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_offset;
228
    }
L
limingshu 已提交
229
    output_grad += output_offset;
230

231
    if (adaptive) {
L
limingshu 已提交
232 233 234 235 236 237
      auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
      auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
      phstart = divmods.height.Div(h_offset * output_height);
      pwstart = divmods.width.Div(w_offset * output_width);
      phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
      pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
238

L
limingshu 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
          auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
          auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
                                                     : ksize_w_divmod.val[0];
          auto tmp_height = ksize_h_divmod.val[1] > 0
                                ? ksize_h_divmod.val[0] + 1
                                : ksize_h_divmod.val[0];
          int pool_size = tmp_height * tmp_width;
          int tmp_idx = ph * output_width + pw;
          int output_sub_idx =
              channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                           : tmp_idx;
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
255 256 257
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
258 259 260 261
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
        }
      }
262
    } else {
L
limingshu 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
      auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
      auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
      phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
      pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
      phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
      pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

      if (exclusive) {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
286 287 288 289 290
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
291 292 293 294 295 296 297 298 299 300 301 302
          }
        }
      } else {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int pool_size = ksize_height * ksize_width;
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
303 304 305 306 307
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
308
          }
309
        }
310 311
      }
    }
L
limingshu 已提交
312
    input_grad[index] = input_grad_data;
313 314 315
  }
}

316
template <typename T>
F
From00 已提交
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
__global__ void KernelMaxPool2DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_height,
                                    const int input_width,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    FastDivModForPooling divmods,
                                    bool channel_last = false) {
335 336
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
337
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
338 339 340 341 342 343 344 345 346 347 348
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
349 350 351 352
    input_data += input_offset;
    input_grad += input_offset;

    int hstart = h_offset * stride_height - padding_height;
353 354 355
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

L
limingshu 已提交
356
    int wstart = w_offset * stride_width - padding_width;
357 358 359 360 361 362 363 364
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
L
limingshu 已提交
365 366 367
        int input_data_idx = channel_last
                                 ? (h * input_width + w) * channels + c_offset
                                 : h * input_width + w;
368 369
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
370 371 372 373 374 375 376
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
F
From00 已提交
377 378
      paddle::platform::CudaAtomicAdd(input_grad + maxIndex,
                                      output_grad[index]);
379 380 381 382
    }
  }
}

N
nhzlx 已提交
383 384
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
385 386 387 388 389 390 391 392 393 394
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
395
    PoolProcess pool_compute) {
N
nhzlx 已提交
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
411 412
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
413
  // backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
414 415 416 417
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
418 419
  dim3 grid(blocks, 1);

L
limingshu 已提交
420 421
  auto pool_divmods =
      FastDivModForPooling(input_channels, output_width, output_height);
F
From00 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_height,
                                                             input_width,
                                                             output_height,
                                                             output_width,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_height,
                                                             stride_width,
                                                             padding_height,
                                                             padding_width,
                                                             pool_divmods,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
N
nhzlx 已提交
440 441
}

C
chengduoZH 已提交
442
/*
443 444 445 446 447 448
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
449
template <typename PoolProcess, typename T>
F
From00 已提交
450
class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
451
 public:
F
From00 已提交
452 453 454
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
455
                  const std::vector<int>& strides,
F
From00 已提交
456 457 458 459
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
460
                  PoolProcess pool_process) {
461 462 463 464
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
465 466 467
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
468 469 470 471 472 473 474 475
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
476
    T* output_data = context.template Alloc<T>(output);
477 478

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
479 480
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
481
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
482 483 484
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
485
    dim3 grid(blocks, 1);
L
limingshu 已提交
486 487 488

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
    KernelPool2D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 pool_divmods,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data);
508
  }
F
From00 已提交
509 510 511
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
512 513
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
514 515 516 517 518
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
538

539 540 541 542
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
543
    T* output_data = context.template Alloc<T>(output);
544 545

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
546 547
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
548
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
549 550 551
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
552
    dim3 grid(blocks, 1);
L
limingshu 已提交
553 554 555

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
    KernelPool2D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 pool_divmods,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data,
                                                 channel_last);
576 577
  }
};
C
chengduoZH 已提交
578
/*
579 580 581 582 583 584
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
585
template <typename PoolProcess, typename T>
F
From00 已提交
586
class Pool2dGradFunctor<phi::GPUContext, PoolProcess, T> {
587
 public:
F
From00 已提交
588 589 590 591
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
592 593
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
594 595 596 597
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
598
                  PoolProcess pool_process) {
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
615
    T* input_grad_data = context.template Alloc<T>(input_grad);
616 617

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data);
649
  }
F
From00 已提交
650 651 652 653
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
654 655 656
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
657 658 659 660 661
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
687
    T* input_grad_data = context.template Alloc<T>(input_grad);
688 689

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data,
                                                           channel_last);
722
  }
723 724
};

C
chengduoZH 已提交
725
/*
726 727 728 729 730 731
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
732
template <typename T>
F
From00 已提交
733
class MaxPool2dGradFunctor<phi::GPUContext, T> {
734
 public:
F
From00 已提交
735 736 737 738
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
739 740 741
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
742
                  DenseTensor* input_grad) {
743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
760
    T* input_grad_data = context.template Alloc<T>(input_grad);
761 762 763 764 765 766

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
767 768
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
    KernelMaxPool2DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data,
                                                 pool_divmods);
787
  }
F
From00 已提交
788 789 790 791 792 793 794 795 796
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
823
    T* input_grad_data = context.template Alloc<T>(input_grad);
824 825 826 827 828 829

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
830 831 832
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);

833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851
    KernelMaxPool2DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data,
                                                 pool_divmods,
                                                 channel_last);
852
  }
853 854
};

F
From00 已提交
855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
template class Pool2dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool2dGradFunctor<phi::GPUContext, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, double>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool2dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool2dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
883

884
template <typename PoolProcess, typename T>
F
From00 已提交
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
__global__ void KernelPool3D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_depth,
                             const int input_height,
                             const int input_width,
                             const int output_depth,
                             const int output_height,
                             const int output_width,
                             const int ksize_depth,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_depth,
                             const int stride_height,
                             const int stride_width,
                             const int padding_depth,
                             const int padding_height,
                             const int padding_width,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
908
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
909
       index += blockDim.x * gridDim.x) {
910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
926 927 928 929 930

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
931 932
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
933

D
dengkaipeng 已提交
934 935
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
936

D
dengkaipeng 已提交
937 938
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
939 940 941 942 943 944 945 946 947 948 949
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
950 951 952 953 954 955 956 957 958 959 960

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

961
    T ele = pool_process.initial();
962 963 964
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
965 966 967 968 969
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
970 971 972
        }
      }
    }
973
    int pool_size = (exclusive || adaptive)
974 975
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
976
    pool_process.finalize(static_cast<T>(pool_size), &ele);
977 978 979 980
    output_data[index] = ele;
  }
}

L
limingshu 已提交
981
template <typename T, typename PoolProcess>
F
From00 已提交
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
__global__ void KernelPool3DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
                                 const T* __restrict__ output_grad,
                                 const int channels,
                                 const int input_depth,
                                 const int input_height,
                                 const int input_width,
                                 const int output_depth,
                                 const int output_height,
                                 const int output_width,
                                 const int ksize_depth,
                                 const int ksize_height,
                                 const int ksize_width,
                                 const int stride_depth,
                                 const int stride_height,
                                 const int stride_width,
                                 const int padding_depth,
                                 const int padding_height,
                                 const int padding_width,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* input_grad,
                                 bool channel_last = false) {
1007
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
1008
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1009 1010
    int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
    T input = static_cast<T>(0);
1011 1012 1013 1014 1015
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
L
limingshu 已提交
1016
      c_offset = (index / input_width / input_height / input_depth) % channels;
1017
      batch_idx = index / input_width / input_height / input_depth / channels;
L
limingshu 已提交
1018 1019
      output_stride = (batch_idx * channels + c_offset) * output_depth *
                      output_height * output_width;
1020
    } else { /* "NDHWC" */
L
limingshu 已提交
1021
      c_offset = index % channels;
1022 1023 1024 1025 1026 1027
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
L
limingshu 已提交
1028 1029
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
1030
    }
1031

1032 1033 1034 1035
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
1036 1037 1038 1039 1040 1041 1042 1043
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
1044
    } else {
D
dengkaipeng 已提交
1045
      pdstart = (d_offset < ksize_depth)
1046
                    ? 0
D
dengkaipeng 已提交
1047 1048
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
1049
                    ? 0
D
dengkaipeng 已提交
1050 1051
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
1052
                    ? 0
D
dengkaipeng 已提交
1053 1054 1055 1056
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
1057
    }
L
limingshu 已提交
1058 1059 1060
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_stride;
1061 1062
    }
    output_grad += output_stride;
L
limingshu 已提交
1063
    T input_grad_data = static_cast<T>(0.0);
1064 1065 1066 1067 1068

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
1092 1093 1094 1095

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
L
limingshu 已提交
1096
                        c_offset
1097
                  : (pd * output_height + ph) * output_width + pw;
L
limingshu 已提交
1098 1099
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
1100 1101 1102
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
1103 1104
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
1105 1106 1107
        }
      }
    }
L
limingshu 已提交
1108
    input_grad[index] = input_grad_data;
1109 1110 1111
  }
}

1112
template <typename T>
F
From00 已提交
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134
__global__ void KernelMaxPool3DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_depth,
                                    const int input_height,
                                    const int input_width,
                                    const int output_depth,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_depth,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_depth,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_depth,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    bool channel_last = false) {
1135
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
1136
       index += blockDim.x * gridDim.x) {
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

1155 1156 1157
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
1158

1159 1160 1161
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
1162

1163 1164 1165
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
1166

1167 1168 1169 1170
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

1171 1172 1173 1174 1175 1176 1177 1178 1179 1180
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
1181 1182 1183
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
1184 1185 1186 1187 1188
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
1189
            stop = true;
1190
            maxIdx = input_data_idx;
1191 1192 1193 1194 1195 1196
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
F
From00 已提交
1197
      paddle::platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
1198 1199 1200 1201
    }
  }
}

F
feng_shuai 已提交
1202 1203
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
1204 1205 1206 1207 1208 1209 1210 1211 1212 1213
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
F
feng_shuai 已提交
1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

F
From00 已提交
1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265
  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_depth,
                                                             input_height,
                                                             input_width,
                                                             output_depth,
                                                             output_height,
                                                             output_width,
                                                             ksize_depth,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_depth,
                                                             stride_height,
                                                             stride_width,
                                                             padding_depth,
                                                             padding_height,
                                                             padding_width,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
F
feng_shuai 已提交
1266 1267
}

C
chengduoZH 已提交
1268
/*
1269 1270 1271 1272 1273 1274 1275
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1276
template <typename PoolProcess, class T>
F
From00 已提交
1277
class Pool3dFunctor<phi::GPUContext, PoolProcess, T> {
1278
 public:
F
From00 已提交
1279 1280 1281
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1282
                  const std::vector<int>& strides,
F
From00 已提交
1283 1284 1285 1286
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
1287
                  PoolProcess pool_process) {
1288 1289 1290 1291 1292
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1293 1294 1295 1296
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1308
    T* output_data = context.template Alloc<T>(output);
1309 1310 1311

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1312 1313
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1314
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1315 1316 1317
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1318 1319
    dim3 grid(blocks, 1);

1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
    KernelPool3D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data);
1343
  }
F
From00 已提交
1344 1345 1346
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
1347 1348
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1349 1350 1351 1352 1353
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1384
    T* output_data = context.template Alloc<T>(output);
1385 1386 1387

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1388 1389
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1390
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1391 1392 1393
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1394 1395
    dim3 grid(blocks, 1);

1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
    KernelPool3D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data,
                                                 channel_last);
1420
  }
1421 1422
};

C
chengduoZH 已提交
1423
/*
1424 1425 1426 1427 1428 1429 1430
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1431
template <typename PoolProcess, class T>
F
From00 已提交
1432
class Pool3dGradFunctor<phi::GPUContext, PoolProcess, T> {
1433
 public:
F
From00 已提交
1434 1435 1436 1437
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1438 1439
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
1440 1441 1442 1443
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
1444
                  PoolProcess pool_process) {
1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1467
    T* input_grad_data = context.template Alloc<T>(input_grad);
1468

1469 1470
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1471 1472 1473 1474
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499
    KernelPool3DGrad<T, PoolProcess>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 input_grad_data);
1500
  }
F
From00 已提交
1501 1502 1503 1504
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
1505 1506 1507
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1508 1509 1510 1511 1512
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1543
    T* input_grad_data = context.template Alloc<T>(input_grad);
1544 1545 1546 1547 1548 1549 1550

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1551
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        input_grad_data,
1576 1577
        channel_last);  // add channel_last
  }
1578 1579
};

C
chengduoZH 已提交
1580
/*
1581 1582 1583 1584 1585 1586 1587
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1588
template <class T>
F
From00 已提交
1589
class MaxPool3dGradFunctor<phi::GPUContext, T> {
1590
 public:
F
From00 已提交
1591 1592 1593 1594
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1595 1596 1597
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1598
                  DenseTensor* input_grad) {
1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1621
    T* input_grad_data = context.template Alloc<T>(input_grad);
1622 1623 1624 1625 1626 1627 1628

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650
    KernelMaxPool3DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data);
1651
  }
F
From00 已提交
1652 1653 1654 1655 1656 1657 1658 1659 1660
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1691
    T* input_grad_data = context.template Alloc<T>(input_grad);
1692 1693 1694 1695 1696 1697 1698 1699

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        input_grad_data,
        channel_last);  // add channel_last
1722
  }
1723 1724
};

F
From00 已提交
1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752
template class Pool3dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool3dGradFunctor<phi::GPUContext, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, double>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool3dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool3dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
1753

C
chengduoZH 已提交
1754
template <typename T1, typename T2>
F
From00 已提交
1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771
__global__ void KernelMaxPool2dWithIdx(const int nthreads,
                                       const T1* input_data,
                                       const int channels,
                                       const int input_height,
                                       const int input_width,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
                                       T2* mask_data,
                                       FastDivModForPooling divmods) {
C
chengduoZH 已提交
1772
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1773
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1774 1775
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
1787
    input_data += input_offset;
C
chengduoZH 已提交
1788

1789
    if (adaptive) {
L
limingshu 已提交
1790 1791
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
C
chengduoZH 已提交
1792

L
limingshu 已提交
1793 1794
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
1795
    } else {
L
limingshu 已提交
1796
      hstart = h_offset * stride_height - padding_height;
1797 1798 1799
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

L
limingshu 已提交
1800
      wstart = w_offset * stride_width - padding_width;
1801 1802 1803
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1804

C
chengduoZH 已提交
1805
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1806
    int max_index = -1;
C
chengduoZH 已提交
1807 1808
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1809 1810 1811 1812
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1813 1814 1815 1816
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1817
    mask_data[index] = max_index;
C
chengduoZH 已提交
1818 1819 1820
  }
}

C
chengduoZH 已提交
1821
template <typename T1, typename T2>
F
From00 已提交
1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838
__global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask_data,
                                           const int channels,
                                           const int input_height,
                                           const int input_width,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad,
                                           FastDivModForPooling divmods) {
C
chengduoZH 已提交
1839
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1840
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1841 1842
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         output_width,
                                                         output_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &output_offset);
L
limingshu 已提交
1854 1855
    mask_data += output_offset;
    output_grad += output_offset;
C
chengduoZH 已提交
1856

1857
    if (adaptive) {
D
dengkaipeng 已提交
1858
      phstart = h_offset * output_height / input_height;
1859
      phend =
D
dengkaipeng 已提交
1860 1861 1862 1863
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1864 1865
    } else {
      phstart =
D
dengkaipeng 已提交
1866
          (h_offset + padding_height < ksize_height)
1867
              ? 0
D
dengkaipeng 已提交
1868
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1869
      pwstart =
D
dengkaipeng 已提交
1870
          (w_offset + padding_width < ksize_width)
1871
              ? 0
D
dengkaipeng 已提交
1872
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1873
      phend =
D
dengkaipeng 已提交
1874 1875
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1876
    }
C
chengduoZH 已提交
1877

L
limingshu 已提交
1878
    T1 input_grad_data = 0;
D
dengkaipeng 已提交
1879
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
1880 1881
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1882
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
L
limingshu 已提交
1883
          input_grad_data += output_grad[ph * output_width + pw];
C
chengduoZH 已提交
1884 1885
      }
    }
L
limingshu 已提交
1886
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1887 1888 1889
  }
}

C
chengduoZH 已提交
1890 1891 1892 1893 1894
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1895
template <typename T1, typename T2>
F
From00 已提交
1896
class MaxPool2dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1897
 public:
F
From00 已提交
1898 1899 1900
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1901
                  const std::vector<int>& strides,
F
From00 已提交
1902 1903 1904 1905
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
1906 1907 1908 1909
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1910 1911 1912
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1913 1914 1915 1916 1917 1918 1919
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1920
    const T1* input_data = input.data<T1>();
F
From00 已提交
1921 1922
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
1923 1924

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1925 1926
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1927
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1928
#endif
C
chengduoZH 已提交
1929

F
feng_shuai 已提交
1930 1931 1932
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
L
limingshu 已提交
1933 1934 1935

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953
    KernelMaxPool2dWithIdx<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 output_data,
                                                 mask_data,
                                                 pool_divmods);
C
chengduoZH 已提交
1954 1955 1956
  }
};

C
chengduoZH 已提交
1957 1958 1959 1960 1961
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1962
template <typename T1, typename T2>
F
From00 已提交
1963
class MaxPool2dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1964
 public:
F
From00 已提交
1965 1966 1967 1968
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1969
                  const std::vector<int>& strides,
F
From00 已提交
1970 1971 1972
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
1973 1974 1975 1976
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1977 1978 1979 1980 1981 1982 1983 1984 1985
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1986 1987
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
F
From00 已提交
1988
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
1989 1990 1991 1992 1993 1994

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1995 1996
    auto pool_divmods =
        FastDivModForPooling(input_channels, input_width, input_height);
1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014
    KernelMaxPool2DWithIdxGrad<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 output_grad_data,
                                                 mask_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 input_grad_data,
                                                 pool_divmods);
C
chengduoZH 已提交
2015 2016 2017
  }
};

F
From00 已提交
2018 2019 2020 2021
template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
C
chengduoZH 已提交
2022

C
chengduoZH 已提交
2023
template <typename T1, typename T2>
2024
__global__ void KernelMaxPool3DWithIdx(const int ncd,
F
From00 已提交
2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043
                                       const T1* input_data,
                                       const int channels,
                                       const int input_depth,
                                       const int input_height,
                                       const int input_width,
                                       const int output_depth,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_depth,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_depth,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_depth,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085
                                       T2* mask_data,
                                       FastDivModForPooling3D divmods_output) {
  int w_offset, h_offset, d_offset, nc_offset;
  int dstart, dend, hstart, hend, wstart, wend;
  const T1* input_data_cur;

  w_offset = blockIdx.x * blockDim.x + threadIdx.x;
  h_offset = blockIdx.y * blockDim.y + threadIdx.y;

  if (w_offset < output_width && h_offset < output_height) {
    for (int index_z = blockIdx.z * blockDim.z + threadIdx.z; index_z < ncd;
         index_z += gridDim.z * blockDim.z) {
      auto output_depth_divmod = divmods_output.depth.Divmod(index_z);
      d_offset = output_depth_divmod.val[1];
      nc_offset = output_depth_divmod.val[0];
      int output_index =
          nc_offset * output_depth * output_height * output_width +
          d_offset * output_height * output_width + h_offset * output_width +
          w_offset;
      int input_offset = nc_offset * input_depth * input_height * input_width;
      input_data_cur = input_data + input_offset;

      if (adaptive) {
        dstart = AdaptStartIndex(d_offset, input_depth, output_depth);
        dend = AdaptEndIndex(d_offset, input_depth, output_depth);

        hstart = AdaptStartIndex(h_offset, input_height, output_height);
        hend = AdaptEndIndex(h_offset, input_height, output_height);

        wstart = AdaptStartIndex(w_offset, input_width, output_width);
        wend = AdaptEndIndex(w_offset, input_width, output_width);
      } else {
        dstart = d_offset * stride_depth - padding_depth;
        hstart = h_offset * stride_height - padding_height;
        wstart = w_offset * stride_width - padding_width;
        dend = min(dstart + ksize_depth, input_depth);
        hend = min(hstart + ksize_height, input_height);
        wend = min(wstart + ksize_width, input_width);
        dstart = max(dstart, 0);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
      }
C
chengduoZH 已提交
2086

2087 2088 2089 2090 2091 2092 2093 2094 2095 2096
      T1 ele = -FLT_MAX;
      int max_index = -1;
      for (int d = dstart; d < dend; ++d) {
        for (int h = hstart; h < hend; ++h) {
          for (int w = wstart; w < wend; ++w) {
            if (ele <
                input_data_cur[(d * input_height + h) * input_width + w]) {
              max_index = (d * input_height + h) * input_width + w;
              ele = input_data_cur[max_index];
            }
C
chengduoZH 已提交
2097 2098 2099
          }
        }
      }
2100 2101
      output_data[output_index] = ele;
      mask_data[output_index] = max_index;
C
chengduoZH 已提交
2102 2103 2104 2105
    }
  }
}

C
chengduoZH 已提交
2106
template <typename T1, typename T2>
F
From00 已提交
2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127
__global__ void KernelMaxPool3DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask,
                                           const int channels,
                                           const int input_depth,
                                           const int input_height,
                                           const int input_width,
                                           const int output_depth,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_depth,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_depth,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_depth,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad) {
C
chengduoZH 已提交
2128
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
2129
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
2130 2131 2132
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
L
limingshu 已提交
2133 2134
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
2135 2136
    int batch_idx = index / input_width / input_height / input_depth / channels;

2137 2138 2139 2140
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
2141 2142 2143 2144
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
2145
      phend =
D
dengkaipeng 已提交
2146 2147 2148 2149
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
2150 2151
    } else {
      pdstart =
D
dengkaipeng 已提交
2152
          (d_offset + padding_depth < ksize_depth)
2153
              ? 0
D
dengkaipeng 已提交
2154
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
2155
      phstart =
D
dengkaipeng 已提交
2156
          (h_offset + padding_height < ksize_height)
2157
              ? 0
D
dengkaipeng 已提交
2158
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
2159
      pwstart =
D
dengkaipeng 已提交
2160
          (w_offset + padding_width < ksize_width)
2161
              ? 0
D
dengkaipeng 已提交
2162 2163
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
2164
      phend =
D
dengkaipeng 已提交
2165 2166
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
2167
    }
C
chengduoZH 已提交
2168

L
limingshu 已提交
2169
    T1 input_grad_data = 0;
C
chengduoZH 已提交
2170
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
2171
        (d_offset * input_height + h_offset) * input_width + w_offset;
L
limingshu 已提交
2172
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
2173 2174 2175 2176
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

2177 2178 2179
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
2180 2181
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
L
limingshu 已提交
2182
            input_grad_data +=
C
chengduoZH 已提交
2183 2184 2185 2186
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
L
limingshu 已提交
2187
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
2188 2189 2190
  }
}

C
chengduoZH 已提交
2191 2192 2193 2194 2195
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2196
template <typename T1, typename T2>
F
From00 已提交
2197
class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2198
 public:
F
From00 已提交
2199 2200 2201
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2202
                  const std::vector<int>& strides,
F
From00 已提交
2203 2204 2205 2206
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
2207 2208 2209 2210 2211
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
2212 2213 2214 2215
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
2216 2217 2218 2219 2220 2221 2222 2223 2224 2225
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2226
    const T1* input_data = input.data<T1>();
F
From00 已提交
2227 2228
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
2229

2230
    int ncd = batch_size * input_channels * output_depth;
F
feng_shuai 已提交
2231

2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245
    int thread_x = 32;
    int thread_y = 8;
    int thread_z = 1;
    dim3 threads(thread_x, thread_y, thread_z);
    std::array<int, 3> max_grid_dim = context.GetCUDAMaxGridDimSize();
    int block_x = (output_width + threads.x - 1) / threads.x;
    int block_y = (output_height + threads.y - 1) / threads.y;
    int block_z = (ncd > max_grid_dim[2] * threads.z)
                      ? max_grid_dim[2]
                      : (ncd + threads.z - 1) / threads.z;
    dim3 grid(block_x, block_y, block_z);

    auto pool_divmods_output = FastDivModForPooling3D(
        input_channels, output_width, output_height, output_depth);
C
chengduoZH 已提交
2246

2247
    KernelMaxPool3DWithIdx<T1, T2>
2248
        <<<grid, threads, 0, context.stream()>>>(ncd,
2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 output_data,
2268 2269
                                                 mask_data,
                                                 pool_divmods_output);
C
chengduoZH 已提交
2270 2271 2272
  }
};

C
chengduoZH 已提交
2273 2274 2275 2276 2277
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2278
template <typename T1, typename T2>
F
From00 已提交
2279
class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2280
 public:
F
From00 已提交
2281 2282 2283 2284
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2285
                  const std::vector<int>& strides,
F
From00 已提交
2286 2287 2288
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
2289 2290 2291 2292 2293
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
2294 2295 2296
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
2297 2298 2299 2300 2301 2302 2303 2304 2305 2306
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2307 2308
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
F
From00 已提交
2309
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
2310 2311 2312 2313 2314 2315 2316

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338
    KernelMaxPool3DWithIdxGrad<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 output_grad_data,
                                                 mask_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 input_grad_data);
C
chengduoZH 已提交
2339 2340 2341
  }
};

F
From00 已提交
2342 2343 2344 2345 2346 2347 2348
template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>;

}  // namespace funcs
}  // namespace phi