pooling.cu 93.3 KB
Newer Older
F
From00 已提交
1
/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

F
From00 已提交
15 16
#include "paddle/phi/kernels/funcs/pooling.h"

C
chengduo 已提交
17 18
#include <algorithm>
#include <vector>
19
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
L
limingshu 已提交
20
#include "paddle/fluid/platform/fast_divmod.h"
F
From00 已提交
21
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
C
chengduoZH 已提交
22

F
From00 已提交
23 24
namespace phi {
namespace funcs {
C
chengduoZH 已提交
25

L
limingshu 已提交
26 27
struct FastDivModForPooling {
 public:
F
From00 已提交
28 29 30
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
L
limingshu 已提交
31 32 33 34

  explicit HOSTDEVICE FastDivModForPooling(const int channels,
                                           const int output_width,
                                           const int output_height) {
F
From00 已提交
35 36 37
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(output_width);
    height = paddle::platform::FastDivMod(output_height);
L
limingshu 已提交
38 39 40 41 42
  }
};

struct FastDivModForPoolingWithMoreStaff {
 public:
F
From00 已提交
43 44 45 46 47 48 49
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
  paddle::platform::FastDivMod ksize_w;
  paddle::platform::FastDivMod ksize_h;
  paddle::platform::FastDivMod stride_w;
  paddle::platform::FastDivMod stride_h;
L
limingshu 已提交
50 51

  explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
F
From00 已提交
52 53 54 55 56 57
      const int channels,
      const int input_width,
      const int input_height,
      const int ksize_width,
      const int ksize_height,
      const int stride_width,
L
limingshu 已提交
58
      const int stride_height) {
F
From00 已提交
59 60 61 62 63 64 65
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(input_width);
    height = paddle::platform::FastDivMod(input_height);
    ksize_w = paddle::platform::FastDivMod(ksize_width);
    ksize_h = paddle::platform::FastDivMod(ksize_height);
    stride_w = paddle::platform::FastDivMod(stride_width);
    stride_h = paddle::platform::FastDivMod(stride_height);
L
limingshu 已提交
66 67 68 69
  }
};

template <typename FastDivModForPooling>
F
From00 已提交
70 71 72 73 74 75 76 77 78 79 80
__device__ void OffsetPreparationFor4Dimension(int index,
                                               bool channel_last,
                                               FastDivModForPooling divmods,
                                               const int pad_width,
                                               const int pad_height,
                                               const int aux_width,
                                               const int aux_height,
                                               int* w_offset,
                                               int* h_offset,
                                               int* c_offset,
                                               int* stride) {
L
limingshu 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  if (!channel_last) { /* NCHW */
    auto input_width_divmod = divmods.width.Divmod(index);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *c_offset = channel_divmod.val[1];
    *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
              aux_height * aux_width;
  } else { /* NHWC */
    auto c_divmod = divmods.channel.Divmod(index);
    auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    *c_offset = c_divmod.val[1];
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *stride = input_height_divmod.val[0] * aux_height * aux_width *
              divmods.channel.divisor;
  }
}

102
template <typename PoolProcess, typename T>
F
From00 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
__global__ void KernelPool2D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_height,
                             const int input_width,
                             const int output_height,
                             const int output_width,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_height,
                             const int stride_width,
                             const int padding_height,
                             const int padding_width,
                             FastDivModForPooling divmods,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
122 123
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
124 125
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
126 127 128 129 130 131 132 133 134 135 136
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
137
    input_data += input_offset;
138

D
dengkaipeng 已提交
139
    if (adaptive) {
L
limingshu 已提交
140 141 142 143
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
D
dengkaipeng 已提交
144
    } else {
L
limingshu 已提交
145
      hstart = h_offset * stride_height - padding_height;
146
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
147
      hstart = max(hstart, 0);
L
limingshu 已提交
148
      wstart = w_offset * stride_width - padding_width;
149
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
150 151
      wstart = max(wstart, 0);
    }
152

153
    T ele = pool_process.initial();
154 155
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
L
limingshu 已提交
156 157 158
        auto input_idx = channel_last
                             ? (h * input_width + w) * channels + c_offset
                             : h * input_width + w;
159
        pool_process.compute(input_data[input_idx], &ele);
160 161
      }
    }
