pooling.cu 44.4 KB
Newer Older
C
chengduoZH 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/pooling.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

22
template <typename PoolProcess, typename T>
23
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
24 25 26 27 28 29 30
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
                             T* output_data) {
31 32
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
33 34 35 36 37 38 39 40 41 42 43 44 45 46
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
47
    T ele = pool_process.initial();
48 49
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
50
        pool_process.compute(ele, input_data[h * input_width + w]);
51 52 53
      }
    }
    int pool_size = (hend - hstart) * (wend - wstart);
54
    pool_process.finalize(ele, (static_cast<T>(pool_size)));
55 56 57 58 59
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
60
__global__ void KernelPool2DGrad(
61
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
62 63 64 65 66
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    PoolProcess pool_process, T* input_grad) {
67 68
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
        int pool_size = (hend - hstart) * (wend - wstart);
        int output_sub_idx = ph * output_width + pw;
98
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
99 100
                             output_grad[output_sub_idx], gradient,
                             static_cast<T>(1.0 / pool_size));
101 102 103 104 105 106
      }
    }
    input_grad[index] = gradient;
  }
}

107
template <typename T>
108
__global__ void KernelMaxPool2DGrad(
109
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
110 111 112 113 114
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
115 116
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
147
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
148 149 150 151
    }
  }
}

C
chengduoZH 已提交
152 153 154 155 156
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
157
template <typename PoolProcess, typename T>
C
chengduoZH 已提交
158
class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
159
 public:
160
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
161 162 163
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  PoolProcess pool_process, framework::Tensor* output) {
164 165 166 167
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
168 169 170
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
171 172 173 174 175 176 177 178
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
179
    T* output_data = output->mutable_data<T>(context.GetPlace());
180 181 182 183 184 185

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

186
    KernelPool2D<
C
chengduoZH 已提交
187 188 189
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
C
chengduoZH 已提交
190 191 192 193
                 .stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, pool_process, output_data);
194 195 196
  }
};

C
chengduoZH 已提交
197 198 199 200 201
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
202
template <typename PoolProcess, typename T>
C
chengduoZH 已提交
203
class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
204
 public:
205
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
206
                  const framework::Tensor& input,
207 208 209
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
C
chengduoZH 已提交
210
                  PoolProcess pool_process, framework::Tensor* input_grad) {
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
227
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
228 229 230 231 232 233

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

234
    KernelPool2DGrad<
C
chengduoZH 已提交
235 236 237 238
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
239 240 241 242
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        pool_process, input_grad_data);
243 244 245
  }
};

C
chengduoZH 已提交
246 247 248 249 250
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
251
template <typename T>
C
chengduoZH 已提交
252
class MaxPool2dGradFunctor<platform::GPUPlace, T> {
253 254
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
255
                  const framework::Tensor& input,
256 257
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
C
chengduoZH 已提交
258 259
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
277
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
278 279 280 281 282 283

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

284
    KernelMaxPool2DGrad<
285 286 287
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
288 289 290 291
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
292 293 294
  }
};

C
chengduoZH 已提交
295
template class MaxPool2dGradFunctor<platform::GPUPlace, float>;
C
chengduoZH 已提交
296
template class MaxPool2dGradFunctor<platform::GPUPlace, double>;
C
chengduoZH 已提交
297 298

template class Pool2dFunctor<platform::GPUPlace,
299
                             paddle::operators::math::MaxPool<float>, float>;
C
chengduoZH 已提交
300
template class Pool2dFunctor<platform::GPUPlace,
301
                             paddle::operators::math::AvgPool<float>, float>;
C
chengduoZH 已提交
302
template class Pool2dGradFunctor<
303
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
C
chengduoZH 已提交
304
template class Pool2dGradFunctor<
305
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
C
chengduoZH 已提交
306
template class Pool2dFunctor<platform::GPUPlace,
307
                             paddle::operators::math::MaxPool<double>, double>;
