pooling.cu 44.8 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_helper.h"
C
chengduoZH 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

22
template <typename PoolProcess, typename T>
23
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
24 25 26 27 28 29 30
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
                             T* output_data) {
31 32
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
33 34 35 36 37 38 39 40 41 42 43 44 45 46
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
47
    T ele = pool_process.initial();
48 49
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
50
        pool_process.compute(ele, input_data[h * input_width + w]);
51 52 53
      }
    }
    int pool_size = (hend - hstart) * (wend - wstart);
54
    pool_process.finalize(ele, (static_cast<T>(pool_size)));
55 56 57 58 59
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
60
__global__ void KernelPool2DGrad(
61
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
62 63 64 65 66
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    PoolProcess pool_process, T* input_grad) {
67 68
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
        int pool_size = (hend - hstart) * (wend - wstart);
        int output_sub_idx = ph * output_width + pw;
98
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
99 100
                             output_grad[output_sub_idx], gradient,
                             static_cast<T>(1.0 / pool_size));
101 102 103 104 105 106
      }
    }
    input_grad[index] = gradient;
  }
}

107
template <typename T>
108
__global__ void KernelMaxPool2DGrad(
109
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
110 111 112 113 114
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
115 116
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
147
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
148 149 150 151
    }
  }
}

C
chengduoZH 已提交
152 153 154 155 156
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
157
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
158
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
159
 public:
Q
QI JUN 已提交
160
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
161 162 163
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  PoolProcess pool_process, framework::Tensor* output) {
164 165 166 167
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
168 169 170
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
171 172 173 174 175 176 177 178
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
179
    T* output_data = output->mutable_data<T>(context.GetPlace());
180 181 182 183 184 185

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
186
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
187 188 189
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, pool_process, output_data);
190 191 192
  }
};

C
chengduoZH 已提交
193 194 195 196 197
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
198
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
199
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
200
 public:
Q
QI JUN 已提交
201
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
202
                  const framework::Tensor& input,
203 204 205
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
C
chengduoZH 已提交
206
                  PoolProcess pool_process, framework::Tensor* input_grad) {
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
223
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
224 225 226 227 228 229

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
230
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
231 232 233 234
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        pool_process, input_grad_data);
235 236 237
  }
};

C
chengduoZH 已提交
238 239 240 241 242
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
243
template <typename T>
Q
QI JUN 已提交
244
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
245
 public:
Q
QI JUN 已提交
246
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
247
                  const framework::Tensor& input,
248 249
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
C
chengduoZH 已提交
250 251
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
269
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
270 271 272 273 274 275

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
276
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
277 278 279 280
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
281 282 283
  }
};

Q
QI JUN 已提交
284 285
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
286

Q
QI JUN 已提交
287
template class Pool2dFunctor<platform::CUDADeviceContext,
288
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
289
template class Pool2dFunctor<platform::CUDADeviceContext,
290
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
291 292 293 294 295 296 297
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
298
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
299
template class Pool2dFunctor<platform::CUDADeviceContext,
300
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
301 302 303 304 305 306
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
307 308

template <typename PoolProcess, typename T>
C
chengduoZH 已提交
309 310 311 312 313 314 315 316 317 318
__global__ void KernelPool3D(const int nthreads, const T* input_data,
                             const int channels, const int input_depth,
                             const int input_height, const int input_width,
                             const int output_depth, const int output_height,
                             const int output_width, const int ksize_depth,
                             const int ksize_height, const int ksize_width,
                             const int stride_depth, const int stride_height,
                             const int stride_width, const int padding_depth,
                             const int padding_height, const int padding_width,
                             PoolProcess pool_process, T* output_data) {
319
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
336
    T ele = pool_process.initial();
337 338 339 340 341
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
342
          pool_process.compute(
343 344 345 346 347
              ele, input_data[(d * input_height + h) * input_width + w]);
        }
      }
    }
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
348
    pool_process.finalize(ele, static_cast<T>(pool_size));