D
dengkaipeng 已提交
162 163
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
164
    pool_process.finalize(static_cast<T>(pool_size), &ele);
165 166 167
    output_data[index] = ele;
  }
}
L
limingshu 已提交
168 169

template <typename T, typename PoolProcess>
F
From00 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
__global__ void KernelPool2DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
                                 const const T* __restrict__ output_grad,
                                 const int output_width,
                                 const int output_height,
                                 const int input_width,
                                 const int input_height,
                                 const int ksize_width,
                                 const int ksize_height,
                                 const int stride_width,
                                 const int stride_height,
                                 const int padding_width,
                                 const int padding_height,
                                 FastDivModForPoolingWithMoreStaff divmods,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* __restrict__ input_grad,
                                 bool channel_last = false) {
190 191
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
192 193 194 195
    T input = static_cast<T>(0);
    T input_grad_data = static_cast<T>(0);
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
196 197 198 199 200 201 202 203 204 205 206
    OffsetPreparationFor4Dimension<>(index,
                                     channel_last,
                                     divmods,
                                     padding_width,
                                     padding_height,
                                     output_width,
                                     output_height,
                                     &w_offset,
                                     &h_offset,
                                     &c_offset,
                                     &output_offset);
L
limingshu 已提交
207 208 209
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_offset;
210
    }
L
limingshu 已提交
211
    output_grad += output_offset;
212

213
    if (adaptive) {
L
limingshu 已提交
214 215 216 217 218 219
      auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
      auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
      phstart = divmods.height.Div(h_offset * output_height);
      pwstart = divmods.width.Div(w_offset * output_width);
      phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
      pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
220

L
limingshu 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
          auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
          auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
                                                     : ksize_w_divmod.val[0];
          auto tmp_height = ksize_h_divmod.val[1] > 0
                                ? ksize_h_divmod.val[0] + 1
                                : ksize_h_divmod.val[0];
          int pool_size = tmp_height * tmp_width;
          int tmp_idx = ph * output_width + pw;
          int output_sub_idx =
              channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                           : tmp_idx;
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
237 238 239
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
240 241 242 243
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
        }
      }
244
    } else {
L
limingshu 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
      auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
      auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
      phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
      pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
      phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
      pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

      if (exclusive) {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
268 269 270 271 272
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
273 274 275 276 277 278 279 280 281 282 283 284
          }
        }
      } else {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int pool_size = ksize_height * ksize_width;
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
285 286 287 288 289
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
290
          }
291
        }
292 293
      }
    }
L
limingshu 已提交
294
    input_grad[index] = input_grad_data;
295 296 297
  }
}

298
template <typename T>
F
From00 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
__global__ void KernelMaxPool2DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_height,
                                    const int input_width,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    FastDivModForPooling divmods,
                                    bool channel_last = false) {
317 318
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
319
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
320 321 322 323 324 325 326 327 328 329 330
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
331 332 333 334
    input_data += input_offset;
    input_grad += input_offset;

    int hstart = h_offset * stride_height - padding_height;
335 336 337
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

L
limingshu 已提交
338
    int wstart = w_offset * stride_width - padding_width;
339 340 341 342 343 344 345 346
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
L
limingshu 已提交
347 348 349
        int input_data_idx = channel_last
                                 ? (h * input_width + w) * channels + c_offset
                                 : h * input_width + w;
350 351
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
352 353 354 355 356 357 358
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
F
From00 已提交
359 360
      paddle::platform::CudaAtomicAdd(input_grad + maxIndex,
                                      output_grad[index]);
361 362 363 364
    }
  }
}

N
nhzlx 已提交
365 366
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
367 368 369 370 371 372 373 374 375 376
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
377
    PoolProcess pool_compute) {
N
nhzlx 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
393 394
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
395
  // paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
396 397 398 399
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
400 401
  dim3 grid(blocks, 1);

L
limingshu 已提交
402 403
  auto pool_divmods =
      FastDivModForPooling(input_channels, output_width, output_height);
F
From00 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_height,
                                                             input_width,
                                                             output_height,
                                                             output_width,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_height,
                                                             stride_width,
                                                             padding_height,
                                                             padding_width,
                                                             pool_divmods,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
N
nhzlx 已提交
422 423
}

