pooling.cu 47.2 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24
template <typename PoolProcess, typename T>
25
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
26 27 28 29 30 31
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
32
                             bool exclusive, T* output_data) {
33 34
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
35 36 37 38 39 40 41 42 43 44 45 46 47 48
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
49
    T ele = pool_process.initial();
50 51
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduo 已提交
52
        pool_process.compute(input_data[h * input_width + w], &ele);
53 54
      }
    }
55
    int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
56
                              : ksize_height * ksize_width;
C
chengduo 已提交
57
    pool_process.finalize(static_cast<T>(pool_size), &ele);
58 59 60 61 62
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
63
__global__ void KernelPool2DGrad(
64
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
65 66 67 68
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
69
    PoolProcess pool_process, bool exclusive, T* input_grad) {
70 71
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
99
        int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
100
                                  : ksize_height * ksize_width;
101
        int output_sub_idx = ph * output_width + pw;
102
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
103 104
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
105 106 107 108 109 110
      }
    }
    input_grad[index] = gradient;
  }
}

111
template <typename T>
112
__global__ void KernelMaxPool2DGrad(
113
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
114 115 116 117 118
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
119 120
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
151
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
152 153 154 155
    }
  }
}

N
nhzlx 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
    const T* input, const std::vector<int>& input_shape,
    const std::vector<int>& output_shape, const std::vector<int>& ksize,
    const std::vector<int>& strides, const std::vector<int>& paddings,
    PoolProcess pool_compute, bool exclusive, T* output, cudaStream_t stream) {
  const int batch_size = input_shape[0];
  const int input_channels = input_shape[1];
  const int input_height = input_shape[2];
  const int input_width = input_shape[3];
  const int output_channels = output_shape[1];
  const int output_height = output_shape[2];
  const int output_width = output_shape[3];
  const int ksize_height = ksize[0];
  const int ksize_width = ksize[1];
  const int stride_height = strides[0];
  const int stride_width = strides[1];
  const int padding_height = paddings[0];
  const int padding_width = paddings[1];

  int nthreads = batch_size * output_channels * output_height * output_width;
  int blocks = (nthreads + 1024 - 1) / 1024;
  dim3 threads(1024, 1);
  dim3 grid(blocks, 1);

  KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
      nthreads, input, input_channels, input_height, input_width, output_height,
      output_width, ksize_height, ksize_width, stride_height, stride_width,
      padding_height, padding_width, pool_compute, exclusive, output);
}

C
chengduoZH 已提交
187 188 189 190 191
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
192
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
193
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
194
 public:
Q
QI JUN 已提交
195
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
196 197 198
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
199
                  bool exclusive, framework::Tensor* output) {
200 201 202 203
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
204 205 206
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
207 208 209 210 211 212 213 214
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
215
    T* output_data = output->mutable_data<T>(context.GetPlace());
216 217 218 219 220 221

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
222
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
223 224
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
225
        stride_width, padding_height, padding_width, pool_process, exclusive,
226
        output_data);
227 228 229
  }
};

C
chengduoZH 已提交
230 231 232 233 234
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
235
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
236
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
237
 public:
Q
QI JUN 已提交
238
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
239
                  const framework::Tensor& input,
240
                  const framework::Tensor& output,
C
chengduo 已提交
241 242 243 244
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
245
                  bool exclusive, framework::Tensor* input_grad) {
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
262
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
263 264 265 266 267 268

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
269
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
270 271 272
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
273
        pool_process, exclusive, input_grad_data);
274 275 276
  }
};

C
chengduoZH 已提交
277 278 279 280 281
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
282
template <typename T>
Q
QI JUN 已提交
283
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
284
 public:
Q
QI JUN 已提交
285
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
286
                  const framework::Tensor& input,
287
                  const framework::Tensor& output,
C
chengduo 已提交
288 289 290 291
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
292
                  framework::Tensor* input_grad) {
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
310
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
311 312 313 314 315 316

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
317
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
318 319 320 321
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
322 323 324
  }
};

N
nhzlx 已提交
325 326 327 328 329
template class Pool2dDirectCUDAFunctor<paddle::operators::math::MaxPool<float>,
                                       float>;
template class Pool2dDirectCUDAFunctor<paddle::operators::math::AvgPool<float>,
                                       float>;

Q
QI JUN 已提交
330 331
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
332

Q
QI JUN 已提交
333
template class Pool2dFunctor<platform::CUDADeviceContext,
334
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
335
template class Pool2dFunctor<platform::CUDADeviceContext,
336
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
337 338 339 340 341 342 343
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
344
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
345
template class Pool2dFunctor<platform::CUDADeviceContext,
346
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
347 348 349 350 351 352
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
353 354