349 350 351 352 353
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
354
__global__ void KernelPool3DGrad(
355
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
356 357 358 359 360 361 362
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
    T* input_grad) {
363
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
364 365 366 367 368 369 370 371 372 373
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
374
                      : (offsetD - ksize_depth) / stride_depth + 1;
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
406
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
407
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
408 409
                               output_grad[output_sub_idx], gradient,
                               static_cast<T>(1.0 / pool_size));
410 411 412 413 414 415 416
        }
      }
    }
    input_grad[index] = gradient;
  }
}

417
template <typename T>
418
__global__ void KernelMaxPool3DGrad(
419
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
420 421 422 423 424 425
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
426
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
463
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
464 465 466 467
    }
  }
}

C
chengduoZH 已提交
468 469 470 471 472
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
473
template <typename PoolProcess, class T>
Q
QI JUN 已提交
474
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
475
 public:
Q
QI JUN 已提交
476
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
477 478 479
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  PoolProcess pool_process, framework::Tensor* output) {
480 481 482 483 484
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
485 486 487 488
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
489 490 491 492 493 494 495 496 497 498 499
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
500
    T* output_data = output->mutable_data<T>(context.GetPlace());
501 502 503 504 505 506 507

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
508
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
509 510 511 512 513
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process,
        output_data);
514 515 516
  }
};

C
chengduoZH 已提交
517 518 519 520 521
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
522
template <typename PoolProcess, class T>
Q
QI JUN 已提交
523
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
524
 public:
Q
QI JUN 已提交
525
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
526
                  const framework::Tensor& input,
527 528 529
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
C
chengduoZH 已提交
530
                  PoolProcess pool_process, framework::Tensor* input_grad) {
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
553
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
554

555 556
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
557 558 559 560
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
561
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
562 563 564 565 566
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, input_grad_data);
567 568 569
  }
};

C
chengduoZH 已提交
570 571 572 573 574
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
575
template <class T>
Q
QI JUN 已提交
576
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
577
 public:
Q
QI JUN 已提交
578
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
579
                  const framework::Tensor& input,
580 581
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
C
chengduoZH 已提交
582 583
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
606
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
607 608 609 610 611 612 613

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
614
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
615 616 617 618 619
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
620 621 622
  }
};

Q
QI JUN 已提交
623 624
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
625

Q
QI JUN 已提交
626
template class Pool3dFunctor<platform::CUDADeviceContext,
627
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
628
template class Pool3dFunctor<platform::CUDADeviceContext,
629
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
630 631 632 633 634 635 636
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
637
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
638
template class Pool3dFunctor<platform::CUDADeviceContext,
639
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
640 641 642 643 644 645
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
646

C
chengduoZH 已提交
647
template <typename T1, typename T2>
C
chengduoZH 已提交
648
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
649
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
650 651 652
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
C
chengduoZH 已提交
653
    const int padding_width, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
654
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
655
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
670
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
671
    int max_index = -1;
C
chengduoZH 已提交
672 673
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
674 675 676 677
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
678 679 680 681
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
682
    mask_data[index] = max_index;
C
chengduoZH 已提交
683 684 685
  }
}

C
chengduoZH 已提交
686
template <typename T1, typename T2>
C
chengduoZH 已提交
687
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
688
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
689 690 691
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
C
chengduoZH 已提交
692
    const int padding_height, const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
693
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
694 695 696 697
       index += blockDim.x * gridDim.x) {
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int c_offset = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
698 699
    int batch_idx = index / input_width / input_height / channels;

C
chengduoZH 已提交
700 701 702 703 704 705 706 707 708 709 710 711 712
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);

C
chengduoZH 已提交
713
    T1 gradient = 0;
C
chengduoZH 已提交
714
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
715
    int output_idx =
C
chengduoZH 已提交
716 717
        (batch_idx * channels + c_offset) * output_height * output_width;

C
chengduoZH 已提交
718 719
    mask_data += output_idx;
    output_grad += output_idx;
