pooling.cu 52.1 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24 25 26 27 28 29 30 31 32 33 34 35
__device__ __forceinline__ int ADAPT_START_INDEX(int ph, int input_size,
                                                 int output_size) {
  return static_cast<int>(
      floor(static_cast<double>(ph * input_size) / output_size));
}

__device__ __forceinline__ int ADAPT_END_INDEX(int ph, int input_size,
                                               int output_size) {
  return static_cast<int>(
      ceil(static_cast<double>((ph + 1) * input_size) / output_size));
}

36
template <typename PoolProcess, typename T>
37
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
38 39 40 41 42 43
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
D
dengkaipeng 已提交
44
                             bool exclusive, bool adaptive, T* output_data) {
45 46
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
47 48 49 50 51
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

52 53
    int hstart, hend;
    int wstart, wend;
D
dengkaipeng 已提交
54
    if (adaptive) {
55 56
      hstart = ADAPT_START_INDEX(ph, input_height, output_height);
      hend = ADAPT_END_INDEX(ph, input_height, output_height);
57

58 59
      wstart = ADAPT_START_INDEX(pw, input_width, output_width);
      wend = ADAPT_END_INDEX(pw, input_width, output_width);
D
dengkaipeng 已提交
60
    } else {
61 62
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
63 64
      hstart = max(hstart, 0);

65 66
      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
67 68
      wstart = max(wstart, 0);
    }
69 70

    input_data += (batch_idx * channels + c) * input_height * input_width;
71
    T ele = pool_process.initial();
72 73
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduo 已提交
74
        pool_process.compute(input_data[h * input_width + w], &ele);
75 76
      }
    }
D
dengkaipeng 已提交
77 78
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
79
    pool_process.finalize(static_cast<T>(pool_size), &ele);
80 81 82 83 84
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
85
__global__ void KernelPool2DGrad(
86
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
87 88 89 90
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
91
    PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad) {
92 93
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
94 95 96 97 98
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
      phstart = offsetH * output_height / input_height;
      phend =
          min((offsetH + 1) * output_height / input_height + 1, output_height);
      pwstart = offsetW * output_width / input_width;
      pwend = min((offsetW + 1) * output_width / input_width + 1, output_width);
    } else {
      phstart = (offsetH < ksize_height)
                    ? 0
                    : (offsetH - ksize_height) / stride_height + 1;
      pwstart = (offsetW < ksize_width)
                    ? 0
                    : (offsetW - ksize_width) / stride_width + 1;
      phend = min(offsetH / stride_height + 1, output_height);
      pwend = min(offsetW / stride_width + 1, output_width);
    }
117 118 119 120 121 122 123 124
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        int pool_size;
        if (adaptive) {
          pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
                                            ksize_height)) *
                      static_cast<int>(
                          ceil(static_cast<double>(input_width) / ksize_width));
        } else {
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          pool_size = exclusive ? (hend - hstart) * (wend - wstart)
                                : ksize_height * ksize_width;
        }
141
        int output_sub_idx = ph * output_width + pw;
142
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
143 144
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
145 146 147 148 149 150
      }
    }
    input_grad[index] = gradient;
  }
}

151
template <typename T>
152
__global__ void KernelMaxPool2DGrad(
153
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
154 155 156 157 158
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
159 160
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
191
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
192 193 194 195
    }
  }
}

N
nhzlx 已提交
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
224
      padding_height, padding_width, pool_compute, exclusive, false, output);
N
nhzlx 已提交
225 226
}

C
chengduoZH 已提交
227 228 229 230 231
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
232
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
233
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
234
 public:
Q
QI JUN 已提交
235
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
236 237 238
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
239
                  bool exclusive, bool adaptive, framework::Tensor* output) {
240 241 242 243
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
244 245 246
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
247 248 249 250 251 252 253 254
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
255
    T* output_data = output->mutable_data<T>(context.GetPlace());
256 257 258 259 260 261

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
262
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
263 264
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
265
        stride_width, padding_height, padding_width, pool_process, exclusive,
266
        adaptive, output_data);
267 268 269
  }
};

C
chengduoZH 已提交
270 271 272 273 274
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
275
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
276
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
277
 public:
Q
QI JUN 已提交
278
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
279
                  const framework::Tensor& input,
280
                  const framework::Tensor& output,
C
chengduo 已提交
281 282 283 284
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
285 286
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
303
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
304 305 306 307 308 309

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
310
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
311 312 313
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
314
        pool_process, exclusive, adaptive, input_grad_data);
315 316 317
  }
};

C
chengduoZH 已提交
318 319 320 321 322
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
323
template <typename T>
Q
QI JUN 已提交
324
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
325
 public:
Q
QI JUN 已提交
326
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
327
                  const framework::Tensor& input,
328
                  const framework::Tensor& output,
C
chengduo 已提交
329 330 331 332
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
333
                  framework::Tensor* input_grad) {
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
351
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
352 353 354 355 356 357

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
358
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
359 360 361 362
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
363 364 365
  }
};

N
nhzlx 已提交
366 367 368 369 370
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
371 372
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
373

Q
QI JUN 已提交
374
template class Pool2dFunctor<platform::CUDADeviceContext,
375
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
376
template class Pool2dFunctor<platform::CUDADeviceContext,
377
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
378 379 380 381 382 383 384
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
385
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
386
template class Pool2dFunctor<platform::CUDADeviceContext,
387
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
388 389 390 391 392 393
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
394 395

template <typename PoolProcess, typename T>
396
__global__ void KernelPool3D(
397 398 399
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
400
    const int ksize_depth, const int ksize_height, const int ksize_width,
401
    const int stride_depth, const int stride_height, const int stride_width,
402
    const int padding_depth, const int padding_height, const int padding_width,
403
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data) {
404
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
405 406 407 408 409 410 411
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
      dstart = ADAPT_START_INDEX(pd, input_depth, output_depth);
      dend = ADAPT_END_INDEX(pd, input_depth, output_depth);

      hstart = ADAPT_START_INDEX(ph, input_height, output_height);
      hend = ADAPT_END_INDEX(ph, input_height, output_height);

      wstart = ADAPT_START_INDEX(pw, input_width, output_width);
      wend = ADAPT_END_INDEX(pw, input_width, output_width);
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
436
    T ele = pool_process.initial();
437 438 439 440 441
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
442
          pool_process.compute(
C
chengduo 已提交
443
              input_data[(d * input_height + h) * input_width + w], &ele);
444 445 446
        }
      }
    }
447
    int pool_size = (exclusive || adaptive)
448 449
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
450
    pool_process.finalize(static_cast<T>(pool_size), &ele);
451 452 453 454 455
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
456
__global__ void KernelPool3DGrad(
457
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
458 459 460 461 462 463
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
464
    bool exclusive, bool adaptive, T* input_grad) {
465
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
466 467 468 469 470 471 472 473
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
      pdstart = offsetD * output_depth / input_depth;
      pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth);
      phstart = offsetH * output_height / input_height;
      phend =
          min((offsetH + 1) * output_height / input_height + 1, output_height);
      pwstart = offsetW * output_width / input_width;
      pwend = min((offsetW + 1) * output_width / input_width + 1, output_width);
    } else {
      pdstart = (offsetD < ksize_depth)
                    ? 0
                    : (offsetD - ksize_depth) / stride_depth + 1;
      phstart = (offsetH < ksize_height)
                    ? 0
                    : (offsetH - ksize_height) / stride_height + 1;
      pwstart = (offsetW < ksize_width)
                    ? 0
                    : (offsetW - ksize_width) / stride_width + 1;
      pdend = min((offsetD) / stride_depth + 1, output_depth);
      phend = min((offsetH) / stride_height + 1, output_height);
      pwend = min((offsetW) / stride_width + 1, output_width);
    }
499 500 501 502 503 504 505 506 507 508 509 510

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
534
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
535
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
536 537
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
538 539 540 541 542 543 544
        }
      }
    }
    input_grad[index] = gradient;
  }
}

545
template <typename T>
546
__global__ void KernelMaxPool3DGrad(
547
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
548 549 550 551 552 553
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
554
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
591
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
592 593 594 595
    }
  }
}

C
chengduoZH 已提交
596 597 598 599 600
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
601
template <typename PoolProcess, class T>
Q
QI JUN 已提交
602
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
603
 public:
Q
QI JUN 已提交
604
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
605 606 607
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
608
                  bool exclusive, bool adaptive, framework::Tensor* output) {
609 610 611 612 613
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
614 615 616 617
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
618 619 620 621 622 623 624 625 626 627 628
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
629
    T* output_data = output->mutable_data<T>(context.GetPlace());
630 631 632 633 634 635 636

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
637
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
638 639 640
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
641
        padding_depth, padding_height, padding_width, pool_process, exclusive,
642
        adaptive, output_data);
643 644 645
  }
};

C
chengduoZH 已提交
646 647 648 649 650
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
651
template <typename PoolProcess, class T>
Q
QI JUN 已提交
652
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
653
 public:
Q
QI JUN 已提交
654
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
655
                  const framework::Tensor& input,
656
                  const framework::Tensor& output,