template <typename PoolProcess, typename T>
355
__global__ void KernelPool3D(
356 357 358
    const int nthreads, const T* input_data, const int channels,
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
359
    const int ksize_depth, const int ksize_height, const int ksize_width,
360
    const int stride_depth, const int stride_height, const int stride_width,
361 362
    const int padding_depth, const int padding_height, const int padding_width,
    PoolProcess pool_process, bool exclusive, T* output_data) {
363
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
380
    T ele = pool_process.initial();
381 382 383 384 385
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
386
          pool_process.compute(
C
chengduo 已提交
387
              input_data[(d * input_height + h) * input_width + w], &ele);
388 389 390
        }
      }
    }
391 392 393
    int pool_size = exclusive
                        ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
C
chengduo 已提交
394
    pool_process.finalize(static_cast<T>(pool_size), &ele);
395 396 397 398 399
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
400
__global__ void KernelPool3DGrad(
401
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
402 403 404 405 406 407
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
408
    bool exclusive, T* input_grad) {
409
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
410 411 412 413 414 415 416 417 418 419
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
420
                      : (offsetD - ksize_depth) / stride_depth + 1;
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
451 452 453
          int pool_size =
              exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
                        : ksize_depth * ksize_height * ksize_width;
454
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
455
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
456 457
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
458 459 460 461 462 463 464
        }
      }
    }
    input_grad[index] = gradient;
  }
}

465
template <typename T>
466
__global__ void KernelMaxPool3DGrad(
467
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
468 469 470 471 472 473
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
474
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
511
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
512 513 514 515
    }
  }
}

C
chengduoZH 已提交
516 517 518 519 520
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
521
template <typename PoolProcess, class T>
Q
QI JUN 已提交
522
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
523
 public:
Q
QI JUN 已提交
524
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
525 526 527
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
528
                  bool exclusive, framework::Tensor* output) {
529 530 531 532 533
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
534 535 536 537
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
538 539 540 541 542 543 544 545 546 547 548
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
549
    T* output_data = output->mutable_data<T>(context.GetPlace());
550 551 552 553 554 555 556

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
557
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
558 559 560
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
561 562
        padding_depth, padding_height, padding_width, pool_process, exclusive,
        output_data);
563 564 565
  }
};

C
chengduoZH 已提交
566 567 568 569 570
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
571
template <typename PoolProcess, class T>
Q
QI JUN 已提交
572
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
573
 public:
Q
QI JUN 已提交
574
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
575
                  const framework::Tensor& input,
576
                  const framework::Tensor& output,
C
chengduo 已提交
577 578 579 580
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
581
                  bool exclusive, framework::Tensor* input_grad) {
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
604
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
605

606 607
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
608 609 610 611
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
612
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
613 614 615 616
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
617
        padding_width, pool_process, exclusive, input_grad_data);
618 619 620
  }
};

C
chengduoZH 已提交
621 622 623 624 625
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
626
template <class T>
Q
QI JUN 已提交
627
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
628
 public:
Q
QI JUN 已提交
629
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
630
                  const framework::Tensor& input,
631
                  const framework::Tensor& output,
C
chengduo 已提交
632 633 634 635
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
636
                  framework::Tensor* input_grad) {
637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
659
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
660 661 662 663 664 665 666

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
667
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
668 669 670 671 672
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
673 674 675
  }
};

Q
QI JUN 已提交
676 677
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
678

Q
QI JUN 已提交
679
template class Pool3dFunctor<platform::CUDADeviceContext,
680
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
681
template class Pool3dFunctor<platform::CUDADeviceContext,
682
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
683 684 685 686 687 688 689
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
690
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
691
template class Pool3dFunctor<platform::CUDADeviceContext,
692
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
693 694 695 696 697 698
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
699

C
chengduoZH 已提交
700
template <typename T1, typename T2>
C
chengduoZH 已提交
701
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
702
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
703 704 705
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
C
chengduoZH 已提交
706
    const int padding_width, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
707
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
708
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
709 710 711 712 713 714 715 716 717 718 719 720 721 722
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
723
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
724
    int max_index = -1;
C
chengduoZH 已提交
725 726
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
727 728 729 730
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
731 732 733 734
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
735
    mask_data[index] = max_index;
C
chengduoZH 已提交
736 737 738
  }
}

C
chengduoZH 已提交
739
template <typename T1, typename T2>
C
chengduoZH 已提交
740
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
741
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
742 743 744
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
C
chengduoZH 已提交
745
    const int padding_height, const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
