pooling.cu 47.5 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24
template <typename PoolProcess, typename T>
25
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
26 27 28 29 30 31
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
D
dengkaipeng 已提交
32
                             bool exclusive, bool adaptive, T* output_data) {
33 34
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
35 36 37 38 39
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

D
dengkaipeng 已提交
40 41 42
    if (adaptive) {
      int hstart = ADAPT_START_INDEX(ph, input_height, output_height);
      int hend = ADAPT_END_INDEX(ph, input_height, output_height);
43

D
dengkaipeng 已提交
44 45 46 47 48 49 50 51 52 53 54
      int wstart = ADAPT_START_INDEX(pw, input_width, output_width);
      int wend = ADAPT_END_INDEX(pw, input_width, output_width);
    } else {
      int hstart = ph * stride_height - padding_height;
      int hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

      int wstart = pw * stride_width - padding_width;
      int wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
55 56

    input_data += (batch_idx * channels + c) * input_height * input_width;
57
    T ele = pool_process.initial();
58 59
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduo 已提交
60
        pool_process.compute(input_data[h * input_width + w], &ele);
61 62
      }
    }
D
dengkaipeng 已提交
63 64
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
65
    pool_process.finalize(static_cast<T>(pool_size), &ele);
66 67 68 69 70
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
71
__global__ void KernelPool2DGrad(
72
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
73 74 75 76
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
77
    PoolProcess pool_process, bool exclusive, T* input_grad) {
78 79
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
107
        int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
108
                                  : ksize_height * ksize_width;
109
        int output_sub_idx = ph * output_width + pw;
110
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
111 112
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
113 114 115 116 117 118
      }
    }
    input_grad[index] = gradient;
  }
}

119
template <typename T>
120
__global__ void KernelMaxPool2DGrad(
121
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
122 123 124 125 126
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
127 128
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
159
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
160 161 162 163
    }
  }
}

N
nhzlx 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
      padding_height, padding_width, pool_compute, exclusive, output);
}

C
chengduoZH 已提交
195 196 197 198 199
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
200
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
201
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
202
 public:
Q
QI JUN 已提交
203
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
204 205 206
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
207
                  bool exclusive, framework::Tensor* output) {
208 209 210 211
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
212 213 214
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
215 216 217 218 219 220 221 222
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
223
    T* output_data = output->mutable_data<T>(context.GetPlace());
224 225 226 227 228 229

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
230
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
231 232
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
233
        stride_width, padding_height, padding_width, pool_process, exclusive,
234
        output_data);
235 236 237
  }
};

C
chengduoZH 已提交
238 239 240 241 242
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
243
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
244
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
245
 public:
Q
QI JUN 已提交
246
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
247
                  const framework::Tensor& input,
248
                  const framework::Tensor& output,
C
chengduo 已提交
249 250 251 252
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
253
                  bool exclusive, framework::Tensor* input_grad) {
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
270
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
271 272 273 274 275 276

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
277
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
278 279 280
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
281
        pool_process, exclusive, input_grad_data);
282 283 284
  }
};

C
chengduoZH 已提交
285 286 287 288 289
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
290
template <typename T>
Q
QI JUN 已提交
291
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
292
 public:
Q
QI JUN 已提交
293
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
294
                  const framework::Tensor& input,
295
                  const framework::Tensor& output,
C
chengduo 已提交
296 297 298 299
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
300
                  framework::Tensor* input_grad) {
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
318
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
319 320 321 322 323 324

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
325
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
326 327 328 329
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
330 331 332
  }
};

N
nhzlx 已提交
333 334 335 336 337
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
338 339
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
340

Q
QI JUN 已提交
341
template class Pool2dFunctor<platform::CUDADeviceContext,
342
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
343
template class Pool2dFunctor<platform::CUDADeviceContext,
344
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
345 346 347 348 349 350 351
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
352
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
353
template class Pool2dFunctor<platform::CUDADeviceContext,
354
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
355 356 357 358 359 360
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
361 362

template <typename PoolProcess, typename T>
363
__global__ void KernelPool3D(
364 365 366
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
367
    const int ksize_depth, const int ksize_height, const int ksize_width,
368
    const int stride_depth, const int stride_height, const int stride_width,
369 370
    const int padding_depth, const int padding_height, const int padding_width,
    PoolProcess pool_process, bool exclusive, T* output_data) {
371
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
388
    T ele = pool_process.initial();
389 390 391 392 393
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
394
          pool_process.compute(
C
chengduo 已提交
395
              input_data[(d * input_height + h) * input_width + w], &ele);
396 397 398
        }
      }
    }