C
chengduo 已提交
657 658 659 660
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
661 662
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
685
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
686

687 688
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
689 690 691 692
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
693
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
694 695 696 697
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
698
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
699 700 701
  }
};

C
chengduoZH 已提交
702 703 704 705 706
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
707
template <class T>
Q
QI JUN 已提交
708
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
709
 public:
Q
QI JUN 已提交
710
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
711
                  const framework::Tensor& input,
712
                  const framework::Tensor& output,
C
chengduo 已提交
713 714 715 716
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
717
                  framework::Tensor* input_grad) {
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
740
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
741 742 743 744 745 746 747

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
748
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
749 750 751 752 753
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
754 755 756
  }
};

Q
QI JUN 已提交
757 758
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
759

Q
QI JUN 已提交
760
template class Pool3dFunctor<platform::CUDADeviceContext,
761
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
762
template class Pool3dFunctor<platform::CUDADeviceContext,
763
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
764 765 766 767 768 769 770
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
771
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
772
template class Pool3dFunctor<platform::CUDADeviceContext,
773
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
774 775 776 777 778 779
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
780

C
chengduoZH 已提交
781
template <typename T1, typename T2>
C
chengduoZH 已提交
782
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
783
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
784 785 786
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
787
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
788
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
789
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
790 791 792 793 794
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

795 796 797 798 799
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
      hstart = ADAPT_START_INDEX(ph, input_height, output_height);
      hend = ADAPT_END_INDEX(ph, input_height, output_height);
C
chengduoZH 已提交
800

801 802 803 804 805 806 807 808 809 810 811
      wstart = ADAPT_START_INDEX(pw, input_width, output_width);
      wend = ADAPT_END_INDEX(pw, input_width, output_width);
    } else {
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
812 813

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
814
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
815
    int max_index = -1;
C
chengduoZH 已提交
816 817
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
818 819 820 821
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
822 823 824 825
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
826
    mask_data[index] = max_index;
C
chengduoZH 已提交
827 828 829
  }
}

C
chengduoZH 已提交
830
template <typename T1, typename T2>
C
chengduoZH 已提交
831
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
832
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
833 834 835
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
836 837
    const int padding_height, const int padding_width, bool adaptive,
    T1* input_grad) {
C
chengduoZH 已提交
838
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
839
       index += blockDim.x * gridDim.x) {
840 841 842
    int offsetW = index % input_width;
    int offsetH = (index / input_width) % input_height;
    int offsetC = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
843 844
    int batch_idx = index / input_width / input_height / channels;

845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
      phstart = offsetH * output_height / input_height;
      phend =
          min((offsetH + 1) * output_height / input_height + 1, output_height);
      pwstart = offsetW * output_width / input_width;
      pwend = min((offsetW + 1) * output_width / input_width + 1, output_width);
    } else {
      phstart =
          (offsetH + padding_height < ksize_height)
              ? 0
              : (offsetH + padding_height - ksize_height) / stride_height + 1;
      pwstart =
          (offsetW + padding_width < ksize_width)
              ? 0
              : (offsetW + padding_width - ksize_width) / stride_width + 1;
      phend =
          min((offsetH + padding_height) / stride_height + 1, output_height);
      pwend = min((offsetW + padding_width) / stride_width + 1, output_width);
    }
C
chengduoZH 已提交
866

C
chengduoZH 已提交
867
    T1 gradient = 0;
868
    int input_current_featuremap_idx = offsetH * input_width + offsetW;
C
chengduoZH 已提交
869
    int output_idx =
870
        (batch_idx * channels + offsetC) * output_height * output_width;
C
chengduoZH 已提交
871

C
chengduoZH 已提交
872 873
    mask_data += output_idx;
    output_grad += output_idx;
874 875
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
876
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
877 878 879 880 881 882 883
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
884 885 886 887 888
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
889
template <typename T1, typename T2>
Q
QI JUN 已提交
890
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
891
 public:
Q
QI JUN 已提交
892
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
893 894
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
895 896
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
897 898 899 900
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
901 902 903
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
904 905 906 907 908 909 910
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
911 912 913
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
914 915 916 917 918 919

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
920
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
921 922
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
923 924
        stride_width, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
925 926 927
  }
};

C
chengduoZH 已提交
928 929 930 931 932
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
933
template <typename T1, typename T2>
Q
QI JUN 已提交
934
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
935
 public:
Q
QI JUN 已提交
936
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
937
                  const framework::Tensor& output_grad,
C
chengduo 已提交
938 939
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
940
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
941 942 943 944 945
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
946 947 948 949 950 951 952 953 954
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
955 956 957
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
958 959 960 961 962 963

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
964
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
965 966
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
967
        stride_height, stride_width, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
