pooling.cu 27.3 KB
Newer Older
C
chengduoZH 已提交
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/pooling.h"
C
chengduoZH 已提交
16
#include "paddle/platform/cuda_helper.h"
17 18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

template <typename PoolProcess, typename T>
23 24 25 26 27 28 29 30 31 32
__global__ void KernelPool2D(const int nthreads, const T* input_data,
                             T* output_data, const int channels,
                             const int input_height, const int input_width,
                             const int output_height, const int output_width,
                             const int ksize_height, const int ksize_width,
                             const int stride_height, const int stride_width,
                             const int padding_height, const int padding_width,
                             PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
33 34 35 36 37 38 39 40 41 42 43 44 45 46
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
47
    T ele = pool_process.initial();
48 49
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
50
        pool_process.compute(ele, input_data[h * input_width + w]);
51 52 53
      }
    }
    int pool_size = (hend - hstart) * (wend - wstart);
54
    pool_process.finalize(ele, (static_cast<T>(pool_size)));
55 56 57 58 59
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
60
__global__ void KernelPool2DGrad(
61 62 63 64 65
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
66 67 68
    const int padding_width, PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
        int pool_size = (hend - hstart) * (wend - wstart);
        int output_sub_idx = ph * output_width + pw;
98
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
99 100
                             output_grad[output_sub_idx], gradient,
                             static_cast<T>(1.0 / pool_size));
101 102 103 104 105 106
      }
    }
    input_grad[index] = gradient;
  }
}

107
template <typename T>
108
__global__ void KernelMaxPool2DGrad(
109 110 111 112 113 114
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
    const int padding_width) {
115 116
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
      atomicAdd(input_grad + maxIndex, output_grad[index]);
    }
  }
}

152
template <typename PoolProcess, typename T>
C
chengduoZH 已提交
153
class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
154
 public:
155 156
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
157
                  std::vector<int>& ksize, std::vector<int>& strides,
158
                  std::vector<int>& paddings, PoolProcess pool_process) {
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
174
    T* output_data = output.mutable_data<T>(context.GetPlace());
175 176 177 178 179 180

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

181
    KernelPool2D<
C
chengduoZH 已提交
182 183 184 185 186 187 188
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(nthreads, input_data, output_data, input_channels,
                              input_height, input_width, output_height,
                              output_width, ksize_height, ksize_width,
                              stride_height, stride_width, padding_height,
189
                              padding_width, pool_process);
190 191 192 193
  }
};

template <typename PoolProcess, typename T>
C
chengduoZH 已提交
194
class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
195
 public:
196 197
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
198 199 200
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
201
                  PoolProcess pool_process) {
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
218
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
219 220 221 222 223 224

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

225
    KernelPool2DGrad<
C
chengduoZH 已提交
226 227 228 229
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
230 231 232
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_height, input_width, output_height, output_width,
        ksize_height, ksize_width, stride_height, stride_width, padding_height,
233
        padding_width, pool_process);
234 235 236
  }
};

237
template <typename T>
C
chengduoZH 已提交
238
class MaxPool2dGradFunctor<platform::GPUPlace, T> {
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings) {
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

269
    KernelMaxPool2DGrad<
270 271 272 273 274 275 276 277 278 279
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_height, input_width, output_height, output_width,
        ksize_height, ksize_width, stride_height, stride_width, padding_height,
        padding_width);
  }
};

C
chengduoZH 已提交
280
template class MaxPool2dGradFunctor<platform::GPUPlace, float>;
281 282 283
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
C
chengduoZH 已提交
284 285

template class Pool2dFunctor<platform::GPUPlace,
286
                             paddle::operators::math::MaxPool<float>, float>;
C
chengduoZH 已提交
287
template class Pool2dFunctor<platform::GPUPlace,
288
                             paddle::operators::math::AvgPool<float>, float>;
C
chengduoZH 已提交
289
template class Pool2dGradFunctor<
290
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
C
chengduoZH 已提交
291
template class Pool2dGradFunctor<
292
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
C
chengduoZH 已提交
293
template class Pool2dFunctor<platform::GPUPlace,
294
                             paddle::operators::math::MaxPool<double>, double>;
C
chengduoZH 已提交
295
template class Pool2dFunctor<platform::GPUPlace,
296
                             paddle::operators::math::AvgPool<double>, double>;