C
chengduoZH 已提交
424
/*
425 426 427 428 429 430
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
431
template <typename PoolProcess, typename T>
F
From00 已提交
432
class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
433
 public:
F
From00 已提交
434 435 436
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
437
                  const std::vector<int>& strides,
F
From00 已提交
438 439 440 441
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
442
                  PoolProcess pool_process) {
443 444 445 446
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
447 448 449
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
450 451 452 453 454 455 456 457
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
458
    T* output_data = context.template Alloc<T>(output);
459 460

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
461 462
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
463
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
464 465 466
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
467
    dim3 grid(blocks, 1);
L
limingshu 已提交
468 469 470

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
471
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
        nthreads,
        input_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        pool_divmods,
        pool_process,
        exclusive,
        adaptive,
        output_data);
490
  }
F
From00 已提交
491 492 493
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
494 495
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
496 497 498 499 500
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
520

521 522 523 524
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
525
    T* output_data = context.template Alloc<T>(output);
526 527

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
528 529
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
530
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
531 532 533
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
534
    dim3 grid(blocks, 1);
L
limingshu 已提交
535 536 537

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
538
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
        nthreads,
        input_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        pool_divmods,
        pool_process,
        exclusive,
        adaptive,
        output_data,
        channel_last);
558 559
  }
};
C
chengduoZH 已提交
560
/*
561 562 563 564 565 566
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
567
template <typename PoolProcess, typename T>
F
From00 已提交
568
class Pool2dGradFunctor<phi::GPUContext, PoolProcess, T> {
569
 public:
F
From00 已提交
570 571 572 573
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
574 575
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
576 577 578 579
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
580
                  PoolProcess pool_process) {
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
597
    T* input_grad_data = context.template Alloc<T>(input_grad);
598 599

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data);
631
  }
F
From00 已提交
632 633 634 635
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
636 637 638
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
639 640 641 642 643
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
669
    T* input_grad_data = context.template Alloc<T>(input_grad);
670 671

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data,
                                                           channel_last);
704
  }
705 706
};

C
chengduoZH 已提交
707
/*
708 709 710 711 712 713
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
714
template <typename T>
F
From00 已提交
715
class MaxPool2dGradFunctor<phi::GPUContext, T> {
716
 public:
F
From00 已提交
717 718 719 720
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
721 722 723
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
724
                  DenseTensor* input_grad) {
725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
742
    T* input_grad_data = context.template Alloc<T>(input_grad);
743 744 745 746 747 748

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
749 750
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
751
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        input_grad_data,
        pool_divmods);
769
  }
F
From00 已提交
770 771 772 773 774 775 776 777 778
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
805
    T* input_grad_data = context.template Alloc<T>(input_grad);
806 807 808 809 810 811

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
812 813 814
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);

815
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        input_grad_data,
        pool_divmods,
        channel_last);
834
  }
835 836
};

F
From00 已提交
837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864
template class Pool2dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool2dGradFunctor<phi::GPUContext, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, double>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool2dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool2dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
865

866
template <typename PoolProcess, typename T>
F
From00 已提交
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889
__global__ void KernelPool3D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_depth,
                             const int input_height,
                             const int input_width,
                             const int output_depth,
                             const int output_height,
                             const int output_width,
                             const int ksize_depth,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_depth,
                             const int stride_height,
                             const int stride_width,
                             const int padding_depth,
                             const int padding_height,
                             const int padding_width,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
890
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
891
       index += blockDim.x * gridDim.x) {
892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
908 909 910 911 912

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
913 914
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
915

D
dengkaipeng 已提交
916 917
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
918

D
dengkaipeng 已提交
919 920
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
921 922 923 924 925 926 927 928 929 930 931
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
932 933 934 935 936 937 938 939 940 941 942

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

943
    T ele = pool_process.initial();
944 945 946
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
947 948 949 950 951
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
952 953 954
        }
      }
    }
955
    int pool_size = (exclusive || adaptive)
956 957
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
958
    pool_process.finalize(static_cast<T>(pool_size), &ele);
959 960 961 962
    output_data[index] = ele;
  }
}

L
limingshu 已提交
963
template <typename T, typename PoolProcess>
F
From00 已提交
964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
__global__ void KernelPool3DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
                                 const T* __restrict__ output_grad,
                                 const int channels,
                                 const int input_depth,
                                 const int input_height,
                                 const int input_width,
                                 const int output_depth,
                                 const int output_height,
                                 const int output_width,
                                 const int ksize_depth,
                                 const int ksize_height,
                                 const int ksize_width,
                                 const int stride_depth,
                                 const int stride_height,
                                 const int stride_width,
                                 const int padding_depth,
                                 const int padding_height,
                                 const int padding_width,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* input_grad,
                                 bool channel_last = false) {
989
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
990
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
991 992
    int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
    T input = static_cast<T>(0);
993 994 995 996 997
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
L
limingshu 已提交
998
      c_offset = (index / input_width / input_height / input_depth) % channels;
999
      batch_idx = index / input_width / input_height / input_depth / channels;
L
limingshu 已提交
1000 1001
      output_stride = (batch_idx * channels + c_offset) * output_depth *
                      output_height * output_width;
1002
    } else { /* "NDHWC" */