399 400 401
    int pool_size = exclusive
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
402
    pool_process.finalize(static_cast<T>(pool_size), &ele);
403 404 405 406 407
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
408
__global__ void KernelPool3DGrad(
409
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
410 411 412 413 414 415
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
416
    bool exclusive, T* input_grad) {
417
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
418 419 420 421 422 423 424 425 426 427
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
428
                      : (offsetD - ksize_depth) / stride_depth + 1;
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
459 460 461
          int pool_size =
              exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
462
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
463
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
464 465
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
466 467 468 469 470 471 472
        }
      }
    }
    input_grad[index] = gradient;
  }
}

473
template <typename T>
474
__global__ void KernelMaxPool3DGrad(
475
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
476 477 478 479 480 481
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
482
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
519
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
520 521 522 523
    }
  }
}

C
chengduoZH 已提交
524 525 526 527 528
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
529
template <typename PoolProcess, class T>
Q
QI JUN 已提交
530
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
531
 public:
Q
QI JUN 已提交
532
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
533 534 535
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
536
                  bool exclusive, framework::Tensor* output) {
537 538 539 540 541
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
542 543 544 545
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
546 547 548 549 550 551 552 553 554 555 556
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
557
    T* output_data = output->mutable_data<T>(context.GetPlace());
558 559 560 561 562 563 564

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
565
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
566 567 568
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
569 570
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        output_data);
571 572 573
  }
};

C
chengduoZH 已提交
574 575 576 577 578
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
579
template <typename PoolProcess, class T>
Q
QI JUN 已提交
580
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
581
 public:
Q
QI JUN 已提交
582
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
583
                  const framework::Tensor& input,
584
                  const framework::Tensor& output,
C
chengduo 已提交
585 586 587 588
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
589
                  bool exclusive, framework::Tensor* input_grad) {
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
612
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
613

614 615
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
616 617 618 619
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
620
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
621 622 623 624
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
625
        padding_width, pool_process, exclusive, input_grad_data);
626 627 628
  }
};

C
chengduoZH 已提交
629 630 631 632 633
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
634
template <class T>
Q
QI JUN 已提交
635
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
636
 public:
Q
QI JUN 已提交
637
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
638
                  const framework::Tensor& input,
639
                  const framework::Tensor& output,
C
chengduo 已提交
640 641 642 643
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
644
                  framework::Tensor* input_grad) {
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
667
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
668 669 670 671 672 673 674

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
675
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
676 677 678 679 680
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
681 682 683
  }
};

Q
QI JUN 已提交
684 685
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
686

Q
QI JUN 已提交
687
template class Pool3dFunctor<platform::CUDADeviceContext,
688
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
689
template class Pool3dFunctor<platform::CUDADeviceContext,
690
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
691 692 693 694 695 696 697
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
698
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
699
template class Pool3dFunctor<platform::CUDADeviceContext,
700
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
701 702 703 704 705 706
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
707

C
chengduoZH 已提交
708
template <typename T1, typename T2>
C
chengduoZH 已提交
709
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
710
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
711 712 713
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
C
chengduoZH 已提交
714
    const int padding_width, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
715
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
716
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729 730
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
731
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
732
    int max_index = -1;
C
chengduoZH 已提交
733 734
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
735 736 737 738
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
739 740 741 742
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
743
    mask_data[index] = max_index;
C
chengduoZH 已提交
744 745 746
  }
}

C
chengduoZH 已提交
747
template <typename T1, typename T2>
C
chengduoZH 已提交
748
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
749
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
750 751 752
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
C
chengduoZH 已提交
753
    const int padding_height, const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
754
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
755 756 757 758
       index += blockDim.x * gridDim.x) {
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int c_offset = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
759 760
    int batch_idx = index / input_width / input_height / channels;

C
chengduoZH 已提交
761 762 763 764 765 766 767 768 769 770 771 772 773
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);

C
chengduoZH 已提交
774
    T1 gradient = 0;
C
chengduoZH 已提交
775
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
776
    int output_idx =