C
chengduoZH 已提交
720 721 722
    for (int ph = ph_start; ph < ph_end; ++ph) {
      for (int pw = pw_start; pw < pw_end; ++pw) {
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
723 724 725 726 727 728 729
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
730 731 732 733 734
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
735
template <typename T1, typename T2>
Q
QI JUN 已提交
736
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
737
 public:
Q
QI JUN 已提交
738
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
739 740 741
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
742 743 744 745
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
746 747 748
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
749 750 751 752 753 754 755
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
756 757 758
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
759 760 761 762 763 764

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
765
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
766 767 768
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
769 770 771
  }
};

C
chengduoZH 已提交
772 773 774 775 776
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
777
template <typename T1, typename T2>
Q
QI JUN 已提交
778
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
779
 public:
Q
QI JUN 已提交
780
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
781 782
                  const framework::Tensor& output_grad,
                  const framework::Tensor& mask, std::vector<int>& ksize,
C
chengduoZH 已提交
783 784 785 786 787 788
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
789 790 791 792 793 794 795 796 797
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
798 799 800
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
801 802 803 804 805 806

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
807
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
808 809 810 811
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
812 813 814
  }
};

Q
QI JUN 已提交
815 816 817 818 819 820 821 822
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
823

C
chengduoZH 已提交
824
template <typename T1, typename T2>
C
chengduoZH 已提交
825
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
826
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
827 828 829 830 831
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
C
chengduoZH 已提交
832
    T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
833
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
834 835 836 837 838 839 840
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
841

C
chengduoZH 已提交
842 843 844 845 846 847 848 849 850
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
C
chengduoZH 已提交
851

C
chengduoZH 已提交
852
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
853
    int max_index = -1;
C
chengduoZH 已提交
854 855 856 857 858 859 860
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
861 862
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
863 864 865 866 867
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
868
    mask_data[index] = max_index;
C
chengduoZH 已提交
869 870 871
  }
}

C
chengduoZH 已提交
872
template <typename T1, typename T2>
C
chengduoZH 已提交
873
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
874 875 876 877 878 879 880
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
881
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
882
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
883 884 885 886 887
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
888 889
    int batch_idx = index / input_width / input_height / input_depth / channels;

C
chengduoZH 已提交
890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907
    int pd_start =
        (d_offset + padding_depth < ksize_depth)
            ? 0
            : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int pd_end =
        min((d_offset + padding_depth) / stride_depth + 1, output_depth);
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);
C
chengduoZH 已提交
908

C
chengduoZH 已提交
909
    T1 gradient = 0;
C
chengduoZH 已提交
910 911 912
    int input_current_feature_map_idx =
        (d_offset * input_height + h_offset) * input_width + w_offset;
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
913 914 915 916
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

C
chengduoZH 已提交
917 918 919 920 921
    for (int pd = pd_start; pd < pd_end; ++pd) {
      for (int ph = ph_start; ph < ph_end; ++ph) {
        for (int pw = pw_start; pw < pw_end; ++pw) {
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
922 923 924 925 926 927 928 929 930
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
931 932 933 934 935
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
936
template <typename T1, typename T2>
Q
QI JUN 已提交
937
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
938
 public:
Q
QI JUN 已提交
939
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
940 941 942
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
943 944 945 946 947
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
948 949 950 951
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
952 953 954 955 956 957 958 959 960 961
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
962 963 964
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
965 966 967 968 969 970 971

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
972
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
973 974 975 976
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
977 978 979
  }
};

C
chengduoZH 已提交
980 981 982 983 984
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
985
template <typename T1, typename T2>
Q
QI JUN 已提交
986
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
987
 public:
Q
QI JUN 已提交
988
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
989 990
                  const framework::Tensor& output_grad,
                  const framework::Tensor& mask, std::vector<int>& ksize,
C
chengduoZH 已提交
991 992 993 994 995 996 997
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
998 999 1000
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1011 1012 1013
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1014 1015 1016 1017 1018 1019 1020

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1021
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1022 1023 1024 1025 1026
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
1027 1028 1029
  }
};

Q
QI JUN 已提交
1030 1031 1032 1033 1034 1035 1036 1037
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1038 1039 1040 1041

}  // namespace math
}  // namespace operators
}  // namespace paddle