C
chengduoZH 已提交
308
template class Pool2dFunctor<platform::GPUPlace,
309
                             paddle::operators::math::AvgPool<double>, double>;
C
chengduoZH 已提交
310
template class Pool2dGradFunctor<
311
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
C
chengduoZH 已提交
312
template class Pool2dGradFunctor<
313
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
314 315

template <typename PoolProcess, typename T>
C
chengduoZH 已提交
316 317 318 319 320 321 322 323 324 325
__global__ void KernelPool3D(const int nthreads, const T* input_data,
                             const int channels, const int input_depth,
                             const int input_height, const int input_width,
                             const int output_depth, const int output_height,
                             const int output_width, const int ksize_depth,
                             const int ksize_height, const int ksize_width,
                             const int stride_depth, const int stride_height,
                             const int stride_width, const int padding_depth,
                             const int padding_height, const int padding_width,
                             PoolProcess pool_process, T* output_data) {
326
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
343
    T ele = pool_process.initial();
344 345 346 347 348
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
349
          pool_process.compute(
350 351 352 353 354
              ele, input_data[(d * input_height + h) * input_width + w]);
        }
      }
    }
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
355
    pool_process.finalize(ele, static_cast<T>(pool_size));
356 357 358 359 360
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
361
__global__ void KernelPool3DGrad(
362
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
363 364 365 366 367 368 369
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
    T* input_grad) {
370
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
371 372 373 374 375 376 377 378 379 380
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
381
                      : (offsetD - ksize_depth) / stride_depth + 1;
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
413
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
414
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
415 416
                               output_grad[output_sub_idx], gradient,
                               static_cast<T>(1.0 / pool_size));
417 418 419 420 421 422 423
        }
      }
    }
    input_grad[index] = gradient;
  }
}

424
template <typename T>
425
__global__ void KernelMaxPool3DGrad(
426
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
427 428 429 430 431 432
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
433
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
470
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
471 472 473 474
    }
  }
}

C
chengduoZH 已提交
475 476 477 478 479
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
480
template <typename PoolProcess, class T>
C
chengduoZH 已提交
481
class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
482
 public:
483
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
484 485 486
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  PoolProcess pool_process, framework::Tensor* output) {
487 488 489 490 491
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
492 493 494 495
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
496 497 498 499 500 501 502 503 504 505 506
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
507
    T* output_data = output->mutable_data<T>(context.GetPlace());
508 509 510 511 512 513 514

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

515
    KernelPool3D<
C
chengduoZH 已提交
516 517 518 519
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
520 521 522 523 524
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process,
        output_data);
525 526 527
  }
};

C
chengduoZH 已提交
528 529 530 531 532
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
533
template <typename PoolProcess, class T>
C
chengduoZH 已提交
534
class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
535
 public:
536
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
537
                  const framework::Tensor& input,
538 539 540
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
C
chengduoZH 已提交
541
                  PoolProcess pool_process, framework::Tensor* input_grad) {
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
564
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
565

566 567
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
568 569 570 571
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

572
    KernelPool3DGrad<
C
chengduoZH 已提交
573 574 575 576
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
577 578 579 580 581
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, input_grad_data);
582 583 584
  }
};

C
chengduoZH 已提交
585 586 587 588 589
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
590
template <class T>
C
chengduoZH 已提交
591
class MaxPool3dGradFunctor<platform::GPUPlace, T> {
592 593
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
594
                  const framework::Tensor& input,
595 596
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
C
chengduoZH 已提交
597 598
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
621
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
622 623 624 625 626 627 628

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

629
    KernelMaxPool3DGrad<
630 631 632
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
633 634 635 636 637
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
638 639 640
  }
};

C
chengduoZH 已提交
641
template class MaxPool3dGradFunctor<platform::GPUPlace, float>;
C
chengduoZH 已提交
642
template class MaxPool3dGradFunctor<platform::GPUPlace, double>;
C
chengduoZH 已提交
643 644