968
        input_grad_data);
C
chengduoZH 已提交
969 970 971
  }
};

Q
QI JUN 已提交
972 973 974 975 976 977 978 979
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
980

C
chengduoZH 已提交
981
template <typename T1, typename T2>
C
chengduoZH 已提交
982
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
983
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
984 985 986 987 988
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
989
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
990
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
991 992 993 994 995 996 997
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
998

999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
      dstart = ADAPT_START_INDEX(pd, input_depth, output_depth);
      dend = ADAPT_END_INDEX(pd, input_depth, output_depth);

      hstart = ADAPT_START_INDEX(ph, input_height, output_height);
      hend = ADAPT_END_INDEX(ph, input_height, output_height);

      wstart = ADAPT_START_INDEX(pw, input_width, output_width);
      wend = ADAPT_END_INDEX(pw, input_width, output_width);
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1022

C
chengduoZH 已提交
1023
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1024
    int max_index = -1;
C
chengduoZH 已提交
1025 1026 1027 1028 1029 1030 1031
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1032 1033
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1034 1035 1036 1037 1038
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1039
    mask_data[index] = max_index;
C
chengduoZH 已提交
1040 1041 1042
  }
}

C
chengduoZH 已提交
1043
template <typename T1, typename T2>
C
chengduoZH 已提交
1044
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1045 1046 1047 1048 1049 1050
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1051
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1052
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1053
       index += blockDim.x * gridDim.x) {
1054 1055 1056 1057
    int offsetW = index % input_width;
    int offsetH = (index / input_width) % input_height;
    int offsetD = (index / input_width / input_height) % input_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1058 1059
    int batch_idx = index / input_width / input_height / input_depth / channels;

1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
      pdstart = offsetD * output_depth / input_depth;
      pdend = min((offsetD + 1) * output_depth / input_depth + 1, output_depth);
      phstart = offsetH * output_height / input_height;
      phend =
          min((offsetH + 1) * output_height / input_height + 1, output_height);
      pwstart = offsetW * output_width / input_width;
      pwend = min((offsetW + 1) * output_width / input_width + 1, output_width);
    } else {
      pdstart =
          (offsetD + padding_depth < ksize_depth)
              ? 0
              : (offsetD + padding_depth - ksize_depth) / stride_depth + 1;
      phstart =
          (offsetH + padding_height < ksize_height)
              ? 0
              : (offsetH + padding_height - ksize_height) / stride_height + 1;
      pwstart =
          (offsetW + padding_width < ksize_width)
              ? 0
              : (offsetW + padding_width - ksize_width) / stride_width + 1;
      pdend = min((offsetD + padding_depth) / stride_depth + 1, output_depth);
      phend =
          min((offsetH + padding_height) / stride_height + 1, output_height);
      pwend = min((offsetW + padding_width) / stride_width + 1, output_width);
    }
C
chengduoZH 已提交
1089

C
chengduoZH 已提交
1090
    T1 gradient = 0;
C
chengduoZH 已提交
1091
    int input_current_feature_map_idx =
1092 1093
        (offsetD * input_height + offsetH) * input_width + offsetW;
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
C
chengduoZH 已提交
1094 1095 1096 1097
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1098 1099 1100
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1101 1102
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
1103 1104 1105 1106 1107 1108 1109 1110 1111
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1112 1113 1114 1115 1116
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1117
template <typename T1, typename T2>
Q
QI JUN 已提交
1118
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1119
 public:
Q
QI JUN 已提交
1120
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1121 1122
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1123 1124
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1125 1126 1127 1128 1129
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1130 1131 1132 1133
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1134 1135 1136 1137 1138 1139 1140 1141 1142 1143
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1144 1145 1146
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1147 1148 1149 1150 1151 1152 1153

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1154
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1155 1156 1157
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1158 1159
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1160 1161 1162
  }
};

C
chengduoZH 已提交
1163 1164 1165 1166 1167
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1168
template <typename T1, typename T2>
Q
QI JUN 已提交
1169
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1170
 public:
Q
QI JUN 已提交
1171
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1172
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1173 1174
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1175
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1176 1177 1178 1179 1180 1181
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1182 1183 1184
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1195 1196 1197
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1198 1199 1200 1201 1202 1203 1204

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1205
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1206 1207 1208
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1209
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1210
        input_grad_data);
C
chengduoZH 已提交
1211 1212 1213
  }
};

Q
QI JUN 已提交
1214 1215 1216 1217 1218 1219 1220 1221
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1222 1223 1224 1225

}  // namespace math
}  // namespace operators
}  // namespace paddle