L
limingshu 已提交
1003
      c_offset = index % channels;
1004 1005 1006 1007 1008 1009
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
L
limingshu 已提交
1010 1011
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
1012
    }
1013

1014 1015 1016 1017
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
1018 1019 1020 1021 1022 1023 1024 1025
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
1026
    } else {
D
dengkaipeng 已提交
1027
      pdstart = (d_offset < ksize_depth)
1028
                    ? 0
D
dengkaipeng 已提交
1029 1030
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
1031
                    ? 0
D
dengkaipeng 已提交
1032 1033
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
1034
                    ? 0
D
dengkaipeng 已提交
1035 1036 1037 1038
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
1039
    }
L
limingshu 已提交
1040 1041 1042
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_stride;
1043 1044
    }
    output_grad += output_stride;
L
limingshu 已提交
1045
    T input_grad_data = static_cast<T>(0.0);
1046 1047 1048 1049 1050

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
1074 1075 1076 1077

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
L
limingshu 已提交
1078
                        c_offset
1079
                  : (pd * output_height + ph) * output_width + pw;
L
limingshu 已提交
1080 1081
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
1082 1083 1084
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
1085 1086
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
1087 1088 1089
        }
      }
    }
L
limingshu 已提交
1090
    input_grad[index] = input_grad_data;
1091 1092 1093
  }
}

1094
template <typename T>
F
From00 已提交
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116
__global__ void KernelMaxPool3DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_depth,
                                    const int input_height,
                                    const int input_width,
                                    const int output_depth,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_depth,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_depth,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_depth,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    bool channel_last = false) {
1117
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
1118
       index += blockDim.x * gridDim.x) {
1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

1137 1138 1139
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
1140

1141 1142 1143
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
1144

1145 1146 1147
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
1148

1149 1150 1151 1152
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
1163 1164 1165
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
1166 1167 1168 1169 1170
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
1171
            stop = true;
1172
            maxIdx = input_data_idx;
1173 1174 1175 1176 1177 1178
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
F
From00 已提交
1179
      paddle::platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
1180 1181 1182 1183
    }
  }
}

F
feng_shuai 已提交
1184 1185
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
F
feng_shuai 已提交
1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

F
From00 已提交
1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247
  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_depth,
                                                             input_height,
                                                             input_width,
                                                             output_depth,
                                                             output_height,
                                                             output_width,
                                                             ksize_depth,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_depth,
                                                             stride_height,
                                                             stride_width,
                                                             padding_depth,
                                                             padding_height,
                                                             padding_width,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
F
feng_shuai 已提交
1248 1249
}