template class Pool3dFunctor<platform::GPUPlace,
645
                             paddle::operators::math::MaxPool<float>, float>;
C
chengduoZH 已提交
646
template class Pool3dFunctor<platform::GPUPlace,
647
                             paddle::operators::math::AvgPool<float>, float>;
C
chengduoZH 已提交
648
template class Pool3dGradFunctor<
649
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
C
chengduoZH 已提交
650
template class Pool3dGradFunctor<
651
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
C
chengduoZH 已提交
652
template class Pool3dFunctor<platform::GPUPlace,
653
                             paddle::operators::math::MaxPool<double>, double>;
C
chengduoZH 已提交
654
template class Pool3dFunctor<platform::GPUPlace,
655
                             paddle::operators::math::AvgPool<double>, double>;
C
chengduoZH 已提交
656
template class Pool3dGradFunctor<
657
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
C
chengduoZH 已提交
658
template class Pool3dGradFunctor<
659
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
660

C
chengduoZH 已提交
661
template <typename T>
C
chengduoZH 已提交
662
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
663 664 665 666 667
    const int nthreads, const T* input_data, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width, T* output_data, T* mask_data) {
C
chengduoZH 已提交
668
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
669
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    T ele = -FLT_MAX;
C
chengduoZH 已提交
685
    int max_index = -1;
C
chengduoZH 已提交
686 687
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
688 689 690 691
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
692 693 694 695
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
696
    mask_data[index] = max_index;
C
chengduoZH 已提交
697 698 699 700
  }
}

template <typename T>
C
chengduoZH 已提交
701
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
702
    const int nthreads, const T* output_grad, const T* mask_data,
C
chengduoZH 已提交
703 704 705
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
C
chengduoZH 已提交
706
    const int padding_height, const int padding_width, T* input_grad) {
C
chengduoZH 已提交
707
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
708 709 710 711
       index += blockDim.x * gridDim.x) {
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int c_offset = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
712 713
    int batch_idx = index / input_width / input_height / channels;

C
chengduoZH 已提交
714 715 716 717 718 719 720 721 722 723 724 725 726
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);

C
chengduoZH 已提交
727
    T gradient = 0;
C
chengduoZH 已提交
728
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
729
    int output_idx =
C
chengduoZH 已提交
730 731
        (batch_idx * channels + c_offset) * output_height * output_width;

C
chengduoZH 已提交
732 733
    mask_data += output_idx;
    output_grad += output_idx;
C
chengduoZH 已提交
734 735 736
    for (int ph = ph_start; ph < ph_end; ++ph) {
      for (int pw = pw_start; pw < pw_end; ++pw) {
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
737 738 739 740 741 742 743
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
744 745 746 747 748
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
749 750 751 752
template <typename T>
class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
753 754 755
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
756 757 758 759
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
760 761 762
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
763 764 765 766 767 768 769 770
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
771 772
    T* output_data = output->mutable_data<T>(context.GetPlace());
    T* mask_data = mask->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
773 774 775 776 777 778

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
779
    KernelMaxPool2dWithIdx<
C
chengduoZH 已提交
780 781
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
C
chengduoZH 已提交
782 783 784 785
                 .stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
786 787 788
  }
};

C
chengduoZH 已提交
789 790 791 792 793
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
794 795 796 797 798 799
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& output_grad,
                  const framework::Tensor& mask, std::vector<int>& ksize,
C
chengduoZH 已提交
800 801 802 803 804 805
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
806 807 808 809 810 811 812 813 814 815 816
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* mask_data = mask.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
817
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
818 819 820 821 822 823

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
824
    KernelMaxPool2DWithIdxGrad<
