pooling.cu 74.2 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
17

Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
F
feng_shuai 已提交
20
#include "paddle/fluid/platform/gpu_launch_config.h"
C
chengduoZH 已提交
21 22 23 24 25

namespace paddle {
namespace operators {
namespace math {

26
template <typename PoolProcess, typename T>
27
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
28 29 30 31 32 33
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
34 35
                             bool exclusive, bool adaptive, T* output_data,
                             bool channel_last = false) {
36 37
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
38 39 40 41 42 43 44 45 46 47 48 49
    int pw, ph, c, batch_idx;
    if (!channel_last) { /*NCHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      c = (index / output_width / output_height) % channels;
      batch_idx = index / output_width / output_height / channels;
    } else { /*NHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      batch_idx = index / channels / output_width / output_height;
    }
50

51 52
    int hstart, hend;
    int wstart, wend;
D
dengkaipeng 已提交
53
    if (adaptive) {
D
dengkaipeng 已提交
54 55
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
56

D
dengkaipeng 已提交
57 58
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
D
dengkaipeng 已提交
59
    } else {
60 61
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
62 63
      hstart = max(hstart, 0);

64 65
      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
66 67
      wstart = max(wstart, 0);
    }
68

69 70 71 72 73
    if (!channel_last) {
      input_data += (batch_idx * channels + c) * input_height * input_width;
    } else {
      input_data += batch_idx * input_height * input_width * channels;
    }
74
    T ele = pool_process.initial();
75 76
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
77 78 79
        auto input_idx = channel_last ? (h * input_width + w) * channels + c
                                      : h * input_width + w;
        pool_process.compute(input_data[input_idx], &ele);
80 81
      }
    }
D
dengkaipeng 已提交
82 83
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
84
    pool_process.finalize(static_cast<T>(pool_size), &ele);
85 86 87 88
    output_data[index] = ele;
  }
}
template <typename PoolProcess, typename T>
89
__global__ void KernelPool2DGrad(
90
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
91 92 93 94
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
95 96
    PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad,
    bool channel_last = false) {
97 98
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
99 100 101 102 103 104 105 106 107 108 109 110 111
    int w_offset, h_offset, offsetC, batch_idx;
    if (!channel_last) { /* NCHW */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      offsetC = (index / input_width / input_height) % channels;
      batch_idx = index / input_width / input_height / channels;
    } else { /* NHWC */
      offsetC = index % channels;
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      batch_idx = index / channels / input_width / input_height;
    }
112

113 114 115
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
116 117 118 119 120
      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
121
    } else {
D
dengkaipeng 已提交
122
      phstart = (h_offset < ksize_height)
123
                    ? 0
D
dengkaipeng 已提交
124 125
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
126
                    ? 0
D
dengkaipeng 已提交
127 128 129
                    : (w_offset - ksize_width) / stride_width + 1;
      phend = min(h_offset / stride_height + 1, output_height);
      pwend = min(w_offset / stride_width + 1, output_width);
130
    }
131
    T gradient = static_cast<T>(0.0);
132
    T input = input_data[index];
133 134 135 136 137 138 139 140 141 142 143 144

    int output_stride;
    if (!channel_last) {
      output_stride =
          (batch_idx * channels + offsetC) * output_height * output_width;
    } else {
      output_stride = batch_idx * output_height * output_width * channels;
    }

    output_data += output_stride;
    output_grad += output_stride;

145 146
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
        int pool_size;
        if (adaptive) {
          pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
                                            ksize_height)) *
                      static_cast<int>(
                          ceil(static_cast<double>(input_width) / ksize_width));
        } else {
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          pool_size = exclusive ? (hend - hstart) * (wend - wstart)
                                : ksize_height * ksize_width;
        }
163

164 165 166
        int output_sub_idx = channel_last
                                 ? (ph * output_width + pw) * channels + offsetC
                                 : ph * output_width + pw;
167
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
168 169
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
170 171 172 173 174 175
      }
    }
    input_grad[index] = gradient;
  }
}

