pooling.cu 69.4 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24
template <typename PoolProcess, typename T>
25
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
26 27 28 29 30 31
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
32 33
                             bool exclusive, bool adaptive, T* output_data,
                             bool channel_last = false) {
34 35
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
36 37 38 39 40 41 42 43 44 45 46 47
    int pw, ph, c, batch_idx;
    if (!channel_last) { /*NCHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      c = (index / output_width / output_height) % channels;
      batch_idx = index / output_width / output_height / channels;
    } else { /*NHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      batch_idx = index / channels / output_width / output_height;
    }
48

49 50
    int hstart, hend;
    int wstart, wend;
D
dengkaipeng 已提交
51
    if (adaptive) {
D
dengkaipeng 已提交
52 53
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
54

D
dengkaipeng 已提交
55 56
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
D
dengkaipeng 已提交
57
    } else {
58 59
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
D
dengkaipeng 已提交
60 61
      hstart = max(hstart, 0);

62 63
      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
D
dengkaipeng 已提交
64 65
      wstart = max(wstart, 0);
    }
66

67 68 69 70 71
    if (!channel_last) {
      input_data += (batch_idx * channels + c) * input_height * input_width;
    } else {
      input_data += batch_idx * input_height * input_width * channels;
    }
72
    T ele = pool_process.initial();
73 74
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
75 76 77
        auto input_idx = channel_last ? (h * input_width + w) * channels + c
                                      : h * input_width + w;
        pool_process.compute(input_data[input_idx], &ele);
78 79
      }
    }
D
dengkaipeng 已提交
80 81
    int pool_size = (exclusive || adaptive) ? (hend - hstart) * (wend - wstart)
                                            : ksize_height * ksize_width;
C
chengduo 已提交
82
    pool_process.finalize(static_cast<T>(pool_size), &ele);
83 84 85 86
    output_data[index] = ele;
  }
}
template <typename PoolProcess, typename T>
87
__global__ void KernelPool2DGrad(
88
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
89 90 91 92
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
93 94
    PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad,
    bool channel_last = false) {
95 96
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
97 98 99 100 101 102 103 104 105 106 107 108 109
    int w_offset, h_offset, offsetC, batch_idx;
    if (!channel_last) { /* NCHW */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      offsetC = (index / input_width / input_height) % channels;
      batch_idx = index / input_width / input_height / channels;
    } else { /* NHWC */
      offsetC = index % channels;
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      batch_idx = index / channels / input_width / input_height;
    }
110

111 112 113
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
114 115 116 117 118
      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
119
    } else {
D
dengkaipeng 已提交
120
      phstart = (h_offset < ksize_height)
121
                    ? 0
D
dengkaipeng 已提交
122 123
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
124
                    ? 0
D
dengkaipeng 已提交
125 126 127
                    : (w_offset - ksize_width) / stride_width + 1;
      phend = min(h_offset / stride_height + 1, output_height);
      pwend = min(w_offset / stride_width + 1, output_width);
128
    }
129 130
    T gradient = 0;
    T input = input_data[index];
131 132 133 134 135 136 137 138 139 140 141 142

    int output_stride;
    if (!channel_last) {
      output_stride =
          (batch_idx * channels + offsetC) * output_height * output_width;
    } else {
      output_stride = batch_idx * output_height * output_width * channels;
    }

    output_data += output_stride;
    output_grad += output_stride;

143 144
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        int pool_size;
        if (adaptive) {
          pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
                                            ksize_height)) *
                      static_cast<int>(
                          ceil(static_cast<double>(input_width) / ksize_width));
        } else {
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          pool_size = exclusive ? (hend - hstart) * (wend - wstart)
                                : ksize_height * ksize_width;
        }
161

162 163 164
        int output_sub_idx = channel_last
                                 ? (ph * output_width + pw) * channels + offsetC
                                 : ph * output_width + pw;
165
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
166 167
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
168 169 170 171 172 173
      }
    }
    input_grad[index] = gradient;
  }
}

