pooling.cu 102.3 KB
Newer Older
F
From00 已提交
1
/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
17

18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
L
limingshu 已提交
19
#include "paddle/fluid/platform/fast_divmod.h"
F
From00 已提交
20
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
21
#include "paddle/phi/kernels/funcs/pooling.h"
C
chengduoZH 已提交
22

F
From00 已提交
23 24
namespace phi {
namespace funcs {
C
chengduoZH 已提交
25

L
limingshu 已提交
26 27
struct FastDivModForPooling {
 public:
F
From00 已提交
28 29 30
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
L
limingshu 已提交
31 32 33 34

  explicit HOSTDEVICE FastDivModForPooling(const int channels,
                                           const int output_width,
                                           const int output_height) {
F
From00 已提交
35 36 37
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(output_width);
    height = paddle::platform::FastDivMod(output_height);
L
limingshu 已提交
38 39 40 41 42
  }
};

struct FastDivModForPoolingWithMoreStaff {
 public:
F
From00 已提交
43 44 45 46 47 48 49
  paddle::platform::FastDivMod channel;
  paddle::platform::FastDivMod width;
  paddle::platform::FastDivMod height;
  paddle::platform::FastDivMod ksize_w;
  paddle::platform::FastDivMod ksize_h;
  paddle::platform::FastDivMod stride_w;
  paddle::platform::FastDivMod stride_h;
L
limingshu 已提交
50 51

  explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
F
From00 已提交
52 53 54 55 56 57
      const int channels,
      const int input_width,
      const int input_height,
      const int ksize_width,
      const int ksize_height,
      const int stride_width,
L
limingshu 已提交
58
      const int stride_height) {
F
From00 已提交
59 60 61 62 63 64 65
    channel = paddle::platform::FastDivMod(channels);
    width = paddle::platform::FastDivMod(input_width);
    height = paddle::platform::FastDivMod(input_height);
    ksize_w = paddle::platform::FastDivMod(ksize_width);
    ksize_h = paddle::platform::FastDivMod(ksize_height);
    stride_w = paddle::platform::FastDivMod(stride_width);
    stride_h = paddle::platform::FastDivMod(stride_height);
L
limingshu 已提交
66 67 68 69
  }
};

template <typename FastDivModForPooling>
F
From00 已提交
70 71 72 73 74 75 76 77 78 79 80
__device__ void OffsetPreparationFor4Dimension(int index,
                                               bool channel_last,
                                               FastDivModForPooling divmods,
                                               const int pad_width,
                                               const int pad_height,
                                               const int aux_width,
                                               const int aux_height,
                                               int* w_offset,
                                               int* h_offset,
                                               int* c_offset,
                                               int* stride) {
L
limingshu 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  if (!channel_last) { /* NCHW */
    auto input_width_divmod = divmods.width.Divmod(index);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *c_offset = channel_divmod.val[1];
    *stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
              aux_height * aux_width;
  } else { /* NHWC */
    auto c_divmod = divmods.channel.Divmod(index);
    auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
    auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
    *c_offset = c_divmod.val[1];
    *w_offset = input_width_divmod.val[1] + pad_width;
    *h_offset = input_height_divmod.val[1] + pad_height;
    *stride = input_height_divmod.val[0] * aux_height * aux_width *
              divmods.channel.divisor;
  }
}

102
template <typename PoolProcess, typename T>
F
From00 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
__global__ void KernelPool2D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_height,
                             const int input_width,
                             const int output_height,
                             const int output_width,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_height,
                             const int stride_width,
                             const int padding_height,
                             const int padding_width,
                             FastDivModForPooling divmods,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
122 123
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
124 125
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
126 127 128 129 130 131 132 133 134 135 136
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
137
    input_data += input_offset;
138

D
dengkaipeng 已提交
139
    if (adaptive) {
L
limingshu 已提交
140 141 142 143
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
D
dengkaipeng 已提交
144
    } else {
L
limingshu 已提交
145
      hstart = h_offset * stride_height - padding_height;
146
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
147
      hstart = max(hstart, 0);
L
limingshu 已提交
148
      wstart = w_offset * stride_width - padding_width;
149
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
150 151
      wstart = max(wstart, 0);
    }
152

153
    T ele = pool_process.initial();
154 155
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
L
limingshu 已提交
156 157 158
        auto input_idx = channel_last
                             ? (h * input_width + w) * channels + c_offset
                             : h * input_width + w;
159
        pool_process.compute(input_data[input_idx], &ele);
160 161
      }
    }
D
dengkaipeng 已提交
162 163
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
164
    pool_process.finalize(static_cast<T>(pool_size), &ele);
165 166 167
    output_data[index] = ele;
  }
}
L
limingshu 已提交
168 169

template <typename T, typename PoolProcess>
F
From00 已提交
170 171 172
__global__ void KernelPool2DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
173
                                 const T* __restrict__ output_grad,
