pooling.cu 51.7 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24
template <typename PoolProcess, typename T>
25
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
26 27 28 29 30 31
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
D
dengkaipeng 已提交
32
                             bool exclusive, bool adaptive, T* output_data) {
33 34
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
35 36 37 38 39
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

40 41
    int hstart, hend;
    int wstart, wend;
D
dengkaipeng 已提交
42
    if (adaptive) {
D
dengkaipeng 已提交
43 44
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
45

D
dengkaipeng 已提交
46 47
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
D
dengkaipeng 已提交
48
    } else {
49 50
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
51 52
      hstart = max(hstart, 0);

53 54
      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
55 56
      wstart = max(wstart, 0);
    }
57 58

    input_data += (batch_idx * channels + c) * input_height * input_width;
59
    T ele = pool_process.initial();
60 61
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduo 已提交
62
        pool_process.compute(input_data[h * input_width + w], &ele);
63 64
      }
    }
D
dengkaipeng 已提交
65 66
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
67
    pool_process.finalize(static_cast<T>(pool_size), &ele);
68 69 70 71 72
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
73
__global__ void KernelPool2DGrad(
74
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
75 76 77 78
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
79
    PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad) {
80 81
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
82 83
    int w_offset = index % input_width + padding_width;
    int h_offset = (index / input_width) % input_height + padding_height;
84 85 86
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

87 88 89
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
90
      phstart = h_offset * output_height / input_height;
91
      phend =
D
dengkaipeng 已提交
92 93 94 95
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
96
    } else {
D
dengkaipeng 已提交
97
      phstart = (h_offset < ksize_height)
98
                    ? 0
D
dengkaipeng 已提交
99 100
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
101
                    ? 0
D
dengkaipeng 已提交
102 103 104
                    : (w_offset - ksize_width) / stride_width + 1;
      phend = min(h_offset / stride_height + 1, output_height);
      pwend = min(w_offset / stride_width + 1, output_width);
105
    }
106 107 108 109 110 111 112 113
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        int pool_size;
        if (adaptive) {
          pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
                                            ksize_height)) *
                      static_cast<int>(
                          ceil(static_cast<double>(input_width) / ksize_width));
        } else {
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          pool_size = exclusive ? (hend - hstart) * (wend - wstart)
                                : ksize_height * ksize_width;
        }
130
        int output_sub_idx = ph * output_width + pw;
131
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
132 133
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
134 135 136 137 138 139
      }
    }
    input_grad[index] = gradient;
  }
}

140
template <typename T>
141
__global__ void KernelMaxPool2DGrad(
142
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
143 144 145 146 147
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
148 149
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
180
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
181 182 183 184
    }
  }
}

N
nhzlx 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
213
      padding_height, padding_width, pool_compute, exclusive, false, output);
N
nhzlx 已提交
214 215
}

C
chengduoZH 已提交
216 217 218 219 220
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
221
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
222
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
223
 public:
Q
QI JUN 已提交
224
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
225 226 227
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
228
                  bool exclusive, bool adaptive, framework::Tensor* output) {
229 230 231 232
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
233 234 235
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
236 237 238 239 240 241 242 243
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
244
    T* output_data = output->mutable_data<T>(context.GetPlace());
245 246 247 248 249 250

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
251
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
252 253
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
254
        stride_width, padding_height, padding_width, pool_process, exclusive,
255
        adaptive, output_data);
256 257 258
  }
};

C
chengduoZH 已提交
259 260 261 262 263
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
264
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
265
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
266
 public:
Q
QI JUN 已提交
267
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
268
                  const framework::Tensor& input,
269
                  const framework::Tensor& output,
C
chengduo 已提交
270 271 272 273
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
274 275
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
292
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
293 294 295 296 297 298

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
299
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
300 301 302
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
303
        pool_process, exclusive, adaptive, input_grad_data);
304 305 306
  }
};

C
chengduoZH 已提交
307 308 309 310 311
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
312
template <typename T>
Q
QI JUN 已提交
313
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
314
 public:
Q
QI JUN 已提交
315
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
316
                  const framework::Tensor& input,
317
                  const framework::Tensor& output,
C
chengduo 已提交
318 319 320 321
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
322
                  framework::Tensor* input_grad) {
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
340
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
341 342 343 344 345 346

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
347
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
348 349 350 351
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
352 353 354
  }
};

N
nhzlx 已提交
355 356 357 358 359
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
360 361
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
362

Q
QI JUN 已提交
363
template class Pool2dFunctor<platform::CUDADeviceContext,
364
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
365
template class Pool2dFunctor<platform::CUDADeviceContext,
366
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
367 368 369 370 371 372 373
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
374
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
375
template class Pool2dFunctor<platform::CUDADeviceContext,
376
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
377 378 379 380 381 382
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
383 384

template <typename PoolProcess, typename T>
385
__global__ void KernelPool3D(
386 387 388
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
389
    const int ksize_depth, const int ksize_height, const int ksize_width,
390
    const int stride_depth, const int stride_height, const int stride_width,
391
    const int padding_depth, const int padding_height, const int padding_width,
392
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data) {
393
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
394 395 396 397 398 399 400
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
401 402 403 404 405

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
406 407
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
408

D
dengkaipeng 已提交
409 410
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
411

D
dengkaipeng 已提交
412 413
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
414 415 416 417 418 419 420 421 422 423 424
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
425
    T ele = pool_process.initial();
426 427 428 429 430
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
431
          pool_process.compute(
C
chengduo 已提交
432
              input_data[(d * input_height + h) * input_width + w], &ele);
433 434 435
        }
      }
    }
436
    int pool_size = (exclusive || adaptive)
437 438
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
439
    pool_process.finalize(static_cast<T>(pool_size), &ele);
440 441 442 443 444
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
445
__global__ void KernelPool3DGrad(
446
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
447 448 449 450 451 452
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
453
    bool exclusive, bool adaptive, T* input_grad) {
454
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
455
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
456 457 458
    int w_offset = index % input_width + padding_width;
    int h_offset = (index / input_width) % input_height + padding_height;
    int d_offset =
459 460 461 462
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

463 464 465 466
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
467 468 469 470
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
471
      phend =
D
dengkaipeng 已提交
472 473 474 475
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
476
    } else {
D
dengkaipeng 已提交
477
      pdstart = (d_offset < ksize_depth)
478
                    ? 0
D
dengkaipeng 已提交
479 480
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
481
                    ? 0
D
dengkaipeng 已提交
482 483
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
484
                    ? 0
D
dengkaipeng 已提交
485 486 487 488
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
489
    }
490 491 492 493 494 495 496 497 498 499 500 501

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
525
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
526
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
527 528
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
529 530 531 532 533 534 535
        }
      }
    }
    input_grad[index] = gradient;
  }
}

536
template <typename T>
537
__global__ void KernelMaxPool3DGrad(
538
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
539 540 541 542 543 544
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
545
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
582
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
583 584 585 586
    }
  }
}

C
chengduoZH 已提交
587 588 589 590 591
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
592
template <typename PoolProcess, class T>
Q
QI JUN 已提交
593
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
594
 public:
Q
QI JUN 已提交
595
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
596 597 598
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
599
                  bool exclusive, bool adaptive, framework::Tensor* output) {
600 601 602 603 604
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
605 606 607 608
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
609 610 611 612 613 614 615 616 617 618 619
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
620
    T* output_data = output->mutable_data<T>(context.GetPlace());
621 622 623 624 625 626 627

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
628
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
629 630 631
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
632
        padding_depth, padding_height, padding_width, pool_process, exclusive,
633
        adaptive, output_data);
634 635 636
  }
};

C
chengduoZH 已提交
637 638 639 640 641
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
642
template <typename PoolProcess, class T>
Q
QI JUN 已提交
643
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
644
 public:
Q
QI JUN 已提交
645
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
646
                  const framework::Tensor& input,