174
template <typename T>
175
__global__ void KernelMaxPool2DGrad(
176
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
177 178 179 180
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
181
    T* input_grad, bool channel_last = false) {
182 183
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
184 185 186 187 188 189 190 191 192 193 194 195
    int pw, ph, c, batch_idx;
    if (!channel_last) { /* NCHW */
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      c = (index / output_width / output_height) % channels;
      batch_idx = index / output_width / output_height / channels;
    } else { /* NHWC */
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      batch_idx = index / channels / output_width / output_height;
    }
196 197 198 199 200 201 202 203
    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

204 205 206 207 208 209 210 211
    int input_stride;
    if (!channel_last) {
      input_stride = (batch_idx * channels + c) * input_height * input_width;
    } else {
      input_stride = batch_idx * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
212 213 214 215 216 217

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
218 219 220 221
        int input_data_idx = channel_last ? (h * input_width + w) * channels + c
                                          : h * input_width + w;
        if (ele == input_data[input_data_idx]) {
          maxIndex = input_data_idx;
222 223 224 225 226 227 228
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
229
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
230 231 232 233
    }
  }
}

N
nhzlx 已提交
234 235 236 237 238
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
239 240
    PoolProcess pool_compute, bool exclusive, bool adaptive, T* output,
    cudaStream_t stream) {
N
nhzlx 已提交
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
263
      padding_height, padding_width, pool_compute, exclusive, adaptive, output);
N
nhzlx 已提交
264 265
}

C
chengduoZH 已提交
266
/*
267 268 269 270 271 272
* Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height
* and width, respectively.
* Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively.
*/
273
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
274
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
275
 public:
Q
QI JUN 已提交
276
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
277 278 279
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
280
                  bool exclusive, bool adaptive, framework::Tensor* output) {
281 282 283 284
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
285 286 287
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
288 289 290 291 292 293 294 295
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
296
    T* output_data = output->mutable_data<T>(context.GetPlace());
297 298 299 300 301 302

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
303
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
304 305
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
306
        stride_width, padding_height, padding_width, pool_process, exclusive,
307
        adaptive, output_data);
308
  }
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, PoolProcess pool_process,
                  bool exclusive, bool adaptive, framework::Tensor* output) {
    bool channel_last = (data_format == "NHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output->dims()[3] : output->dims()[1];
    const int output_height =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_width =
        channel_last ? output->dims()[2] : output->dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];
334

335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
};
C
chengduoZH 已提交
353
/*
354 355 356 357 358 359
* Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height
* and width, respectively.
* Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively.
*/
360
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
361
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
362
 public:
Q
QI JUN 已提交
363
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
364
                  const framework::Tensor& input,
365
                  const framework::Tensor& output,
C
chengduo 已提交
366 367 368 369
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
370 371
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
388
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
389 390 391 392 393 394

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
395
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
396 397 398
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
399
        pool_process, exclusive, adaptive, input_grad_data);
400
  }
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, PoolProcess pool_process, bool exclusive,
      bool adaptive, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();

    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        pool_process, exclusive, adaptive, input_grad_data, channel_last);
  }
448 449
};

C
chengduoZH 已提交
450
/*
451 452 453 454 455 456
* Tensors are in NCHW or NHWC format.
* Ksize, strides are two elements. These two elements represent height
* and width, respectively.
* Paddings are four elements. These four elements represent height_up,
* height_down, width_left and width_right, respectively.
*/
457
template <typename T>
Q
QI JUN 已提交
458
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
459
 public:
Q
QI JUN 已提交
460
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
461
                  const framework::Tensor& input,
462
                  const framework::Tensor& output,
C
chengduo 已提交
463 464 465 466
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
467
                  framework::Tensor* input_grad) {
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
485
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
486 487 488 489 490 491

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
492
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
493 494 495 496
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
497
  }
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NHWC");

    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
    const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_width = channel_last ? input.dims()[2] : input.dims()[3];

    const int output_channels =
        channel_last ? output.dims()[3] : output.dims()[1];
    const int output_height =
        channel_last ? output.dims()[1] : output.dims()[2];
    const int output_width = channel_last ? output.dims()[2] : output.dims()[3];

    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];

    const int stride_height = strides[0];
    const int stride_width = strides[1];

    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data, channel_last);
  }
543 544
};

N
nhzlx 已提交
545 546 547 548 549
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
550 551
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
552

Q
QI JUN 已提交
553
template class Pool2dFunctor<platform::CUDADeviceContext,
554
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
555
template class Pool2dFunctor<platform::CUDADeviceContext,
556
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
557 558 559 560 561 562 563
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
564
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
565
template class Pool2dFunctor<platform::CUDADeviceContext,
566
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
567 568 569 570 571 572
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
573 574