F
From00 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
                                 const int output_width,
                                 const int output_height,
                                 const int input_width,
                                 const int input_height,
                                 const int ksize_width,
                                 const int ksize_height,
                                 const int stride_width,
                                 const int stride_height,
                                 const int padding_width,
                                 const int padding_height,
                                 FastDivModForPoolingWithMoreStaff divmods,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* __restrict__ input_grad,
                                 bool channel_last = false) {
190 191
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
192 193 194 195
    T input = static_cast<T>(0);
    T input_grad_data = static_cast<T>(0);
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
196 197 198 199 200 201 202 203 204 205 206
    OffsetPreparationFor4Dimension<>(index,
                                     channel_last,
                                     divmods,
                                     padding_width,
                                     padding_height,
                                     output_width,
                                     output_height,
                                     &w_offset,
                                     &h_offset,
                                     &c_offset,
                                     &output_offset);
L
limingshu 已提交
207 208 209
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_offset;
210
    }
L
limingshu 已提交
211
    output_grad += output_offset;
212

213
    if (adaptive) {
L
limingshu 已提交
214 215 216 217 218 219
      auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
      auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
      phstart = divmods.height.Div(h_offset * output_height);
      pwstart = divmods.width.Div(w_offset * output_width);
      phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
      pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
220

L
limingshu 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
          auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
          auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
                                                     : ksize_w_divmod.val[0];
          auto tmp_height = ksize_h_divmod.val[1] > 0
                                ? ksize_h_divmod.val[0] + 1
                                : ksize_h_divmod.val[0];
          int pool_size = tmp_height * tmp_width;
          int tmp_idx = ph * output_width + pw;
          int output_sub_idx =
              channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                           : tmp_idx;
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
237 238 239
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
240 241 242 243
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
        }
      }
244
    } else {
L
limingshu 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
      auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
      auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
      phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
      pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
      phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
      pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

      if (exclusive) {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
268 269 270 271 272
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
273 274 275 276 277 278 279 280 281 282 283 284
          }
        }
      } else {
        for (int ph = phstart; ph < phend; ++ph) {
          for (int pw = pwstart; pw < pwend; ++pw) {
            int pool_size = ksize_height * ksize_width;
            int tmp_idx = ph * output_width + pw;
            int output_sub_idx =
                channel_last ? tmp_idx * divmods.channel.divisor + c_offset
                             : tmp_idx;
            T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                               : static_cast<T>(0);
F
From00 已提交
285 286 287 288 289
            pool_process.compute(input,
                                 ouput_value,
                                 output_grad[output_sub_idx],
                                 static_cast<T>(1.0 / pool_size),
                                 &input_grad_data);
L
limingshu 已提交
290
          }
291
        }
292 293
      }
    }
L
limingshu 已提交
294
    input_grad[index] = input_grad_data;
295 296 297
  }
}

298
template <typename T>
F
From00 已提交
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
__global__ void KernelMaxPool2DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_height,
                                    const int input_width,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    FastDivModForPooling divmods,
                                    bool channel_last = false) {
317 318
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
319
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
320 321 322 323 324 325 326 327 328 329 330
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         channel_last,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
331 332 333 334
    input_data += input_offset;
    input_grad += input_offset;

    int hstart = h_offset * stride_height - padding_height;
335 336 337
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

L
limingshu 已提交
338
    int wstart = w_offset * stride_width - padding_width;
339 340 341 342 343 344 345 346
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
L
limingshu 已提交
347 348 349
        int input_data_idx = channel_last
                                 ? (h * input_width + w) * channels + c_offset
                                 : h * input_width + w;
350 351
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
352 353 354 355 356 357 358
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
F
From00 已提交
359 360
      paddle::platform::CudaAtomicAdd(input_grad + maxIndex,
                                      output_grad[index]);
361 362 363 364
    }
  }
}

N
nhzlx 已提交
365 366
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
367 368 369 370 371 372 373 374 375 376
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
377
    PoolProcess pool_compute) {
N
nhzlx 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];
  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
392 393
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
394
  // backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
395 396 397 398
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
399 400
  dim3 grid(blocks, 1);

L
limingshu 已提交
401 402
  auto pool_divmods =
      FastDivModForPooling(input_channels, output_width, output_height);
F
From00 已提交
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_height,
                                                             input_width,
                                                             output_height,
                                                             output_width,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_height,
                                                             stride_width,
                                                             padding_height,
                                                             padding_width,
                                                             pool_divmods,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
N
nhzlx 已提交
421 422
}