C
chengduoZH 已提交
1250
/*
1251 1252 1253 1254 1255 1256 1257
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1258
template <typename PoolProcess, class T>
F
From00 已提交
1259
class Pool3dFunctor<phi::GPUContext, PoolProcess, T> {
1260
 public:
F
From00 已提交
1261 1262 1263
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1264
                  const std::vector<int>& strides,
F
From00 已提交
1265 1266 1267 1268
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
1269
                  PoolProcess pool_process) {
1270 1271 1272 1273 1274
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1275 1276 1277 1278
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1290
    T* output_data = context.template Alloc<T>(output);
1291 1292 1293

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1294 1295
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
1296
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1297 1298 1299
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1300 1301
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1302
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324
        nthreads,
        input_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        output_data);
1325
  }
F
From00 已提交
1326 1327 1328
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
1329 1330
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1331 1332 1333 1334 1335
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1366
    T* output_data = context.template Alloc<T>(output);
1367 1368 1369

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1370 1371
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
1372
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1373 1374 1375
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1376 1377 1378
    dim3 grid(blocks, 1);

    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401
        nthreads,
        input_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        output_data,
        channel_last);
1402
  }
1403 1404
};

C
chengduoZH 已提交
1405
/*
1406 1407 1408 1409 1410 1411 1412
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1413
template <typename PoolProcess, class T>
F
From00 已提交
1414
class Pool3dGradFunctor<phi::GPUContext, PoolProcess, T> {
1415
 public:
F
From00 已提交
1416 1417 1418 1419
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1420 1421
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
1422 1423 1424 1425
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
1426
                  PoolProcess pool_process) {
1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1449
    T* input_grad_data = context.template Alloc<T>(input_grad);
1450

1451 1452
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1453 1454 1455 1456
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1457
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        input_grad_data);
1482
  }
F
From00 已提交
1483 1484 1485 1486
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
1487 1488 1489
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1490 1491 1492 1493 1494
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1525
    T* input_grad_data = context.template Alloc<T>(input_grad);
1526 1527 1528 1529 1530 1531 1532

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1533
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        input_grad_data,
1558 1559
        channel_last);  // add channel_last
  }
1560 1561
};

C
chengduoZH 已提交
1562
/*
1563 1564 1565 1566 1567 1568 1569
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1570
template <class T>
F
From00 已提交
1571
class MaxPool3dGradFunctor<phi::GPUContext, T> {
1572
 public:
F
From00 已提交
1573 1574 1575 1576
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1577 1578 1579
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1580
                  DenseTensor* input_grad) {
1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1603
    T* input_grad_data = context.template Alloc<T>(input_grad);
1604 1605 1606 1607 1608 1609 1610

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1611
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        input_grad_data);
1633
  }
F
From00 已提交
1634 1635 1636 1637 1638 1639 1640 1641 1642
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1673
    T* input_grad_data = context.template Alloc<T>(input_grad);
1674 1675 1676 1677 1678 1679 1680 1681

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        input_grad_data,
        channel_last);  // add channel_last
1704
  }
1705 1706
};

F
From00 已提交
1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
template class Pool3dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool3dGradFunctor<phi::GPUContext, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, double>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool3dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool3dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
1735

C
chengduoZH 已提交
1736
template <typename T1, typename T2>
F
From00 已提交
1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753
__global__ void KernelMaxPool2dWithIdx(const int nthreads,
                                       const T1* input_data,
                                       const int channels,
                                       const int input_height,
                                       const int input_width,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
                                       T2* mask_data,
                                       FastDivModForPooling divmods) {
C
chengduoZH 已提交
1754
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1755
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1756 1757
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
1769
    input_data += input_offset;
C
chengduoZH 已提交
1770

1771
    if (adaptive) {
L
limingshu 已提交
1772 1773
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
C
chengduoZH 已提交
1774

L
limingshu 已提交
1775 1776
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
1777
    } else {
L
limingshu 已提交
1778
      hstart = h_offset * stride_height - padding_height;
1779 1780 1781
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

L
limingshu 已提交
1782
      wstart = w_offset * stride_width - padding_width;
1783 1784 1785
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1786

C
chengduoZH 已提交
1787
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1788
    int max_index = -1;
C
chengduoZH 已提交
1789 1790
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1791 1792 1793 1794
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1795 1796 1797 1798
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1799
    mask_data[index] = max_index;
C
chengduoZH 已提交
1800 1801 1802
  }
}

C
chengduoZH 已提交
1803
template <typename T1, typename T2>
F
From00 已提交
1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820
__global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask_data,
                                           const int channels,
                                           const int input_height,
                                           const int input_width,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad,
                                           FastDivModForPooling divmods) {
C
chengduoZH 已提交
1821
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1822
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1823 1824
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         output_width,
                                                         output_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &output_offset);
L
limingshu 已提交
1836 1837
    mask_data += output_offset;
    output_grad += output_offset;
C
chengduoZH 已提交
1838

1839
    if (adaptive) {
D
dengkaipeng 已提交
1840
      phstart = h_offset * output_height / input_height;
1841
      phend =
D
dengkaipeng 已提交
1842 1843 1844 1845
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1846 1847
    } else {
      phstart =
D
dengkaipeng 已提交
1848
          (h_offset + padding_height < ksize_height)
1849
              ? 0
D
dengkaipeng 已提交
1850
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1851
      pwstart =
D
dengkaipeng 已提交
1852
          (w_offset + padding_width < ksize_width)
1853
              ? 0
D
dengkaipeng 已提交
1854
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1855
      phend =
D
dengkaipeng 已提交
1856 1857
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1858
    }
C
chengduoZH 已提交
1859

L
limingshu 已提交
1860
    T1 input_grad_data = 0;
D
dengkaipeng 已提交
1861
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
1862 1863
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1864
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
L
limingshu 已提交
1865
          input_grad_data += output_grad[ph * output_width + pw];
C
chengduoZH 已提交
1866 1867
      }
    }
L
limingshu 已提交
1868
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1869 1870 1871
  }
}

C
chengduoZH 已提交
1872 1873 1874 1875 1876
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1877
template <typename T1, typename T2>
F
From00 已提交
1878
class MaxPool2dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1879
 public:
F
From00 已提交
1880 1881 1882
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1883
                  const std::vector<int>& strides,
F
From00 已提交
1884 1885 1886 1887
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
1888 1889 1890 1891
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1892 1893 1894
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1895 1896 1897 1898 1899 1900 1901
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1902
    const T1* input_data = input.data<T1>();
F
From00 已提交
1903 1904
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
1905 1906

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1907 1908
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
1909
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1910
#endif
C
chengduoZH 已提交
1911

F
feng_shuai 已提交
1912 1913 1914
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
L
limingshu 已提交
1915 1916 1917

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
Q
QI JUN 已提交
1918
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935
        nthreads,
        input_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        adaptive,
        output_data,
        mask_data,
        pool_divmods);
C
chengduoZH 已提交
1936 1937 1938
  }
};

C
chengduoZH 已提交
1939 1940 1941 1942 1943
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1944
template <typename T1, typename T2>
F
From00 已提交
1945
class MaxPool2dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1946
 public:
F
From00 已提交
1947 1948 1949 1950
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1951
                  const std::vector<int>& strides,
F
From00 已提交
1952 1953 1954
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
1955 1956 1957 1958
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1959 1960 1961 1962 1963 1964 1965 1966 1967
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1968 1969
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
F
From00 已提交
1970
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
1971 1972 1973 1974 1975 1976

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1977 1978
    auto pool_divmods =
        FastDivModForPooling(input_channels, input_width, input_height);
Q
QI JUN 已提交
1979
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
        nthreads,
        output_grad_data,
        mask_data,
        input_channels,
        input_height,
        input_width,
        output_height,
        output_width,
        ksize_height,
        ksize_width,
        stride_height,
        stride_width,
        padding_height,
        padding_width,
        adaptive,
        input_grad_data,
        pool_divmods);
C
chengduoZH 已提交
1997 1998 1999
  }
};

F
From00 已提交
2000 2001 2002 2003
template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
C
chengduoZH 已提交
2004

C
chengduoZH 已提交
2005
template <typename T1, typename T2>
F
From00 已提交
2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026
__global__ void KernelMaxPool3DWithIdx(const int nthreads,
                                       const T1* input_data,
                                       const int channels,
                                       const int input_depth,
                                       const int input_height,
                                       const int input_width,
                                       const int output_depth,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_depth,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_depth,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_depth,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
                                       T2* mask_data) {
C
chengduoZH 已提交
2027
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
2028 2029 2030 2031 2032 2033 2034
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
2035

2036 2037 2038 2039
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
2040 2041
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
2042

D
dengkaipeng 已提交
2043 2044
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
2045

D
dengkaipeng 已提交
2046 2047
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
2059

C
chengduoZH 已提交
2060
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
2061
    int max_index = -1;
C
chengduoZH 已提交
2062 2063 2064 2065 2066 2067 2068
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
2069 2070
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
2071 2072 2073 2074 2075
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
2076
    mask_data[index] = max_index;
C
chengduoZH 已提交
2077 2078 2079
  }
}

C
chengduoZH 已提交
2080
template <typename T1, typename T2>
F
From00 已提交
2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101
__global__ void KernelMaxPool3DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask,
                                           const int channels,
                                           const int input_depth,
                                           const int input_height,
                                           const int input_width,
                                           const int output_depth,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_depth,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_depth,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_depth,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad) {
C
chengduoZH 已提交
2102
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
2103
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
2104 2105 2106
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
L
limingshu 已提交
2107 2108
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
2109 2110
    int batch_idx = index / input_width / input_height / input_depth / channels;

2111 2112 2113 2114
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
2115 2116 2117 2118
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
2119
      phend =
D
dengkaipeng 已提交
2120 2121 2122 2123
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
2124 2125
    } else {
      pdstart =
D
dengkaipeng 已提交
2126
          (d_offset + padding_depth < ksize_depth)
2127
              ? 0
D
dengkaipeng 已提交
2128
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
2129
      phstart =
D
dengkaipeng 已提交
2130
          (h_offset + padding_height < ksize_height)
2131
              ? 0
D
dengkaipeng 已提交
2132
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
2133
      pwstart =
D
dengkaipeng 已提交
2134
          (w_offset + padding_width < ksize_width)
2135
              ? 0
D
dengkaipeng 已提交
2136 2137
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
2138
      phend =
D
dengkaipeng 已提交
2139 2140
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
2141
    }
C
chengduoZH 已提交
2142

L
limingshu 已提交
2143
    T1 input_grad_data = 0;
C
chengduoZH 已提交
2144
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
2145
        (d_offset * input_height + h_offset) * input_width + w_offset;
L
limingshu 已提交
2146
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
2147 2148 2149 2150
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

2151 2152 2153
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
2154 2155
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
L
limingshu 已提交
2156
            input_grad_data +=
C
chengduoZH 已提交
2157 2158 2159 2160
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
L
limingshu 已提交
2161
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
2162 2163 2164
  }
}

C
chengduoZH 已提交
2165 2166 2167 2168 2169
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2170
template <typename T1, typename T2>
F
From00 已提交
2171
class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2172
 public:
F
From00 已提交
2173 2174 2175
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2176
                  const std::vector<int>& strides,
F
From00 已提交
2177 2178 2179 2180
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
2181 2182 2183 2184 2185
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
2186 2187 2188 2189
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
2190 2191 2192 2193 2194 2195 2196 2197 2198 2199
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2200
    const T1* input_data = input.data<T1>();
F
From00 已提交
2201 2202
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
2203 2204 2205

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
2206 2207
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
F
From00 已提交
2208
    paddle::platform::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
2209 2210 2211 2212
#endif

    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
C
chengduoZH 已提交
2213 2214
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
2215
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235
        nthreads,
        input_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        adaptive,
        output_data,
2236
        mask_data);
C
chengduoZH 已提交
2237 2238 2239
  }
};

C
chengduoZH 已提交
2240 2241 2242 2243 2244
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2245
template <typename T1, typename T2>
F
From00 已提交
2246
class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2247
 public:
F
From00 已提交
2248 2249 2250 2251
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2252
                  const std::vector<int>& strides,
F
From00 已提交
2253 2254 2255
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
2256 2257 2258 2259 2260
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
2261 2262 2263
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
2264 2265 2266 2267 2268 2269 2270 2271 2272 2273
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2274 2275
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
F
From00 已提交
2276
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
2277 2278 2279 2280 2281 2282 2283

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
2284
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304
        nthreads,
        output_grad_data,
        mask_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        adaptive,
C
chengduoZH 已提交
2305
        input_grad_data);
C
chengduoZH 已提交
2306 2307 2308
  }
};

F
From00 已提交
2309 2310 2311 2312 2313 2314 2315
template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>;

}  // namespace funcs
}  // namespace phi