template <typename PoolProcess, typename T>
575
__global__ void KernelPool3D(
576 577 578
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
579
    const int ksize_depth, const int ksize_height, const int ksize_width,
580
    const int stride_depth, const int stride_height, const int stride_width,
581
    const int padding_depth, const int padding_height, const int padding_width,
582 583
    PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
    bool channel_last = false) {
584
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
585
       index += blockDim.x * gridDim.x) {
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601
    int pw, ph, pd, c, batch_idx;
    if (!channel_last) {
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else {
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }
602 603 604 605 606

    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
607 608
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
609

D
dengkaipeng 已提交
610 611
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
612

D
dengkaipeng 已提交
613 614
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
615 616 617 618 619 620 621 622 623 624 625
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
626 627 628 629 630 631 632 633 634 635 636

    int input_data_stride;
    if (!channel_last) { /* NCDHW */
      input_data_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else { /* NDHWC */
      input_data_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_data_stride;

637
    T ele = pool_process.initial();
638 639 640
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
641 642 643 644 645
          auto input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          pool_process.compute(input_data[input_data_idx], &ele);
646 647 648
        }
      }
    }
649
    int pool_size = (exclusive || adaptive)
650 651
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
652
    pool_process.finalize(static_cast<T>(pool_size), &ele);
653 654 655 656 657
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
658
__global__ void KernelPool3DGrad(
659
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
660 661 662 663 664 665
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
666
    bool exclusive, bool adaptive, T* input_grad, bool channel_last = false) {
667
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
668
       index += blockDim.x * gridDim.x) {
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
    int w_offset, h_offset, d_offset, offsetC, batch_idx;
    if (!channel_last) { /* "NCDHW" */
      w_offset = index % input_width + padding_width;
      h_offset = (index / input_width) % input_height + padding_height;
      d_offset =
          (index / input_width / input_height) % input_depth + padding_depth;
      offsetC = (index / input_width / input_height / input_depth) % channels;
      batch_idx = index / input_width / input_height / input_depth / channels;

    } else { /* "NDHWC" */
      offsetC = index % channels;
      w_offset = (index / channels) % input_width + padding_width;
      h_offset =
          (index / channels / input_width) % input_height + padding_height;
      d_offset = (index / channels / input_width / input_height) % input_depth +
                 padding_depth;
      batch_idx = index / channels / input_width / input_height / input_depth;
    }
687

688 689 690 691
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
692 693 694 695 696 697 698 699
      pdstart = AdaptStartIndex(d_offset, output_depth, input_depth);
      pdend = AdaptEndIndex(d_offset, output_depth, input_depth);

      phstart = AdaptStartIndex(h_offset, output_height, input_height);
      phend = AdaptEndIndex(h_offset, output_height, input_height);

      pwstart = AdaptStartIndex(w_offset, output_width, input_width);
      pwend = AdaptEndIndex(w_offset, output_width, input_width);
700
    } else {
D
dengkaipeng 已提交
701
      pdstart = (d_offset < ksize_depth)
702
                    ? 0
D
dengkaipeng 已提交
703 704
                    : (d_offset - ksize_depth) / stride_depth + 1;
      phstart = (h_offset < ksize_height)
705
                    ? 0
D
dengkaipeng 已提交
706 707
                    : (h_offset - ksize_height) / stride_height + 1;
      pwstart = (w_offset < ksize_width)
708
                    ? 0
D
dengkaipeng 已提交
709 710 711 712
                    : (w_offset - ksize_width) / stride_width + 1;
      pdend = min((d_offset) / stride_depth + 1, output_depth);
      phend = min((h_offset) / stride_height + 1, output_height);
      pwend = min((w_offset) / stride_width + 1, output_width);
713
    }
714 715 716

    T gradient = 0;
    T input = input_data[index];
717 718 719 720 721 722 723 724 725 726 727

    int output_stride;
    if (!channel_last) {
      output_stride = (batch_idx * channels + offsetC) * output_depth *
                      output_height * output_width;
    } else {
      output_stride =
          batch_idx * output_depth * output_height * output_width * channels;
    }
    output_data += output_stride;
    output_grad += output_stride;
728 729 730 731 732

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
          int pool_size;
          if (adaptive) {
            pool_size =
                static_cast<int>(
                    ceil(static_cast<double>(input_depth) / ksize_depth)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_height) / ksize_height)) *
                static_cast<int>(
                    ceil(static_cast<double>(input_width) / ksize_width));
          } else {
            int dstart = pd * stride_depth - padding_depth;
            int hstart = ph * stride_height - padding_height;
            int wstart = pw * stride_width - padding_width;
            int dend = min(dstart + ksize_depth, input_depth);
            int hend = min(hstart + ksize_height, input_height);
            int wend = min(wstart + ksize_width, input_width);
            dstart = max(dstart, 0);
            hstart = max(hstart, 0);
            wstart = max(wstart, 0);
            pool_size =
                exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                          : ksize_depth * ksize_height * ksize_width;
          }