C
chengduoZH 已提交
423
/*
424 425 426 427 428 429
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
430
template <typename PoolProcess, typename T>
F
From00 已提交
431
class Pool2dFunctor<phi::GPUContext, PoolProcess, T> {
432
 public:
F
From00 已提交
433 434 435
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
436
                  const std::vector<int>& strides,
F
From00 已提交
437 438 439 440
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
441
                  PoolProcess pool_process) {
442 443 444 445
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
446 447 448
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
449 450 451 452 453 454 455 456
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
457
    T* output_data = context.template Alloc<T>(output);
458 459

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
460 461
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
462
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
463 464 465
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
466
    dim3 grid(blocks, 1);
L
limingshu 已提交
467 468 469

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
    KernelPool2D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 pool_divmods,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data);
489
  }
F
From00 已提交
490 491 492
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
493 494
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
495 496 497 498 499
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
519

520 521 522 523
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
F
From00 已提交
524
    T* output_data = context.template Alloc<T>(output);
525 526

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
527 528
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
529
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
530 531 532
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
533
    dim3 grid(blocks, 1);
L
limingshu 已提交
534 535 536

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
    KernelPool2D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 pool_divmods,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data,
                                                 channel_last);
557 558
  }
};
C
chengduoZH 已提交
559
/*
560 561 562 563 564 565
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
566
template <typename PoolProcess, typename T>
F
From00 已提交
567
class Pool2dGradFunctor<phi::GPUContext, PoolProcess, T> {
568
 public:
F
From00 已提交
569 570 571 572
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
573 574
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
575 576 577 578
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
579
                  PoolProcess pool_process) {
580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
596
    T* input_grad_data = context.template Alloc<T>(input_grad);
597 598

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data);
630
  }
F
From00 已提交
631 632 633 634
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
635 636 637
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
638 639 640 641 642
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
668
    T* input_grad_data = context.template Alloc<T>(input_grad);
669 670

    int nthreads = batch_size * input_channels * input_height * input_width;
F
From00 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
    auto pool_divmods = FastDivModForPoolingWithMoreStaff(input_channels,
                                                          input_width,
                                                          input_height,
                                                          ksize_width,
                                                          ksize_height,
                                                          stride_width,
                                                          stride_height);

    auto config = phi::backends::gpu::GetGpuLaunchConfig1D(context, nthreads);
    KernelPool2DGrad<T, PoolProcess><<<config.block_per_grid,
                                       config.thread_per_block,
                                       0,
                                       context.stream()>>>(nthreads,
                                                           input_data,
                                                           output_data,
                                                           output_grad_data,
                                                           output_width,
                                                           output_height,
                                                           input_width,
                                                           input_height,
                                                           ksize_width,
                                                           ksize_height,
                                                           stride_width,
                                                           stride_height,
                                                           padding_width,
                                                           padding_height,
                                                           pool_divmods,
                                                           pool_process,
                                                           exclusive,
                                                           adaptive,
                                                           input_grad_data,
                                                           channel_last);
703
  }
704 705
};

C
chengduoZH 已提交
706
/*
707 708 709 710 711 712
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
713
template <typename T>
F
From00 已提交
714
class MaxPool2dGradFunctor<phi::GPUContext, T> {
715
 public:
F
From00 已提交
716 717 718 719
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
720 721 722
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
723
                  DenseTensor* input_grad) {
724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
741
    T* input_grad_data = context.template Alloc<T>(input_grad);
742 743 744 745 746 747

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
748 749
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767
    KernelMaxPool2DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data,
                                                 pool_divmods);
768
  }
F
From00 已提交
769 770 771 772 773 774 775 776 777
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
804
    T* input_grad_data = context.template Alloc<T>(input_grad);
805 806 807 808 809 810

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
811 812 813
    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);

814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
    KernelMaxPool2DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data,
                                                 pool_divmods,
                                                 channel_last);
833
  }
834 835
};

F
From00 已提交
836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863
template class Pool2dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool2dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool2dGradFunctor<phi::GPUContext, float>;
template class MaxPool2dGradFunctor<phi::GPUContext, double>;
template class MaxPool2dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool2dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool2dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool2dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool2dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool2dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
864

865
template <typename PoolProcess, typename T>
F
From00 已提交
866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888
__global__ void KernelPool3D(const int nthreads,
                             const T* input_data,
                             const int channels,
                             const int input_depth,
                             const int input_height,
                             const int input_width,
                             const int output_depth,
                             const int output_height,
                             const int output_width,
                             const int ksize_depth,
                             const int ksize_height,
                             const int ksize_width,
                             const int stride_depth,
                             const int stride_height,
                             const int stride_width,
                             const int padding_depth,
                             const int padding_height,
                             const int padding_width,
                             PoolProcess pool_process,
                             bool exclusive,
                             bool adaptive,
                             T* output_data,
                             bool channel_last = false) {
889
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
890
       index += blockDim.x * gridDim.x) {
891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
907 908 909 910 911

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
912 913
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
914

D
dengkaipeng 已提交
915 916
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
917

D
dengkaipeng 已提交
918 919
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
920 921 922 923 924 925 926 927 928 929 930
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
931 932 933 934 935 936 937 938 939 940 941

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

942
    T ele = pool_process.initial();
943 944 945
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
946 947 948 949 950
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
951 952 953
        }
      }
    }
954
    int pool_size = (exclusive || adaptive)
955 956
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
957
    pool_process.finalize(static_cast<T>(pool_size), &ele);
958 959 960 961
    output_data[index] = ele;
  }
}

L
limingshu 已提交
962
template <typename T, typename PoolProcess>
F
From00 已提交
963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987
__global__ void KernelPool3DGrad(const int nthreads,
                                 const T* __restrict__ input_data,
                                 const T* __restrict__ output_data,
                                 const T* __restrict__ output_grad,
                                 const int channels,
                                 const int input_depth,
                                 const int input_height,
                                 const int input_width,
                                 const int output_depth,
                                 const int output_height,
                                 const int output_width,
                                 const int ksize_depth,
                                 const int ksize_height,
                                 const int ksize_width,
                                 const int stride_depth,
                                 const int stride_height,
                                 const int stride_width,
                                 const int padding_depth,
                                 const int padding_height,
                                 const int padding_width,
                                 PoolProcess pool_process,
                                 bool exclusive,
                                 bool adaptive,
                                 T* input_grad,
                                 bool channel_last = false) {
988
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
989
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
990 991
    int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
    T input = static_cast<T>(0);
992 993 994 995 996
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
L
limingshu 已提交
997
      c_offset = (index / input_width / input_height / input_depth) % channels;
998
      batch_idx = index / input_width / input_height / input_depth / channels;
L
limingshu 已提交
999 1000
      output_stride = (batch_idx * channels + c_offset) * output_depth *
                      output_height * output_width;
1001
    } else { /* "NDHWC" */