176
template <typename T>
177
__global__ void KernelMaxPool2DGrad(
178
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
179 180 181 182
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
183
    T* input_grad, bool channel_last = false) {
184 185
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
186 187 188 189 190 191 192 193 194 195 196 197
    int pw, ph, c, batch_idx;
    if (!channel_last) { /* NCHW */
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      c = (index / output_width / output_height) % channels;
      batch_idx = index / output_width / output_height / channels;
    } else { /* NHWC */
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      batch_idx = index / channels / output_width / output_height;
    }
198 199 200 201 202 203 204 205
    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

206 207 208 209 210 211 212 213
    int input_stride;
    if (!channel_last) {
      input_stride = (batch_idx * channels + c) * input_height * input_width;
    } else {
      input_stride = batch_idx * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
214 215 216 217 218 219

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
220 221 222 223
        int input_data_idx = channel_last ? (h * input_width + w) * channels + c
                                          : h * input_width + w;
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
224 225 226 227 228 229 230
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
231
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
232 233 234 235
    }
  }
}

N
nhzlx 已提交
236 237 238 239 240
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
241 242
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
N
nhzlx 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
258 259 260 261 262 263 264
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  // platform::ChangeThreadNum(context, &thread_num);
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
N
nhzlx 已提交
265 266 267 268 269
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
270
      padding_height, padding_width, pool_compute, exclusive, adaptive, output);
N
nhzlx 已提交
271 272
}

C
chengduoZH 已提交
273
/*
274 275 276 277 278 279
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
280
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
281
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
282
 public:
Q
QI JUN 已提交
283
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
284 285
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
286 287 288
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
289 290 291 292
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
293 294 295
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
296 297 298 299 300 301 302 303
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
304
    T* output_data = output->mutable_data<T>(context.GetPlace());
305 306

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
307 308 309 310 311 312
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
313
    dim3 grid(blocks, 1);
Q
QI JUN 已提交
314
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
315 316
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
317
        stride_width, padding_height, padding_width, pool_process, exclusive,
318
        adaptive, output_data);
319
  }
320 321 322 323
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
324 325
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
345

346 347 348 349 350 351 352
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
353 354 355 356 357 358
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
359 360 361 362 363 364 365 366
    dim3 grid(blocks, 1);
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
};
C
chengduoZH 已提交
367
/*
368 369 370 371 372 373
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
374
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
375
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
376
 public:
Q
QI JUN 已提交
377
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
378
                  const framework::Tensor& input,
379
                  const framework::Tensor& output,
C
chengduo 已提交
380 381 382
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
383 384 385
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
402
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
403 404 405 406 407 408

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
409
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
410 411 412
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
413
        pool_process, exclusive, adaptive, input_grad_data);
414
  }
415 416 417 418 419 420 421 422 423
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();

    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        pool_process, exclusive, adaptive, input_grad_data, channel_last);
  }
464 465
};

C
chengduoZH 已提交
466
/*
467 468 469 470 471 472
 * Tensors are in NCHW or NHWC format.
 * Ksize, strides are two elements. These two elements represent height
 * and width, respectively.
 * Paddings are four elements. These four elements represent height_up,
 * height_down, width_left and width_right, respectively.
 */
473
template <typename T>
Q
QI JUN 已提交
474
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
475
 public:
Q
QI JUN 已提交
476
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
477
                  const framework::Tensor& input,
478
                  const framework::Tensor& output,
C
chengduo 已提交
479 480 481 482
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
483
                  framework::Tensor* input_grad) {
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
501
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
502 503 504 505 506 507

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
508
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
509 510 511 512
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
513
  }
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data, channel_last);
  }
559 560
};

N
nhzlx 已提交
561 562 563 564 565
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
566 567
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
568 569
template class MaxPool2dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
570

Q
QI JUN 已提交
571
template class Pool2dFunctor<platform::CUDADeviceContext,
572
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
573
template class Pool2dFunctor<platform::CUDADeviceContext,
574
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
575 576 577 578 579 580 581
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
582
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
583
template class Pool2dFunctor<platform::CUDADeviceContext,
584
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
585 586 587 588 589 590
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
591