756 757 758 759 760 761 762

          int output_sub_idx =
              channel_last
                  ? ((pd * output_height + ph) * output_width + pw) * channels +
                        offsetC
                  : (pd * output_height + ph) * output_width + pw;

763
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
764 765
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
766 767 768 769 770 771 772
        }
      }
    }
    input_grad[index] = gradient;
  }
}

773
template <typename T>
774
__global__ void KernelMaxPool3DGrad(
775
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
776 777 778 779 780
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
781 782
    const int padding_height, const int padding_width, T* input_grad,
    bool channel_last = false) {
783
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
784
       index += blockDim.x * gridDim.x) {
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
    int pw, ph, pd, c, batch_idx;

    if (!channel_last) { /*NCDHW*/
      pw = index % output_width;
      ph = (index / output_width) % output_height;
      pd = (index / output_width / output_height) % output_depth;
      c = (index / output_width / output_height / output_depth) % channels;
      batch_idx =
          index / output_width / output_height / output_depth / channels;
    } else { /*NDHWC*/
      c = index % channels;
      pw = (index / channels) % output_width;
      ph = (index / channels / output_width) % output_height;
      pd = (index / channels / output_width / output_height) % output_depth;
      batch_idx =
          index / channels / output_width / output_height / output_depth;
    }

803 804 805
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
806

807 808 809
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
810

811 812 813
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
814

815 816 817 818
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;

819 820 821 822 823 824 825 826 827 828
    int input_stride;
    if (!channel_last) {
      input_stride =
          (batch_idx * channels + c) * input_depth * input_height * input_width;
    } else {
      input_stride =
          batch_idx * input_depth * input_height * input_width * channels;
    }
    input_data += input_stride;
    input_grad += input_stride;
829 830 831
    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
832 833 834 835 836
          int input_data_idx =
              channel_last
                  ? ((d * input_height + h) * input_width + w) * channels + c
                  : (d * input_height + h) * input_width + w;
          if (ele == input_data[input_data_idx]) {
837
            stop = true;
838
            maxIdx = input_data_idx;
839 840 841 842 843 844
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
845
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
846 847 848 849
    }
  }
}

C
chengduoZH 已提交
850
/*
851 852 853 854 855 856 857
* Tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth,
* depth_back,
* height_up, height_down, width_left and width_right, respectively.
*/
858
template <typename PoolProcess, class T>
Q
QI JUN 已提交
859
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
860
 public:
Q
QI JUN 已提交
861
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
862 863 864
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
865
                  bool exclusive, bool adaptive, framework::Tensor* output) {
866 867 868 869 870
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
871 872 873 874
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
875 876 877 878 879 880 881 882 883 884 885
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
886
    T* output_data = output->mutable_data<T>(context.GetPlace());
887 888 889 890 891 892 893

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
894
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
895 896 897
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
898
        padding_depth, padding_height, padding_width, pool_process, exclusive,
899
        adaptive, output_data);
900
  }
901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
                  const std::string data_format, PoolProcess pool_process,
                  bool exclusive, bool adaptive, framework::Tensor* output) {
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output->dims()[4] : output->dims()[1];
    const int output_depth =
        channel_last ? output->dims()[1] : output->dims()[2];
    const int output_height =
        channel_last ? output->dims()[2] : output->dims()[3];
    const int output_width =
        channel_last ? output->dims()[3] : output->dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        adaptive, output_data, channel_last);
  }
952 953
};

C
chengduoZH 已提交
954
/*
955 956 957 958 959 960 961
* Tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth,
* depth_back,
* height_up, height_down, width_left and width_right, respectively.
*/
962
template <typename PoolProcess, class T>
Q
QI JUN 已提交
963
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
964
 public:
Q
QI JUN 已提交
965
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
966
                  const framework::Tensor& input,
967
                  const framework::Tensor& output,
C
chengduo 已提交
968 969 970 971
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
972 973
                  bool exclusive, bool adaptive,
                  framework::Tensor* input_grad) {
974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
996
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
997

998 999
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
1000 1001 1002 1003
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1004
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1005 1006 1007 1008
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
1009
        padding_width, pool_process, exclusive, adaptive, input_grad_data);
1010
  }
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, PoolProcess pool_process, bool exclusive,
      bool adaptive, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NDHWC");

    const int batch_size = input.dims()[0];
    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, exclusive, adaptive, input_grad_data,
        channel_last);  // add channel_last
  }
1064 1065
};

C
chengduoZH 已提交
1066
/*
1067 1068 1069 1070 1071 1072 1073
* tensors are in NCDHW or NDHWC format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
* Paddings are six elements. These six elements represent depth_forth,
* depth_back,
* height_up, height_down, width_left and width_right, respectively.
*/
1074
template <class T>
Q
QI JUN 已提交
1075
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
1076
 public:
Q
QI JUN 已提交
1077
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1078
                  const framework::Tensor& input,
1079
                  const framework::Tensor& output,
C
chengduo 已提交
1080 1081 1082 1083
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1084
                  framework::Tensor* input_grad) {
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
1107
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
1108 1109 1110 1111 1112 1113 1114

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1115
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1116 1117 1118 1119 1120
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
1121
  }
1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
  void operator()(
      const platform::CUDADeviceContext& context,
      const framework::Tensor& input, const framework::Tensor& output,
      const framework::Tensor& output_grad, const std::vector<int>& ksize,
      const std::vector<int>& strides, const std::vector<int>& paddings,
      const std::string data_format, framework::Tensor* input_grad) {
    bool channel_last = (data_format == "NDHWC");
    const int batch_size = input.dims()[0];

    const int input_channels = channel_last ? input.dims()[4] : input.dims()[1];
    const int input_depth = channel_last ? input.dims()[1] : input.dims()[2];
    const int input_height = channel_last ? input.dims()[2] : input.dims()[3];
    const int input_width = channel_last ? input.dims()[3] : input.dims()[4];

    const int output_channels =
        channel_last ? output.dims()[4] : output.dims()[1];
    const int output_depth = channel_last ? output.dims()[1] : output.dims()[2];
    const int output_height =
        channel_last ? output.dims()[2] : output.dims()[3];
    const int output_width = channel_last ? output.dims()[3] : output.dims()[4];

    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];

    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];

    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data, channel_last);  // add channel_last
  }
1173 1174
};

Q
QI JUN 已提交
1175 1176
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
1177

Q
QI JUN 已提交
1178
template class Pool3dFunctor<platform::CUDADeviceContext,
1179
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
1180
template class Pool3dFunctor<platform::CUDADeviceContext,
1181
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
1182 1183 1184 1185 1186 1187 1188
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
1189
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
1190
template class Pool3dFunctor<platform::CUDADeviceContext,
1191
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
1192 1193 1194 1195 1196 1197
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
1198

C
chengduoZH 已提交
1199
template <typename T1, typename T2>
C
chengduoZH 已提交
1200
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
1201
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1202 1203 1204
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
1205
    const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1206
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1207
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
1208 1209 1210 1211 1212
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

1213 1214 1215
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1216 1217
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
C
chengduoZH 已提交
1218

D
dengkaipeng 已提交
1219 1220
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1221 1222 1223 1224 1225 1226 1227 1228 1229
    } else {
      hstart = ph * stride_height - padding_height;
      hend = min(hstart + ksize_height, input_height);
      hstart = max(hstart, 0);

      wstart = pw * stride_width - padding_width;
      wend = min(wstart + ksize_width, input_width);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1230 1231

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
1232
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1233
    int max_index = -1;
C
chengduoZH 已提交
1234 1235
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
1236 1237 1238 1239
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
1240 1241 1242 1243
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1244
    mask_data[index] = max_index;
C
chengduoZH 已提交
1245 1246 1247
  }
}