L
limingshu 已提交
1002
      c_offset = index % channels;
1003 1004 1005 1006 1007 1008
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
L
limingshu 已提交
1009 1010
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
1011
    }
1012

1013 1014 1015 1016
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
1017 1018 1019 1020 1021 1022 1023 1024
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
1025
    } else {
D
dengkaipeng 已提交
1026
      pdstart = (d_offset < ksize_depth)
1027
                    ? 0
D
dengkaipeng 已提交
1028 1029
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
1030
                    ? 0
D
dengkaipeng 已提交
1031 1032
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
1033
                    ? 0
D
dengkaipeng 已提交
1034 1035 1036 1037
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
1038
    }
L
limingshu 已提交
1039 1040 1041
    if (pool_process.use_x) {
      input = input_data[index];
      output_data += output_stride;
1042 1043
    }
    output_grad += output_stride;
L
limingshu 已提交
1044
    T input_grad_data = static_cast<T>(0.0);
1045 1046 1047 1048 1049

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
1073 1074 1075 1076

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
L
limingshu 已提交
1077
                        c_offset
1078
                  : (pd * output_height + ph) * output_width + pw;
L
limingshu 已提交
1079 1080
          T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
                                             : static_cast<T>(0);
F
From00 已提交
1081 1082 1083
          pool_process.compute(input,
                               ouput_value,
                               output_grad[output_sub_idx],
L
limingshu 已提交
1084 1085
                               static_cast<T>(1.0 / pool_size),
                               &input_grad_data);
1086 1087 1088
        }
      }
    }
L
limingshu 已提交
1089
    input_grad[index] = input_grad_data;
1090 1091 1092
  }
}

1093
template <typename T>
F
From00 已提交
1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
__global__ void KernelMaxPool3DGrad(const int nthreads,
                                    const T* input_data,
                                    const T* output_data,
                                    const T* output_grad,
                                    const int channels,
                                    const int input_depth,
                                    const int input_height,
                                    const int input_width,
                                    const int output_depth,
                                    const int output_height,
                                    const int output_width,
                                    const int ksize_depth,
                                    const int ksize_height,
                                    const int ksize_width,
                                    const int stride_depth,
                                    const int stride_height,
                                    const int stride_width,
                                    const int padding_depth,
                                    const int padding_height,
                                    const int padding_width,
                                    T* input_grad,
                                    bool channel_last = false) {
1116
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
1117
       index += blockDim.x * gridDim.x) {
1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

1136 1137 1138
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
1139

1140 1141 1142
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
1143

1144 1145 1146
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
1147

1148 1149 1150 1151
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

1152 1153 1154 1155 1156 1157 1158 1159 1160 1161
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
1162 1163 1164
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
1165 1166 1167 1168 1169
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
1170
            stop = true;
1171
            maxIdx = input_data_idx;
1172 1173 1174 1175 1176 1177
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
F
From00 已提交
1178
      paddle::platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
1179 1180 1181 1182
    }
  }
}

F
feng_shuai 已提交
1183 1184
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
F
From00 已提交
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
    const T* input,
    const std::vector<int>& input_shape,
    const std::vector<int>& output_shape,
    const std::vector<int>& ksize,
    const std::vector<int>& strides,
    const std::vector<int>& paddings,
    bool exclusive,
    bool adaptive,
    T* output,
    gpuStream_t stream,