C
chengduoZH 已提交
825 826
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
C
chengduoZH 已提交
827 828 829 830 831
                 .stream()>>>(nthreads, output_grad_data, mask_data,
                              input_channels, input_height, input_width,
                              output_height, output_width, ksize_height,
                              ksize_width, stride_height, stride_width,
                              padding_height, padding_width, input_grad_data);
C
chengduoZH 已提交
832 833 834 835 836 837 838 839 840
  }
};

template class MaxPool2dWithIndexFunctor<platform::GPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>;

template <typename T>
C
chengduoZH 已提交
841
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
842 843 844 845 846 847 848
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
    T* output_data, T* mask_data) {
C
chengduoZH 已提交
849
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
850 851 852 853 854 855 856
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
857

C
chengduoZH 已提交
858 859 860 861 862 863 864 865 866
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
C
chengduoZH 已提交
867

C
chengduoZH 已提交
868
    T ele = -FLT_MAX;
C
chengduoZH 已提交
869
    int max_index = -1;
C
chengduoZH 已提交
870 871 872 873 874 875 876
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
877 878
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
879 880 881 882 883
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
884
    mask_data[index] = max_index;
C
chengduoZH 已提交
885 886 887 888
  }
}

template <typename T>
C
chengduoZH 已提交
889
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
890 891 892 893 894 895 896
    const int nthreads, const T* output_grad, const T* mask, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
    T* input_grad) {
C
chengduoZH 已提交
897
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
898
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
899 900 901 902 903
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
904 905
    int batch_idx = index / input_width / input_height / input_depth / channels;

C
chengduoZH 已提交
906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923
    int pd_start =
        (d_offset + padding_depth < ksize_depth)
            ? 0
            : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int pd_end =
        min((d_offset + padding_depth) / stride_depth + 1, output_depth);
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);
C
chengduoZH 已提交
924 925

    T gradient = 0;
C
chengduoZH 已提交
926 927 928
    int input_current_feature_map_idx =
        (d_offset * input_height + h_offset) * input_width + w_offset;
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
929 930 931 932
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

C
chengduoZH 已提交
933 934 935 936 937
    for (int pd = pd_start; pd < pd_end; ++pd) {
      for (int ph = ph_start; ph < ph_end; ++ph) {
        for (int pw = pw_start; pw < pw_end; ++pw) {
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
938 939 940 941 942 943 944 945 946
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
947 948 949 950 951
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
952 953 954 955
template <typename T>
class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
C
chengduoZH 已提交
956 957 958
                  const framework::Tensor& input, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* output, framework::Tensor* mask) {
C
chengduoZH 已提交
959 960 961 962 963
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
964 965 966 967
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
968 969 970 971 972 973 974 975 976 977 978
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
979 980
    T* output_data = output->mutable_data<T>(context.GetPlace());
    T* mask_data = mask->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
981 982 983 984 985 986 987

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
988
    KernelMaxPool3DWithIdx<
C
chengduoZH 已提交
989 990 991
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
992 993 994 995
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
996 997 998
  }
};

C
chengduoZH 已提交
999 1000 1001 1002 1003
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1004 1005 1006 1007 1008 1009
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& output_grad,
                  const framework::Tensor& mask, std::vector<int>& ksize,
C
chengduoZH 已提交
1010 1011 1012 1013 1014 1015 1016
                  std::vector<int>& strides, std::vector<int>& paddings,
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1017 1018 1019
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* output_grad_data = output_grad.data<T>();
    const T* mask_data = mask.data<T>();
C
chengduoZH 已提交
1032
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
C
chengduoZH 已提交
1033 1034 1035 1036 1037 1038 1039

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

C
chengduoZH 已提交
1040
    KernelMaxPool3DWithIdxGrad<
C
chengduoZH 已提交
1041 1042 1043
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
C
chengduoZH 已提交
1044 1045 1046 1047 1048
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
  }
};

template class MaxPool3dWithIndexFunctor<platform::GPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, double>;

}  // namespace math
}  // namespace operators
}  // namespace paddle