C
chengduoZH 已提交
1248
template <typename T1, typename T2>
C
chengduoZH 已提交
1249
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
1250
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
1251 1252 1253
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
1254 1255
    const int padding_height, const int padding_width, bool adaptive,
    T1* input_grad) {
C
chengduoZH 已提交
1256
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1257
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1258 1259
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
1260
    int offsetC = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
1261 1262
    int batch_idx = index / input_width / input_height / channels;

1263 1264 1265
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1266
      phstart = h_offset * output_height / input_height;
1267
      phend =
D
dengkaipeng 已提交
1268 1269 1270 1271
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1272 1273
    } else {
      phstart =
D
dengkaipeng 已提交
1274
          (h_offset + padding_height < ksize_height)
1275
              ? 0
D
dengkaipeng 已提交
1276
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1277
      pwstart =
D
dengkaipeng 已提交
1278
          (w_offset + padding_width < ksize_width)
1279
              ? 0
D
dengkaipeng 已提交
1280
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
1281
      phend =
D
dengkaipeng 已提交
1282 1283
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1284
    }
C
chengduoZH 已提交
1285

C
chengduoZH 已提交
1286
    T1 gradient = 0;
D
dengkaipeng 已提交
1287
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
1288
    int output_idx =
1289
        (batch_idx * channels + offsetC) * output_height * output_width;
C
chengduoZH 已提交
1290

C
chengduoZH 已提交
1291 1292
    mask_data += output_idx;
    output_grad += output_idx;
1293 1294
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1295
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
1296 1297 1298 1299 1300 1301 1302
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1303 1304 1305 1306 1307
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1308
template <typename T1, typename T2>
Q
QI JUN 已提交
1309
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1310
 public:
Q
QI JUN 已提交
1311
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1312 1313
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1314 1315
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1316 1317 1318 1319
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
1320 1321 1322
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
1323 1324 1325 1326 1327 1328 1329
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1330 1331 1332
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1333 1334 1335 1336 1337 1338

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1339
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1340 1341
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
1342 1343
        stride_width, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1344 1345 1346
  }
};

C
chengduoZH 已提交
1347 1348 1349 1350 1351
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
1352
template <typename T1, typename T2>
Q
QI JUN 已提交
1353
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1354
 public:
Q
QI JUN 已提交
1355
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1356
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1357 1358
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1359
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1360 1361 1362 1363 1364
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
1365 1366 1367 1368 1369 1370 1371 1372 1373
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
1374 1375 1376
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1377 1378 1379 1380 1381 1382

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1383
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1384 1385
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
1386
        stride_height, stride_width, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1387
        input_grad_data);
C
chengduoZH 已提交
1388 1389 1390
  }
};

Q
QI JUN 已提交
1391 1392 1393 1394 1395 1396 1397 1398
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1399

C
chengduoZH 已提交
1400
template <typename T1, typename T2>
C
chengduoZH 已提交
1401
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
1402
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
1403 1404 1405 1406 1407
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
1408
    bool adaptive, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
1409
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1410 1411 1412 1413 1414 1415 1416
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
1417

1418 1419 1420 1421
    int dstart, dend;
    int hstart, hend;
    int wstart, wend;
    if (adaptive) {
D
dengkaipeng 已提交
1422 1423
      dstart = AdaptStartIndex(pd, input_depth, output_depth);
      dend = AdaptEndIndex(pd, input_depth, output_depth);
1424

D
dengkaipeng 已提交
1425 1426
      hstart = AdaptStartIndex(ph, input_height, output_height);
      hend = AdaptEndIndex(ph, input_height, output_height);
1427

D
dengkaipeng 已提交
1428 1429
      wstart = AdaptStartIndex(pw, input_width, output_width);
      wend = AdaptEndIndex(pw, input_width, output_width);
1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440
    } else {
      dstart = pd * stride_depth - padding_depth;
      hstart = ph * stride_height - padding_height;
      wstart = pw * stride_width - padding_width;
      dend = min(dstart + ksize_depth, input_depth);
      hend = min(hstart + ksize_height, input_height);
      wend = min(wstart + ksize_width, input_width);
      dstart = max(dstart, 0);
      hstart = max(hstart, 0);
      wstart = max(wstart, 0);
    }
C
chengduoZH 已提交
1441

C
chengduoZH 已提交
1442
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
1443
    int max_index = -1;
C
chengduoZH 已提交
1444 1445 1446 1447 1448 1449 1450
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
1451 1452
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
1453 1454 1455 1456 1457
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
1458
    mask_data[index] = max_index;