C
chengduoZH 已提交
297
template class Pool2dGradFunctor<
298
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
C
chengduoZH 已提交
299
template class Pool2dGradFunctor<
300
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
301 302

template <typename PoolProcess, typename T>
303
__global__ void KernelPool3D(
304 305 306 307 308 309
    const int nthreads, const T* input_data, T* output_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
310 311
    PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
328
    T ele = pool_process.initial();
329 330 331 332 333
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
334
          pool_process.compute(
335 336 337 338 339
              ele, input_data[(d * input_height + h) * input_width + w]);
        }
      }
    }
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
340
    pool_process.finalize(ele, static_cast<T>(pool_size));
341 342 343 344 345
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
346
__global__ void KernelPool3DGrad(
347 348 349 350 351 352 353
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
354 355
    PoolProcess pool_process) {
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
356 357 358 359 360 361 362 363 364 365
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
366
                      : (offsetD - ksize_depth) / stride_depth + 1;
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
398
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
399
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduoZH 已提交
400 401
                               output_grad[output_sub_idx], gradient,
                               static_cast<T>(1.0 / pool_size));
402 403 404 405 406 407 408
        }
      }
    }
    input_grad[index] = gradient;
  }
}

409
template <typename T>
410
__global__ void KernelMaxPool3DGrad(
411 412 413 414 415 416 417 418
    const int nthreads, const T* input_data, const T* output_data,
    const T* output_grad, T* input_grad, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height,
    const int padding_width) {
419
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
      atomicAdd(input_grad + maxIdx, output_grad[index]);
    }
  }
}

461
template <typename PoolProcess, class T>
C
chengduoZH 已提交
462
class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
463
 public:
464 465
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
466
                  std::vector<int>& ksize, std::vector<int>& strides,
467
                  std::vector<int>& paddings, PoolProcess pool_process) {
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
488
    T* output_data = output.mutable_data<T>(context.GetPlace());
489 490 491 492 493 494 495

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

496
    KernelPool3D<
C
chengduoZH 已提交
497 498 499 500
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
501 502 503 504
        nthreads, input_data, output_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
505
        pool_process);
506 507 508 509
  }
};

template <typename PoolProcess, class T>
C
chengduoZH 已提交
510
class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
511
 public:
512 513
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
514 515 516
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
517
                  PoolProcess pool_process) {
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
540
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
541

542 543
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
544 545 546 547
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

548
    KernelPool3DGrad<
C
chengduoZH 已提交
549 550 551 552
        PoolProcess,
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
553 554 555 556
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_depth, input_height, input_width, output_depth,
        output_height, output_width, ksize_depth, ksize_height, ksize_width,
        stride_depth, stride_height, stride_width, padding_depth,
557
        padding_height, padding_width, pool_process);
558 559 560
  }
};

561
template <class T>
C
chengduoZH 已提交
562
class MaxPool3dGradFunctor<platform::GPUPlace, T> {
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings) {
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

599
    KernelMaxPool3DGrad<
600 601 602 603 604 605 606 607 608 609 610
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_depth, input_height, input_width, output_depth,
        output_height, output_width, ksize_depth, ksize_height, ksize_width,
        stride_depth, stride_height, stride_width, padding_depth,
        padding_height, padding_width);
  }
};

C
chengduoZH 已提交
611
template class MaxPool3dGradFunctor<platform::GPUPlace, float>;
612 613 614
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>;  // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
C
chengduoZH 已提交
615 616

template class Pool3dFunctor<platform::GPUPlace,
617
                             paddle::operators::math::MaxPool<float>, float>;
C
chengduoZH 已提交
618
template class Pool3dFunctor<platform::GPUPlace,
619
                             paddle::operators::math::AvgPool<float>, float>;
C
chengduoZH 已提交
620
template class Pool3dGradFunctor<
621
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
C
chengduoZH 已提交
622
template class Pool3dGradFunctor<
623
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
C
chengduoZH 已提交
624
template class Pool3dFunctor<platform::GPUPlace,
625
                             paddle::operators::math::MaxPool<double>, double>;
C
chengduoZH 已提交
626
template class Pool3dFunctor<platform::GPUPlace,
627
                             paddle::operators::math::AvgPool<double>, double>;
C
chengduoZH 已提交
628
template class Pool3dGradFunctor<
629
    platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
C
chengduoZH 已提交
630
template class Pool3dGradFunctor<
631
    platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
632 633 634 635

}  // namespace math
}  // namespace operators
}  // namespace paddle