647
                  const framework::Tensor& output,
C
chengduo 已提交
648 649 650 651
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
652 653
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
676
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
677

678 679
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
680 681 682 683
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
684
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
685 686 687 688
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
689
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
690 691 692
  }
};

C
chengduoZH 已提交
693 694 695 696 697
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
698
template <class T>
Q
QI JUN 已提交
699
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
700
 public:
Q
QI JUN 已提交
701
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
702
                  const framework::Tensor& input,
703
                  const framework::Tensor& output,
C
chengduo 已提交
704 705 706 707
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
708
                  framework::Tensor* input_grad) {
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
731
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
732 733 734 735 736 737 738

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
739
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
740 741 742 743 744
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
745 746 747
  }
};

Q
QI JUN 已提交
748 749
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
750

Q
QI JUN 已提交
751
template class Pool3dFunctor<platform::CUDADeviceContext,
752
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
753
template class Pool3dFunctor<platform::CUDADeviceContext,
754
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
755 756 757 758 759 760 761
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
762
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
763
template class Pool3dFunctor<platform::CUDADeviceContext,
764
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
765 766 767 768 769 770
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
771

C
chengduoZH 已提交
772
template <typename T1, typename T2>
C
chengduoZH 已提交
773
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
774
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
775 776 777
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
778
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
779
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
780
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
781 782 783 784 785
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

786 787 788
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
789 790
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
C
chengduoZH 已提交
791

D
dengkaipeng 已提交
792 793
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
794 795 796 797 798 799 800 801 802
    } else {
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
803 804

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
805
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
806
    int max_index = -1;
C
chengduoZH 已提交
807 808
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
809 810 811 812
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
813 814 815 816
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
817
    mask_data[index] = max_index;
C
chengduoZH 已提交
818 819 820
  }
}

C
chengduoZH 已提交
821
template <typename T1, typename T2>
C
chengduoZH 已提交
822
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
823
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
824 825 826
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
827 828
    const int padding_height, const int padding_width, bool adaptive,
    T1* input_grad) {
C
chengduoZH 已提交
829
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
830
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
831 832
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
833
    int offsetC = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
834 835
    int batch_idx = index / input_width / input_height / channels;

836 837 838
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
839
      phstart = h_offset * output_height / input_height;
840
      phend =
D
dengkaipeng 已提交
841 842 843 844
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
845 846
    } else {
      phstart =
D
dengkaipeng 已提交
847
          (h_offset + padding_height < ksize_height)
848
              ? 0
D
dengkaipeng 已提交
849
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
850
      pwstart =
D
dengkaipeng 已提交
851
          (w_offset + padding_width < ksize_width)
852
              ? 0
D
dengkaipeng 已提交
853
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
854
      phend =
D
dengkaipeng 已提交
855 856
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
857
    }
C
chengduoZH 已提交
858

C
chengduoZH 已提交
859
    T1 gradient = 0;
D
dengkaipeng 已提交
860
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
861
    int output_idx =
862
        (batch_idx * channels + offsetC) * output_height * output_width;
C
chengduoZH 已提交
863

C
chengduoZH 已提交
864 865
    mask_data += output_idx;
    output_grad += output_idx;
866 867
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
868
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
869 870 871 872 873 874 875
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
876 877 878 879 880
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
881
template <typename T1, typename T2>
Q
QI JUN 已提交
882
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
883
 public:
Q
QI JUN 已提交
884
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
885 886
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
887 888
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
889 890 891 892
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
893 894 895
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
896 897 898 899 900 901 902
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
903 904 905
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
906 907 908 909 910 911

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
912
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
913 914
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
915 916
        stride_width, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
917 918 919
  }
};

C
chengduoZH 已提交
920 921 922 923 924
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
925
template <typename T1, typename T2>
Q
QI JUN 已提交
926
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
927
 public:
Q
QI JUN 已提交
928
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
929
                  const framework::Tensor& output_grad,