C
chengduoZH 已提交
777 778
        (batch_idx * channels + c_offset) * output_height * output_width;

C
chengduoZH 已提交
779 780
    mask_data += output_idx;
    output_grad += output_idx;
C
chengduoZH 已提交
781 782 783
    for (int ph = ph_start; ph < ph_end; ++ph) {
      for (int pw = pw_start; pw < pw_end; ++pw) {
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
784 785 786 787 788 789 790
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
791 792 793 794 795
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
796
template <typename T1, typename T2>
Q
QI JUN 已提交
797
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
798
 public:
Q
QI JUN 已提交
799
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
800 801 802 803
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
804 805 806 807
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
808 809 810
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
811 812 813 814 815 816 817
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
818 819 820
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
821 822 823 824 825 826

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
827
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
828 829 830
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
831 832 833
  }
};

C
chengduoZH 已提交
834 835 836 837 838
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
839
template <typename T1, typename T2>
Q
QI JUN 已提交
840
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
841
 public:
Q
QI JUN 已提交
842
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
843
                  const framework::Tensor& output_grad,
C
chengduo 已提交
844 845 846
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
847 848 849 850 851
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
852 853 854 855 856 857 858 859 860
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
861 862 863
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
864 865 866 867 868 869

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
870
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
871 872 873 874
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
875 876 877
  }
};

Q
QI JUN 已提交
878 879 880 881 882 883 884 885
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
886

C
chengduoZH 已提交
887
template <typename T1, typename T2>
C
chengduoZH 已提交
888
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
889
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
890 891 892 893 894
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
C
chengduoZH 已提交
895
    T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
896
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
897 898 899 900 901 902 903
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
904

C
chengduoZH 已提交
905 906 907 908 909 910 911 912 913
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
C
chengduoZH 已提交
914

C
chengduoZH 已提交
915
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
916
    int max_index = -1;
C
chengduoZH 已提交
917 918 919 920 921 922 923
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
924 925
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
926 927 928 929 930
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
931
    mask_data[index] = max_index;
C
chengduoZH 已提交
932 933 934
  }
}

C
chengduoZH 已提交
935
template <typename T1, typename T2>
C
chengduoZH 已提交
936
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
937 938 939 940 941 942 943
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
944
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
945
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
946 947 948 949 950
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
951 952
    int batch_idx = index / input_width / input_height / input_depth / channels;

C
chengduoZH 已提交
953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970
    int pd_start =
        (d_offset + padding_depth < ksize_depth)
            ? 0
            : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int pd_end =
        min((d_offset + padding_depth) / stride_depth + 1, output_depth);
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);
C
chengduoZH 已提交
971

C
chengduoZH 已提交
972
    T1 gradient = 0;
C
chengduoZH 已提交
973 974 975
    int input_current_feature_map_idx =
        (d_offset * input_height + h_offset) * input_width + w_offset;
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
976 977 978 979
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

C
chengduoZH 已提交
980 981 982 983 984
    for (int pd = pd_start; pd < pd_end; ++pd) {
      for (int ph = ph_start; ph < ph_end; ++ph) {
        for (int pw = pw_start; pw < pw_end; ++pw) {
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
985 986 987 988 989 990 991 992 993
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
994 995 996 997 998
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
999
template <typename T1, typename T2>
Q
QI JUN 已提交
1000
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1001
 public:
Q
QI JUN 已提交
1002
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1003 1004 1005 1006
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
1007 1008 1009 1010 1011
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1012 1013 1014 1015
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1026 1027 1028
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1029 1030 1031 1032 1033 1034 1035

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1036
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1037 1038 1039 1040
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
1041 1042 1043
  }
};

C
chengduoZH 已提交
1044 1045 1046 1047 1048
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1049
template <typename T1, typename T2>
Q
QI JUN 已提交
1050
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1051
 public:
Q
QI JUN 已提交
1052
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1053
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1054 1055 1056
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1057 1058 1059 1060 1061 1062
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1063 1064 1065
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1076 1077 1078
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1079 1080 1081 1082 1083 1084 1085

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1086
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1087 1088 1089 1090 1091
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
1092 1093 1094
  }
};

Q
QI JUN 已提交
1095 1096 1097 1098 1099 1100 1101 1102
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1103 1104 1105 1106

}  // namespace math
}  // namespace operators
}  // namespace paddle