592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool2dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

609
template <typename PoolProcess, typename T>
610
__global__ void KernelPool3D(
611 612 613
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
614
    const int ksize_depth, const int ksize_height, const int ksize_width,
615
    const int stride_depth, const int stride_height, const int stride_width,
616
    const int padding_depth, const int padding_height, const int padding_width,
617 618
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
619
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
620
       index += blockDim.x * gridDim.x) {
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
637 638 639 640 641

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
642 643
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
644

D
dengkaipeng 已提交
645 646
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
647

D
dengkaipeng 已提交
648 649
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
650 651 652 653 654 655 656 657 658 659 660
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
661 662 663 664 665 666 667 668 669 670 671

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

672
    T ele = pool_process.initial();
673 674 675
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
676 677 678 679 680
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
681 682 683
        }
      }
    }
684
    int pool_size = (exclusive || adaptive)
685 686
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
687
    pool_process.finalize(static_cast<T>(pool_size), &ele);
688 689 690 691 692
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
693
__global__ void KernelPool3DGrad(
694
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
695 696 697 698 699 700
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
701
    bool exclusive, bool adaptive, T* input_grad, bool channel_last = false) {
702
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
703
       index += blockDim.x * gridDim.x) {
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
    int w_offset, h_offset, d_offset, offsetC, batch_idx;
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
      offsetC = (index / input_width / input_height / input_depth) % channels;
      batch_idx = index / input_width / input_height / input_depth / channels;

    } else { /* "NDHWC" */
      offsetC = index % channels;
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
    }
722

723 724 725 726
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
727 728 729 730 731 732 733 734
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
735
    } else {
D
dengkaipeng 已提交
736
      pdstart = (d_offset < ksize_depth)
737
                    ? 0
D
dengkaipeng 已提交
738 739
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
740
                    ? 0
D
dengkaipeng 已提交
741 742
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
743
                    ? 0
D
dengkaipeng 已提交
744 745 746 747
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
748
    }
749

750
    T gradient = static_cast<T>(0.0);
751
    T input = input_data[index];
752 753 754 755 756 757 758 759 760 761 762

    int output_stride;
    if (!channel_last) {
      output_stride = (batch_idx * channels + offsetC) * output_depth *
                      output_height * output_width;
    } else {
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
    }
    output_data += output_stride;
    output_grad += output_stride;
763 764 765 766 767

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
791 792 793 794 795 796 797

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
                        offsetC
                  : (pd * output_height + ph) * output_width + pw;

798
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
799 800
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
801 802 803 804 805 806 807
        }
      }
    }
    input_grad[index] = gradient;
  }
}

808
template <typename T>
809
__global__ void KernelMaxPool3DGrad(
810
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
811 812 813 814 815
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
816 817
    const int padding_height, const int padding_width, T* input_grad,
    bool channel_last = false) {
818
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
819
       index += blockDim.x * gridDim.x) {
820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

838 839 840
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
841

842 843 844
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
845

846 847 848
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
849

850 851 852 853
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

854 855 856 857 858 859 860 861 862 863
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
864 865 866
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
867 868 869 870 871
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
872
            stop = true;
873
            maxIdx = input_data_idx;
874 875 876 877 878 879
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
880
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
881 882 883 884
    }
  }
}

F
feng_shuai 已提交
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
template <typename PoolProcess, typename T>
void Pool3dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    bool exclusive, bool adaptive, T* output, gpuStream_t stream,
    PoolProcess pool_compute) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_depth = input_shape[2];
  const int input_height = input_shape[3];
  const int input_width = input_shape[4];
  const int output_channels = output_shape[1];
  const int output_depth = output_shape[2];
  const int output_height = output_shape[3];
  const int output_width = output_shape[4];
  const int ksize_depth = ksize[0];
  const int ksize_height = ksize[1];
  const int ksize_width = ksize[2];
  const int stride_depth = strides[0];
  const int stride_height = strides[1];
  const int stride_width = strides[2];
  const int padding_depth = paddings[0];
  const int padding_height = paddings[1];
  const int padding_width = paddings[2];

  int nthreads = batch_size * output_channels * output_depth * output_height *
                 output_width;
  int thread_num = 1024;