C
chengduo 已提交
930 931
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
932
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
933 934 935 936 937
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
938 939 940 941 942 943 944 945 946
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
947 948 949
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
950 951 952 953 954 955

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
956
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
957 958
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
959
        stride_height, stride_width, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
960
        input_grad_data);
C
chengduoZH 已提交
961 962 963
  }
};

Q
QI JUN 已提交
964 965 966 967 968 969 970 971
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
972

C
chengduoZH 已提交
973
template <typename T1, typename T2>
C
chengduoZH 已提交
974
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
975
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
976 977 978 979 980
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
981
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
982
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
983 984 985 986 987 988 989
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
990

991 992 993 994
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
995 996
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
997

D
dengkaipeng 已提交
998 999
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
1000

D
dengkaipeng 已提交
1001 1002
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1014

C
chengduoZH 已提交
1015
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1016
    int max_index = -1;
C
chengduoZH 已提交
1017 1018 1019 1020 1021 1022 1023
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1024 1025
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1026 1027 1028 1029 1030
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1031
    mask_data[index] = max_index;
C
chengduoZH 已提交
1032 1033 1034
  }
}

C
chengduoZH 已提交
1035
template <typename T1, typename T2>
C
chengduoZH 已提交
1036
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1037 1038 1039 1040 1041 1042
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1043
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1044
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1045
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1046 1047 1048
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
1049
    int offsetC = (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1050 1051
    int batch_idx = index / input_width / input_height / input_depth / channels;

1052 1053 1054 1055
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1056 1057 1058 1059
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
1060
      phend =
D
dengkaipeng 已提交
1061 1062 1063 1064
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1065 1066
    } else {
      pdstart =
D
dengkaipeng 已提交
1067
          (d_offset + padding_depth < ksize_depth)
1068
              ? 0
D
dengkaipeng 已提交
1069
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
1070
      phstart =
D
dengkaipeng 已提交
1071
          (h_offset + padding_height < ksize_height)
1072
              ? 0
D
dengkaipeng 已提交
1073
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1074
      pwstart =
D
dengkaipeng 已提交
1075
          (w_offset + padding_width < ksize_width)
1076
              ? 0
D
dengkaipeng 已提交
1077 1078
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
1079
      phend =
D
dengkaipeng 已提交
1080 1081
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1082
    }
C
chengduoZH 已提交
1083

C
chengduoZH 已提交
1084
    T1 gradient = 0;
C
chengduoZH 已提交
1085
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
1086
        (d_offset * input_height + h_offset) * input_width + w_offset;
1087
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
C
chengduoZH 已提交
1088 1089 1090 1091
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1092 1093 1094
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1095 1096
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
1097 1098 1099 1100 1101 1102 1103 1104 1105
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1106 1107 1108 1109 1110
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1111
template <typename T1, typename T2>
Q
QI JUN 已提交
1112
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1113
 public:
Q
QI JUN 已提交
1114
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1115 1116
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1117 1118
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1119 1120 1121 1122 1123
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1124 1125 1126 1127
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1138 1139 1140
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1141 1142 1143 1144 1145 1146 1147

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1148
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1149 1150 1151
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1152 1153
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1154 1155 1156
  }
};

C
chengduoZH 已提交
1157 1158 1159 1160 1161
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1162
template <typename T1, typename T2>
Q
QI JUN 已提交
1163
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1164
 public:
Q
QI JUN 已提交
1165
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1166
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1167 1168
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1169
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1170 1171 1172 1173 1174 1175
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1176 1177 1178
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1179 1180 1181 1182 1183 1184 1185 1186 1187 1188
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1189 1190 1191
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1192 1193 1194 1195 1196 1197 1198

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1199
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1200 1201 1202
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1203
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1204
        input_grad_data);
C
chengduoZH 已提交
1205 1206 1207
  }
};

Q
QI JUN 已提交
1208 1209 1210 1211 1212 1213 1214 1215
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1216 1217 1218 1219

}  // namespace math
}  // namespace operators
}  // namespace paddle