C
chengduoZH 已提交
1459 1460 1461
  }
}

C
chengduoZH 已提交
1462
template <typename T1, typename T2>
C
chengduoZH 已提交
1463
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
1464 1465 1466 1467 1468 1469
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
1470
    const int padding_width, bool adaptive, T1* input_grad) {
C
chengduoZH 已提交
1471
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
1472
       index += blockDim.x * gridDim.x) {
D
dengkaipeng 已提交
1473 1474 1475
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
1476
    int offsetC = (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
1477 1478
    int batch_idx = index / input_width / input_height / input_depth / channels;

1479 1480 1481 1482
    int pdstart, pdend;
    int phstart, phend;
    int pwstart, pwend;
    if (adaptive) {
D
dengkaipeng 已提交
1483 1484 1485 1486
      pdstart = d_offset * output_depth / input_depth;
      pdend =
          min((d_offset + 1) * output_depth / input_depth + 1, output_depth);
      phstart = h_offset * output_height / input_height;
1487
      phend =
D
dengkaipeng 已提交
1488 1489 1490 1491
          min((h_offset + 1) * output_height / input_height + 1, output_height);
      pwstart = w_offset * output_width / input_width;
      pwend =
          min((w_offset + 1) * output_width / input_width + 1, output_width);
1492 1493
    } else {
      pdstart =
D
dengkaipeng 已提交
1494
          (d_offset + padding_depth < ksize_depth)
1495
              ? 0
D
dengkaipeng 已提交
1496
              : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
1497
      phstart =
D
dengkaipeng 已提交
1498
          (h_offset + padding_height < ksize_height)
1499
              ? 0
D
dengkaipeng 已提交
1500
              : (h_offset + padding_height - ksize_height) / stride_height + 1;
1501
      pwstart =
D
dengkaipeng 已提交
1502
          (w_offset + padding_width < ksize_width)
1503
              ? 0
D
dengkaipeng 已提交
1504 1505
              : (w_offset + padding_width - ksize_width) / stride_width + 1;
      pdend = min((d_offset + padding_depth) / stride_depth + 1, output_depth);
1506
      phend =
D
dengkaipeng 已提交
1507 1508
          min((h_offset + padding_height) / stride_height + 1, output_height);
      pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
1509
    }
C
chengduoZH 已提交
1510

C
chengduoZH 已提交
1511
    T1 gradient = 0;
C
chengduoZH 已提交
1512
    int input_current_feature_map_idx =
D
dengkaipeng 已提交
1513
        (d_offset * input_height + h_offset) * input_width + w_offset;
1514
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
C
chengduoZH 已提交
1515 1516 1517 1518
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

1519 1520 1521
    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
C
chengduoZH 已提交
1522 1523
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
1524 1525 1526 1527 1528 1529 1530 1531 1532
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
1533 1534 1535 1536 1537
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1538
template <typename T1, typename T2>
Q
QI JUN 已提交
1539
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1540
 public:
Q
QI JUN 已提交
1541
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
1542 1543
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1544 1545
                  const std::vector<int>& paddings, bool adaptive,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
1546 1547 1548 1549 1550
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1551 1552 1553 1554
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1565 1566 1567
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1568 1569 1570 1571 1572 1573 1574

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1575
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1576 1577 1578
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
1579 1580
        padding_depth, padding_height, padding_width, adaptive, output_data,
        mask_data);
C
chengduoZH 已提交
1581 1582 1583
  }
};

C
chengduoZH 已提交
1584 1585 1586 1587 1588
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1589
template <typename T1, typename T2>
Q
QI JUN 已提交
1590
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1591
 public:
Q
QI JUN 已提交
1592
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1593
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1594 1595
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
1596
                  const std::vector<int>& paddings, bool adaptive,
C
chengduoZH 已提交
1597 1598 1599 1600 1601 1602
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1603 1604 1605
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1616 1617 1618
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1619 1620 1621 1622 1623 1624 1625

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1626
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1627 1628 1629
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
1630
        stride_width, padding_depth, padding_height, padding_width, adaptive,
C
chengduoZH 已提交
1631
        input_grad_data);
C
chengduoZH 已提交
1632 1633 1634
  }
};

Q
QI JUN 已提交
1635 1636 1637 1638 1639 1640 1641 1642
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1643 1644 1645 1646

}  // namespace math
}  // namespace operators
}  // namespace paddle