#ifdef WITH_NV_JETSON
  thread_num = 512;
#endif
  int blocks = (nthreads + thread_num - 1) / thread_num;
  dim3 threads(thread_num, 1);
  dim3 grid(blocks, 1);

  KernelPool3D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_depth, input_height, input_width,
      output_depth, output_height, output_width, ksize_depth, ksize_height,
      ksize_width, stride_depth, stride_height, stride_width, padding_depth,
      padding_height, padding_width, pool_compute, exclusive, adaptive, output);
}

C
chengduoZH 已提交
928
/*
929 930 931 932 933 934 935
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
936
template <typename PoolProcess, class T>
Q
QI JUN 已提交
937
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
938
 public:
Q
QI JUN 已提交
939
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
940 941
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
942 943 944
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* output,
                  PoolProcess pool_process) {
945 946 947 948 949
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
950 951 952 953
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
954 955 956 957 958 959 960 961 962 963 964
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
965
    T* output_data = output->mutable_data<T>(context.GetPlace());
966 967 968

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
969 970 971 972 973 974
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
975 976
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
977
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
978 979 980
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
981
        padding_depth, padding_height, padding_width, pool_process, exclusive,
982
        adaptive, output_data);
983
  }
984 985 986 987
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
988 989
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* output, PoolProcess pool_process) {
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1024 1025 1026 1027 1028 1029
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
1030 1031 1032 1033 1034 1035 1036 1037 1038
    dim3 grid(blocks, 1);

    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
1039 1040
};

C
chengduoZH 已提交
1041
/*
1042 1043 1044 1045 1046 1047 1048
 * Tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1049
template <typename PoolProcess, class T>
Q
QI JUN 已提交
1050
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
1051
 public:
Q
QI JUN 已提交
1052
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1053
                  const framework::Tensor& input,
1054
                  const framework::Tensor& output,
C
chengduo 已提交
1055 1056 1057
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1058 1059 1060
                  const std::vector<int>& paddings, bool exclusive,
                  bool adaptive, framework::Tensor* input_grad,
                  PoolProcess pool_process) {
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1083
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1084

1085 1086
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1087 1088 1089 1090
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1091
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1092 1093 1094 1095
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
1096
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
1097
  }
1098 1099 1100 1101 1102 1103 1104 1105 1106
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, bool exclusive, bool adaptive,
                  framework::Tensor* input_grad, PoolProcess pool_process) {
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);  // add channel_last
  }
1153 1154
};

C
chengduoZH 已提交
1155
/*
1156 1157 1158 1159 1160 1161 1162
 * tensors are in NCDHW or NDHWC format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 * Paddings are six elements. These six elements represent depth_forth,
 * depth_back,
 * height_up, height_down, width_left and width_right, respectively.
 */
1163
template <class T>
Q
QI JUN 已提交
1164
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
1165
 public:
Q
QI JUN 已提交
1166
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1167
                  const framework::Tensor& input,
1168
                  const framework::Tensor& output,
C
chengduo 已提交
1169 1170 1171 1172
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1173
                  framework::Tensor* input_grad) {
1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1196
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1197 1198 1199 1200 1201 1202 1203

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1204
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1205 1206 1207 1208 1209
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
1210
  }
1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data, channel_last);  // add channel_last
  }
1262 1263
};

F
feng_shuai 已提交
1264 1265 1266 1267 1268
template class Pool3dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool3dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
1269 1270
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
1271 1272
template class MaxPool3dGradFunctor<platform::CUDADeviceContext,
                                    paddle::platform::float16>;
C
chengduoZH 已提交
1273

Q
QI JUN 已提交
1274
template class Pool3dFunctor<platform::CUDADeviceContext,
1275
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
1276
template class Pool3dFunctor<platform::CUDADeviceContext,
1277
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
1278 1279 1280 1281 1282 1283 1284
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
1285
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
1286
template class Pool3dFunctor<platform::CUDADeviceContext,
1287
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
1288 1289 1290 1291 1292 1293
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
1294