F
feng_shuai 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

F
From00 已提交
1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246
  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(nthreads,
                                                             input,
                                                             input_channels,
                                                             input_depth,
                                                             input_height,
                                                             input_width,
                                                             output_depth,
                                                             output_height,
                                                             output_width,
                                                             ksize_depth,
                                                             ksize_height,
                                                             ksize_width,
                                                             stride_depth,
                                                             stride_height,
                                                             stride_width,
                                                             padding_depth,
                                                             padding_height,
                                                             padding_width,
                                                             pool_compute,
                                                             exclusive,
                                                             adaptive,
                                                             output);
F
feng_shuai 已提交
1247 1248
}

C
chengduoZH 已提交
1249
/*
1250 1251 1252 1253 1254 1255 1256
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1257
template <typename PoolProcess, class T>
F
From00 已提交
1258
class Pool3dFunctor<phi::GPUContext, PoolProcess, T> {
1259
 public:
F
From00 已提交
1260 1261 1262
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1263
                  const std::vector<int>& strides,
F
From00 已提交
1264 1265 1266 1267
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
1268
                  PoolProcess pool_process) {
1269 1270 1271 1272 1273
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1274 1275 1276 1277
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1289
    T* output_data = context.template Alloc<T>(output);
1290 1291 1292

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1293 1294
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1295
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1296 1297 1298
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1299 1300
    dim3 grid(blocks, 1);

1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323
    KernelPool3D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data);
1324
  }
F
From00 已提交
1325 1326 1327
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
1328 1329
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1330 1331 1332 1333 1334
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* output,
                  PoolProcess pool_process) {
1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
F
From00 已提交
1365
    T* output_data = context.template Alloc<T>(output);
1366 1367 1368

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1369 1370
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1371
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1372 1373 1374
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1375 1376
    dim3 grid(blocks, 1);

1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400
    KernelPool3D<PoolProcess, T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 output_data,
                                                 channel_last);
1401
  }
1402 1403
};

C
chengduoZH 已提交
1404
/*
1405 1406 1407 1408 1409 1410 1411
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1412
template <typename PoolProcess, class T>
F
From00 已提交
1413
class Pool3dGradFunctor<phi::GPUContext, PoolProcess, T> {
1414
 public:
F
From00 已提交
1415 1416 1417 1418
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1419 1420
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
F
From00 已提交
1421 1422 1423 1424
                  const std::vector<int>& paddings,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
1425
                  PoolProcess pool_process) {
1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1448
    T* input_grad_data = context.template Alloc<T>(input_grad);
1449

1450 1451
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1452 1453 1454 1455
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480
    KernelPool3DGrad<T, PoolProcess>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 pool_process,
                                                 exclusive,
                                                 adaptive,
                                                 input_grad_data);
1481
  }
F
From00 已提交
1482 1483 1484 1485
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
1486 1487 1488
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1489 1490 1491 1492 1493
                  const std::string data_format,
                  bool exclusive,
                  bool adaptive,
                  DenseTensor* input_grad,
                  PoolProcess pool_process) {
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1524
    T* input_grad_data = context.template Alloc<T>(input_grad);
1525 1526 1527 1528 1529 1530 1531

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1532
    KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        pool_process,
        exclusive,
        adaptive,
        input_grad_data,
1557 1558
        channel_last);  // add channel_last
  }
1559 1560
};

C
chengduoZH 已提交
1561
/*
1562 1563 1564 1565 1566 1567 1568
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1569
template <class T>
F
From00 已提交
1570
class MaxPool3dGradFunctor<phi::GPUContext, T> {
1571
 public:
F
From00 已提交
1572 1573 1574 1575
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
C
chengduo 已提交
1576 1577 1578
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
F
From00 已提交
1579
                  DenseTensor* input_grad) {
1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1602
    T* input_grad_data = context.template Alloc<T>(input_grad);
1603 1604 1605 1606 1607 1608 1609

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631
    KernelMaxPool3DGrad<T>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 output_data,
                                                 output_grad_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 input_grad_data);
1632
  }
F
From00 已提交
1633 1634 1635 1636 1637 1638 1639 1640 1641
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format,
                  DenseTensor* input_grad) {
1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
F
From00 已提交
1672
    T* input_grad_data = context.template Alloc<T>(input_grad);
1673 1674 1675 1676 1677 1678 1679 1680

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
F
From00 已提交
1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702
        nthreads,
        input_data,
        output_data,
        output_grad_data,
        input_channels,
        input_depth,
        input_height,
        input_width,
        output_depth,
        output_height,
        output_width,
        ksize_depth,
        ksize_height,
        ksize_width,
        stride_depth,
        stride_height,
        stride_width,
        padding_depth,
        padding_height,
        padding_width,
        input_grad_data,
        channel_last);  // add channel_last
1703
  }
1704 1705
};

F
From00 已提交
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733
template class Pool3dDirectCUDAFunctor<MaxPool<float>, float>;
template class Pool3dDirectCUDAFunctor<AvgPool<float>, float>;

template class MaxPool3dGradFunctor<phi::GPUContext, float>;
template class MaxPool3dGradFunctor<phi::GPUContext, double>;
template class MaxPool3dGradFunctor<phi::GPUContext, dtype::float16>;

template class Pool3dFunctor<phi::GPUContext, MaxPool<float>, float>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<float>, float>;
template class Pool3dFunctor<phi::GPUContext, MaxPool<double>, double>;
template class Pool3dFunctor<phi::GPUContext, AvgPool<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<phi::GPUContext, AvgPoolGrad<double>, double>;

template class Pool3dFunctor<phi::GPUContext,
                             MaxPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dFunctor<phi::GPUContext,
                             AvgPool<dtype::float16>,
                             dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 MaxPoolGrad<dtype::float16>,
                                 dtype::float16>;
template class Pool3dGradFunctor<phi::GPUContext,
                                 AvgPoolGrad<dtype::float16>,
                                 dtype::float16>;
1734

C
chengduoZH 已提交
1735
template <typename T1, typename T2>
F
From00 已提交
1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752
__global__ void KernelMaxPool2dWithIdx(const int nthreads,
                                       const T1* input_data,
                                       const int channels,
                                       const int input_height,
                                       const int input_width,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
                                       T2* mask_data,
                                       FastDivModForPooling divmods) {
C
chengduoZH 已提交
1753
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1754
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1755 1756
    int hstart, hend, wstart, wend;
    int w_offset, h_offset, c_offset, input_offset;
F
From00 已提交
1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         input_width,
                                                         input_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &input_offset);
L
limingshu 已提交
1768
    input_data += input_offset;
C
chengduoZH 已提交
1769

1770
    if (adaptive) {
L
limingshu 已提交
1771 1772
      hstart = AdaptStartIndex(h_offset, input_height, output_height);
      hend = AdaptEndIndex(h_offset, input_height, output_height);
C
chengduoZH 已提交
1773

L
limingshu 已提交
1774 1775
      wstart = AdaptStartIndex(w_offset, input_width, output_width);
      wend = AdaptEndIndex(w_offset, input_width, output_width);
1776
    } else {
L
limingshu 已提交
1777
      hstart = h_offset * stride_height - padding_height;
1778 1779 1780
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

L
limingshu 已提交
1781
      wstart = w_offset * stride_width - padding_width;
1782 1783 1784
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1785

C
chengduoZH 已提交
1786
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1787
    int max_index = -1;
C
chengduoZH 已提交
1788 1789
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1790 1791 1792 1793
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1794 1795 1796 1797
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1798
    mask_data[index] = max_index;
C
chengduoZH 已提交
1799 1800 1801
  }
}

C
chengduoZH 已提交
1802
template <typename T1, typename T2>
F
From00 已提交
1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819
__global__ void KernelMaxPool2DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask_data,
                                           const int channels,
                                           const int input_height,
                                           const int input_width,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad,
                                           FastDivModForPooling divmods) {
C
chengduoZH 已提交
1820
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1821
       index += blockDim.x * gridDim.x) {
L
limingshu 已提交
1822 1823
    int phstart, phend, pwstart, pwend;
    int w_offset, h_offset, c_offset, output_offset;
F
From00 已提交
1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834
    OffsetPreparationFor4Dimension<FastDivModForPooling>(index,
                                                         false,
                                                         divmods,
                                                         0,
                                                         0,
                                                         output_width,
                                                         output_height,
                                                         &w_offset,
                                                         &h_offset,
                                                         &c_offset,
                                                         &output_offset);
L
limingshu 已提交
1835 1836
    mask_data += output_offset;
    output_grad += output_offset;
C
chengduoZH 已提交
1837

1838
    if (adaptive) {
D
dengkaipeng 已提交
1839
      phstart = h_offset * output_height / input_height;
1840
      phend =
D
dengkaipeng 已提交
1841 1842 1843 1844
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1845 1846
    } else {
      phstart =
D
dengkaipeng 已提交
1847
          (h_offset + padding_height < ksize_height)
1848
              ? 0
D
dengkaipeng 已提交
1849
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1850
      pwstart =
D
dengkaipeng 已提交
1851
          (w_offset + padding_width < ksize_width)
1852
              ? 0
D
dengkaipeng 已提交
1853
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1854
      phend =
D
dengkaipeng 已提交
1855 1856
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1857
    }
C
chengduoZH 已提交
1858

L
limingshu 已提交
1859
    T1 input_grad_data = 0;
D
dengkaipeng 已提交
1860
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
1861 1862
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1863
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
L
limingshu 已提交
1864
          input_grad_data += output_grad[ph * output_width + pw];
C
chengduoZH 已提交
1865 1866
      }
    }
L
limingshu 已提交
1867
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
1868 1869 1870
  }
}

C
chengduoZH 已提交
1871 1872 1873 1874 1875
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1876
template <typename T1, typename T2>
F
From00 已提交
1877
class MaxPool2dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1878
 public:
F
From00 已提交
1879 1880 1881
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1882
                  const std::vector<int>& strides,
F
From00 已提交
1883 1884 1885 1886
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
1887 1888 1889 1890
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1891 1892 1893
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1894 1895 1896 1897 1898 1899 1900
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1901
    const T1* input_data = input.data<T1>();
F
From00 已提交
1902 1903
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
1904 1905

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1906 1907
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
1908
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
1909
#endif
C
chengduoZH 已提交
1910

F
feng_shuai 已提交
1911 1912 1913
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
L
limingshu 已提交
1914 1915 1916

    auto pool_divmods =
        FastDivModForPooling(input_channels, output_width, output_height);
1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934
    KernelMaxPool2dWithIdx<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 output_data,
                                                 mask_data,
                                                 pool_divmods);
C
chengduoZH 已提交
1935 1936 1937
  }
};

C
chengduoZH 已提交
1938 1939 1940 1941 1942
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1943
template <typename T1, typename T2>
F
From00 已提交
1944
class MaxPool2dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
1945
 public:
F
From00 已提交
1946 1947 1948 1949
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
1950
                  const std::vector<int>& strides,
F
From00 已提交
1951 1952 1953
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
1954 1955 1956 1957
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1958 1959 1960 1961 1962 1963 1964 1965 1966
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1967 1968
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
F
From00 已提交
1969
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
1970 1971 1972 1973 1974 1975

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

L
limingshu 已提交
1976 1977
    auto pool_divmods =
        FastDivModForPooling(input_channels, input_width, input_height);
1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995
    KernelMaxPool2DWithIdxGrad<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 output_grad_data,
                                                 mask_data,
                                                 input_channels,
                                                 input_height,
                                                 input_width,
                                                 output_height,
                                                 output_width,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_height,
                                                 stride_width,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 input_grad_data,
                                                 pool_divmods);
C
chengduoZH 已提交
1996 1997 1998
  }
};

F
From00 已提交
1999 2000 2001 2002
template class MaxPool2dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool2dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool2dWithIndexGradFunctor<phi::GPUContext, double, int>;
C
chengduoZH 已提交
2003

C
chengduoZH 已提交
2004
template <typename T1, typename T2>
F
From00 已提交
2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025
__global__ void KernelMaxPool3DWithIdx(const int nthreads,
                                       const T1* input_data,
                                       const int channels,
                                       const int input_depth,
                                       const int input_height,
                                       const int input_width,
                                       const int output_depth,
                                       const int output_height,
                                       const int output_width,
                                       const int ksize_depth,
                                       const int ksize_height,
                                       const int ksize_width,
                                       const int stride_depth,
                                       const int stride_height,
                                       const int stride_width,
                                       const int padding_depth,
                                       const int padding_height,
                                       const int padding_width,
                                       bool adaptive,
                                       T1* output_data,
                                       T2* mask_data) {
C
chengduoZH 已提交
2026
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
2027 2028 2029 2030 2031 2032 2033
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
2034

2035 2036 2037 2038
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
2039 2040
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
2041

D
dengkaipeng 已提交
2042 2043
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
2044

D
dengkaipeng 已提交
2045 2046
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
2058

C
chengduoZH 已提交
2059
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
2060
    int max_index = -1;
C
chengduoZH 已提交
2061 2062 2063 2064 2065 2066 2067
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
2068 2069
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
2070 2071 2072 2073 2074
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
2075
    mask_data[index] = max_index;
C
chengduoZH 已提交
2076 2077 2078
  }
}

C
chengduoZH 已提交
2079
template <typename T1, typename T2>
F
From00 已提交
2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100
__global__ void KernelMaxPool3DWithIdxGrad(const int nthreads,
                                           const T1* output_grad,
                                           const T2* mask,
                                           const int channels,
                                           const int input_depth,
                                           const int input_height,
                                           const int input_width,
                                           const int output_depth,
                                           const int output_height,
                                           const int output_width,
                                           const int ksize_depth,
                                           const int ksize_height,
                                           const int ksize_width,
                                           const int stride_depth,
                                           const int stride_height,
                                           const int stride_width,
                                           const int padding_depth,
                                           const int padding_height,
                                           const int padding_width,
                                           bool adaptive,
                                           T1* input_grad) {
C
chengduoZH 已提交
2101
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
2102
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
2103 2104 2105
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
L
limingshu 已提交
2106 2107
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
2108 2109
    int batch_idx = index / input_width / input_height / input_depth / channels;

2110 2111 2112 2113
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
2114 2115 2116 2117
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
2118
      phend =
D
dengkaipeng 已提交
2119 2120 2121 2122
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
2123 2124
    } else {
      pdstart =
D
dengkaipeng 已提交
2125
          (d_offset + padding_depth < ksize_depth)
2126
              ? 0
D
dengkaipeng 已提交
2127
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
2128
      phstart =
D
dengkaipeng 已提交
2129
          (h_offset + padding_height < ksize_height)
2130
              ? 0
D
dengkaipeng 已提交
2131
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
2132
      pwstart =
D
dengkaipeng 已提交
2133
          (w_offset + padding_width < ksize_width)
2134
              ? 0
D
dengkaipeng 已提交
2135 2136
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
2137
      phend =
D
dengkaipeng 已提交
2138 2139
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
2140
    }
C
chengduoZH 已提交
2141

L
limingshu 已提交
2142
    T1 input_grad_data = 0;
C
chengduoZH 已提交
2143
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
2144
        (d_offset * input_height + h_offset) * input_width + w_offset;
L
limingshu 已提交
2145
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
2146 2147 2148 2149
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

2150 2151 2152
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
2153 2154
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
L
limingshu 已提交
2155
            input_grad_data +=
C
chengduoZH 已提交
2156 2157 2158 2159
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
L
limingshu 已提交
2160
    input_grad[index] = input_grad_data;
C
chengduoZH 已提交
2161 2162 2163
  }
}

C
chengduoZH 已提交
2164 2165 2166 2167 2168
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2169
template <typename T1, typename T2>
F
From00 已提交
2170
class MaxPool3dWithIndexFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2171
 public:
F
From00 已提交
2172 2173 2174
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& input,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2175
                  const std::vector<int>& strides,
F
From00 已提交
2176 2177 2178 2179
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* output,
                  DenseTensor* mask) {
C
chengduoZH 已提交
2180 2181 2182 2183 2184
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
2185 2186 2187 2188
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
2189 2190 2191 2192 2193 2194 2195 2196 2197 2198
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2199
    const T1* input_data = input.data<T1>();
F
From00 已提交
2200 2201
    T1* output_data = context.template Alloc<T1>(output);
    T2* mask_data = context.template Alloc<T2>(mask);
C
chengduoZH 已提交
2202 2203 2204

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
2205 2206
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
2207
    backends::gpu::ChangeThreadNum(context, &thread_num);
F
feng_shuai 已提交
2208 2209 2210 2211
#endif

    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
C
chengduoZH 已提交
2212 2213
    dim3 grid(blocks, 1);

2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235
    KernelMaxPool3DWithIdx<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 input_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 output_data,
                                                 mask_data);
C
chengduoZH 已提交
2236 2237 2238
  }
};

C
chengduoZH 已提交
2239 2240 2241 2242 2243
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
2244
template <typename T1, typename T2>
F
From00 已提交
2245
class MaxPool3dWithIndexGradFunctor<phi::GPUContext, T1, T2> {
C
chengduoZH 已提交
2246
 public:
F
From00 已提交
2247 2248 2249 2250
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& output_grad,
                  const DenseTensor& mask,
                  const std::vector<int>& ksize,
C
chengduo 已提交
2251
                  const std::vector<int>& strides,
F
From00 已提交
2252 2253 2254
                  const std::vector<int>& paddings,
                  bool adaptive,
                  DenseTensor* input_grad) {
C
chengduoZH 已提交
2255 2256 2257 2258 2259
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
2260 2261 2262
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
2263 2264 2265 2266 2267 2268 2269 2270 2271 2272
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
2273 2274
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
F
From00 已提交
2275
    T1* input_grad_data = context.template Alloc<T1>(input_grad);
C
chengduoZH 已提交
2276 2277 2278 2279 2280 2281 2282

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304
    KernelMaxPool3DWithIdxGrad<T1, T2>
        <<<grid, threads, 0, context.stream()>>>(nthreads,
                                                 output_grad_data,
                                                 mask_data,
                                                 input_channels,
                                                 input_depth,
                                                 input_height,
                                                 input_width,
                                                 output_depth,
                                                 output_height,
                                                 output_width,
                                                 ksize_depth,
                                                 ksize_height,
                                                 ksize_width,
                                                 stride_depth,
                                                 stride_height,
                                                 stride_width,
                                                 padding_depth,
                                                 padding_height,
                                                 padding_width,
                                                 adaptive,
                                                 input_grad_data);
C
chengduoZH 已提交
2305 2306 2307
  }
};

F
From00 已提交
2308 2309 2310 2311 2312 2313 2314
template class MaxPool3dWithIndexFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, float, int>;
template class MaxPool3dWithIndexFunctor<phi::GPUContext, double, int>;
template class MaxPool3dWithIndexGradFunctor<phi::GPUContext, double, int>;

}  // namespace funcs
}  // namespace phi