746
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
747 748 749 750
       index += blockDim.x * gridDim.x) {
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int c_offset = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
751 752
    int batch_idx = index / input_width / input_height / channels;

C
chengduoZH 已提交
753 754 755 756 757 758 759 760 761 762 763 764 765
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);

C
chengduoZH 已提交
766
    T1 gradient = 0;
C
chengduoZH 已提交
767
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
768
    int output_idx =
C
chengduoZH 已提交
769 770
        (batch_idx * channels + c_offset) * output_height * output_width;

C
chengduoZH 已提交
771 772
    mask_data += output_idx;
    output_grad += output_idx;
C
chengduoZH 已提交
773 774 775
    for (int ph = ph_start; ph < ph_end; ++ph) {
      for (int pw = pw_start; pw < pw_end; ++pw) {
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
776 777 778 779 780 781 782
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
783 784 785 786 787
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
788
template <typename T1, typename T2>
Q
QI JUN 已提交
789
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
790
 public:
Q
QI JUN 已提交
791
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
792 793 794 795
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
796 797 798 799
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
800 801 802
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
803 804 805 806 807 808 809
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
810 811 812
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
813 814 815 816 817 818

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
819
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
820 821 822
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
823 824 825
  }
};

C
chengduoZH 已提交
826 827 828 829 830
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
831
template <typename T1, typename T2>
Q
QI JUN 已提交
832
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
833
 public:
Q
QI JUN 已提交
834
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
835
                  const framework::Tensor& output_grad,
C
chengduo 已提交
836 837 838
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
839 840 841 842 843
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
844 845 846 847 848 849 850 851 852
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
853 854 855
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
856 857 858 859 860 861

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
862
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
863 864 865 866
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
867 868 869
  }
};

Q
QI JUN 已提交
870 871 872 873 874 875 876 877
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
878

C
chengduoZH 已提交
879
template <typename T1, typename T2>
C
chengduoZH 已提交
880
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
881
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
882 883 884 885 886
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
C
chengduoZH 已提交
887
    T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
888
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
889 890 891 892 893 894 895
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
896

C
chengduoZH 已提交
897 898 899 900 901 902 903 904 905
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
C
chengduoZH 已提交
906

C
chengduoZH 已提交
907
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
908
    int max_index = -1;
C
chengduoZH 已提交
909 910 911 912 913 914 915
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
916 917
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
918 919 920 921 922
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
923
    mask_data[index] = max_index;
C
chengduoZH 已提交
924 925 926
  }
}

C
chengduoZH 已提交
927
template <typename T1, typename T2>
C
chengduoZH 已提交
928
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
929 930 931 932 933 934 935
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
936
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
937
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
938 939 940 941 942
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
943 944
    int batch_idx = index / input_width / input_height / input_depth / channels;

C
chengduoZH 已提交
945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
    int pd_start =
        (d_offset + padding_depth < ksize_depth)
            ? 0
            : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int pd_end =
        min((d_offset + padding_depth) / stride_depth + 1, output_depth);
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);
C
chengduoZH 已提交
963

C
chengduoZH 已提交
964
    T1 gradient = 0;
C
chengduoZH 已提交
965 966 967
    int input_current_feature_map_idx =
        (d_offset * input_height + h_offset) * input_width + w_offset;
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
968 969 970 971
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

C
chengduoZH 已提交
972 973 974 975 976
    for (int pd = pd_start; pd < pd_end; ++pd) {
      for (int ph = ph_start; ph < ph_end; ++ph) {
        for (int pw = pw_start; pw < pw_end; ++pw) {
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
977 978 979 980 981 982 983 984 985
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
986 987 988 989 990
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
991
template <typename T1, typename T2>
Q
QI JUN 已提交
992
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
993
 public:
Q
QI JUN 已提交
994
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
995 996 997 998
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
999 1000 1001 1002 1003
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
1004 1005 1006 1007
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1018 1019 1020
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
1021 1022 1023 1024 1025 1026 1027

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1028
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1029 1030 1031 1032
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
1033 1034 1035
  }
};

C
chengduoZH 已提交
1036 1037 1038 1039 1040
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1041
template <typename T1, typename T2>
Q
QI JUN 已提交
1042
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1043
 public:
Q
QI JUN 已提交
1044
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1045
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1046 1047 1048
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1049 1050 1051 1052 1053 1054
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1055 1056 1057
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1068 1069 1070
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1071 1072 1073 1074 1075 1076 1077

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1078
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1079 1080 1081 1082 1083
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
1084 1085 1086
  }
};

Q
QI JUN 已提交
1087 1088 1089 1090 1091 1092 1093 1094
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1095 1096 1097 1098

}  // namespace math
}  // namespace operators
}  // namespace paddle