1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPool<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::MaxPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;
template class Pool3dGradFunctor<
    platform::CUDADeviceContext,
    paddle::operators::math::AvgPoolGrad<paddle::platform::float16>,
    paddle::platform::float16>;

C
chengduoZH 已提交
1312
template <typename T1, typename T2>
C
chengduoZH 已提交
1313
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
1314
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1315 1316 1317
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
1318
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1319
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1320
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
1321 1322 1323 1324 1325
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

1326 1327 1328
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1329 1330
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
C
chengduoZH 已提交
1331

D
dengkaipeng 已提交
1332 1333
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1334 1335 1336 1337 1338 1339 1340 1341 1342
    } else {
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1343 1344

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
1345
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1346
    int max_index = -1;
C
chengduoZH 已提交
1347 1348
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1349 1350 1351 1352
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1353 1354 1355 1356
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1357
    mask_data[index] = max_index;
C
chengduoZH 已提交
1358 1359 1360
  }
}

C
chengduoZH 已提交
1361
template <typename T1, typename T2>
C
chengduoZH 已提交
1362
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
1363
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
1364 1365 1366
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
1367 1368
    const int padding_height, const int padding_width, bool adaptive,
    T1* input_grad) {
C
chengduoZH 已提交
1369
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1370
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1371 1372
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
1373
    int offsetC = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
1374 1375
    int batch_idx = index / input_width / input_height / channels;

1376 1377 1378
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1379
      phstart = h_offset * output_height / input_height;
1380
      phend =
D
dengkaipeng 已提交
1381 1382 1383 1384
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1385 1386
    } else {
      phstart =
D
dengkaipeng 已提交
1387
          (h_offset + padding_height < ksize_height)
1388
              ? 0
D
dengkaipeng 已提交
1389
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1390
      pwstart =
D
dengkaipeng 已提交
1391
          (w_offset + padding_width < ksize_width)
1392
              ? 0
D
dengkaipeng 已提交
1393
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1394
      phend =
D
dengkaipeng 已提交
1395 1396
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1397
    }
C
chengduoZH 已提交
1398

C
chengduoZH 已提交
1399
    T1 gradient = 0;
D
dengkaipeng 已提交
1400
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
1401
    int output_idx =
1402
        (batch_idx * channels + offsetC) * output_height * output_width;
C
chengduoZH 已提交
1403

C
chengduoZH 已提交
1404 1405
    mask_data += output_idx;
    output_grad += output_idx;
1406 1407
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1408
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
1409 1410 1411 1412 1413 1414 1415
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1416 1417 1418 1419 1420
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1421
template <typename T1, typename T2>
Q
QI JUN 已提交
1422
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1423
 public:
Q
QI JUN 已提交
1424
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1425 1426
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1427 1428
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1429 1430 1431 1432
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1433 1434 1435
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1436 1437 1438 1439 1440 1441 1442
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1443 1444 1445
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1446 1447

    int nthreads = batch_size * output_channels * output_height * output_width;
F
feng_shuai 已提交
1448 1449 1450 1451
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif
C
chengduoZH 已提交
1452

F
feng_shuai 已提交
1453 1454 1455
    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
    dim3 grid(blocks, 1);
Q
QI JUN 已提交
1456
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1457 1458
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
1459 1460
        stride_width, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1461 1462 1463
  }
};

C
chengduoZH 已提交
1464 1465 1466 1467 1468
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1469
template <typename T1, typename T2>
Q
QI JUN 已提交
1470
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1471
 public:
Q
QI JUN 已提交
1472
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1473
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1474 1475
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1476
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1477 1478 1479 1480 1481
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1482 1483 1484 1485 1486 1487 1488 1489 1490
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1491 1492 1493
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1494 1495 1496 1497 1498 1499

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1500
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1501 1502
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
1503
        stride_height, stride_width, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1504
        input_grad_data);
C
chengduoZH 已提交
1505 1506 1507
  }
};

Q
QI JUN 已提交
1508 1509 1510 1511 1512 1513 1514 1515
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1516

C
chengduoZH 已提交
1517
template <typename T1, typename T2>
C
chengduoZH 已提交
1518
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
1519
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1520 1521 1522 1523 1524
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
1525
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1526
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1527 1528 1529 1530 1531 1532 1533
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
1534

1535 1536 1537 1538
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1539 1540
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
1541

D
dengkaipeng 已提交
1542 1543
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
1544

D
dengkaipeng 已提交
1545 1546
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1558

C
chengduoZH 已提交
1559
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1560
    int max_index = -1;
C
chengduoZH 已提交
1561 1562 1563 1564 1565 1566 1567
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1568 1569
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1570 1571 1572 1573 1574
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1575
    mask_data[index] = max_index;
C
chengduoZH 已提交
1576 1577 1578
  }
}

C
chengduoZH 已提交
1579
template <typename T1, typename T2>
C
chengduoZH 已提交
1580
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1581 1582 1583 1584 1585 1586
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1587
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1588
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1589
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1590 1591 1592
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
1593
    int offsetC = (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1594 1595
    int batch_idx = index / input_width / input_height / input_depth / channels;

1596 1597 1598 1599
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1600 1601 1602 1603
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
1604
      phend =
D
dengkaipeng 已提交
1605 1606 1607 1608
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1609 1610
    } else {
      pdstart =
D
dengkaipeng 已提交
1611
          (d_offset + padding_depth < ksize_depth)
1612
              ? 0
D
dengkaipeng 已提交
1613
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
1614
      phstart =
D
dengkaipeng 已提交
1615
          (h_offset + padding_height < ksize_height)
1616
              ? 0
D
dengkaipeng 已提交
1617
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1618
      pwstart =
D
dengkaipeng 已提交
1619
          (w_offset + padding_width < ksize_width)
1620
              ? 0
D
dengkaipeng 已提交
1621 1622
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
1623
      phend =
D
dengkaipeng 已提交
1624 1625
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1626
    }
C
chengduoZH 已提交
1627

C
chengduoZH 已提交
1628
    T1 gradient = 0;
C
chengduoZH 已提交
1629
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
1630
        (d_offset * input_height + h_offset) * input_width + w_offset;
1631
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
C
chengduoZH 已提交
1632 1633 1634 1635
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1636 1637 1638
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1639 1640
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
1641 1642 1643 1644 1645 1646 1647 1648 1649
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1650 1651 1652 1653 1654
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1655
template <typename T1, typename T2>
Q
QI JUN 已提交
1656
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1657
 public:
Q
QI JUN 已提交
1658
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1659 1660
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1661 1662
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1663 1664 1665 1666 1667
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1668 1669 1670 1671
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1672 1673 1674 1675 1676 1677 1678 1679 1680 1681
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1682 1683 1684
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1685 1686 1687

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
F
feng_shuai 已提交
1688 1689 1690 1691 1692 1693 1694
    int thread_num = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &thread_num);
#endif

    int blocks = (nthreads + thread_num - 1) / thread_num;
    dim3 threads(thread_num, 1);
C
chengduoZH 已提交
1695 1696
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1697
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1698 1699 1700
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1701 1702
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1703 1704 1705
  }
};

C
chengduoZH 已提交
1706 1707 1708 1709 1710
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1711
template <typename T1, typename T2>
Q
QI JUN 已提交
1712
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1713
 public:
Q
QI JUN 已提交
1714
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1715
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1716 1717
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1718
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1719 1720 1721 1722 1723 1724
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1725 1726 1727
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1738 1739 1740
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1741 1742 1743 1744 1745 1746 1747

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1748
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1749 1750 1751
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1752
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1753
        input_grad_data);
C
chengduoZH 已提交
1754 1755 1756
  }
};

Q
QI JUN 已提交
1757 1758 1759 1760 1761 1762 1763 1764
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1765 1766 1767 1768

}  // namespace math
}  // namespace operators
}  // namespace paddle