pooling.cu 45.3 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduo 已提交
15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/pooling.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
C
chengduoZH 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

24
template <typename PoolProcess, typename T>
25
__global__ void KernelPool2D(const int nthreads, const T* input_data,
C
chengduoZH 已提交
26 27 28 29 30 31 32
                             const int channels, const int input_height,
                             const int input_width, const int output_height,
                             const int output_width, const int ksize_height,
                             const int ksize_width, const int stride_height,
                             const int stride_width, const int padding_height,
                             const int padding_width, PoolProcess pool_process,
                             T* output_data) {
33 34
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
35 36 37 38 39 40 41 42 43 44 45 46 47 48
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
49
    T ele = pool_process.initial();
50 51
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduo 已提交
52
        pool_process.compute(input_data[h * input_width + w], &ele);
53 54 55
      }
    }
    int pool_size = (hend - hstart) * (wend - wstart);
C
chengduo 已提交
56
    pool_process.finalize(static_cast<T>(pool_size), &ele);
57 58 59 60 61
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
62
__global__ void KernelPool2DGrad(
63
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
64 65 66 67 68
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    PoolProcess pool_process, T* input_grad) {
69 70
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetC = (index / input_width / input_height) % channels;
    int batch_idx = index / input_width / input_height / channels;

    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int phend = min(offsetH / stride_height + 1, output_height);
    int pwend = min(offsetW / stride_width + 1, output_width);
    T gradient = 0;
    T input = input_data[index];
    int output_idx =
        (batch_idx * channels + offsetC) * output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        int hstart = ph * stride_height - padding_height;
        int wstart = pw * stride_width - padding_width;
        int hend = min(hstart + ksize_height, input_height);
        int wend = min(wstart + ksize_width, input_width);
        hstart = max(hstart, 0);
        wstart = max(wstart, 0);
        int pool_size = (hend - hstart) * (wend - wstart);
        int output_sub_idx = ph * output_width + pw;
100
        pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
101 102
                             output_grad[output_sub_idx],
                             static_cast<T>(1.0 / pool_size), &gradient);
103 104 105 106 107 108
      }
    }
    input_grad[index] = gradient;
  }
}

109
template <typename T>
110
__global__ void KernelMaxPool2DGrad(
111
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
112 113 114 115 116
    const T* output_grad, const int channels, const int input_height,
    const int input_width, const int output_height, const int output_width,
    const int ksize_height, const int ksize_width, const int stride_height,
    const int stride_width, const int padding_height, const int padding_width,
    T* input_grad) {
117 118
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
       index += blockDim.x * gridDim.x) {
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
    input_grad += (batch_idx * channels + c) * input_height * input_width;

    T ele = output_data[index];
    int maxIndex = -1;
    bool stop = false;
    for (int h = hstart; h < hend && !stop; ++h) {
      for (int w = wstart; w < wend && !stop; ++w) {
        if (ele == input_data[h * input_width + w]) {
          maxIndex = h * input_width + w;
          stop = true;
        }
      }
    }

    if (maxIndex != -1) {
      // atomic add
C
chengduoZH 已提交
149
      platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]);
150 151 152 153
    }
  }
}

C
chengduoZH 已提交
154 155 156 157 158
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
159
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
160
class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
161
 public:
Q
QI JUN 已提交
162
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
163 164 165 166
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
                  framework::Tensor* output) {
167 168 169 170
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
171 172 173
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
174 175 176 177 178 179 180 181
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
182
    T* output_data = output->mutable_data<T>(context.GetPlace());
183 184 185 186 187 188

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
189
    KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
190 191 192
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, pool_process, output_data);
193 194 195
  }
};

C
chengduoZH 已提交
196 197 198 199 200
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
201
template <typename PoolProcess, typename T>
Q
QI JUN 已提交
202
class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
203
 public:
Q
QI JUN 已提交
204
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
205
                  const framework::Tensor& input,
206
                  const framework::Tensor& output,
C
chengduo 已提交
207 208 209 210 211
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
                  framework::Tensor* input_grad) {
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
228
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
229 230 231 232 233 234

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
235
    KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
236 237 238 239
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        pool_process, input_grad_data);
240 241 242
  }
};

C
chengduoZH 已提交
243 244 245 246 247
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
248
template <typename T>
Q
QI JUN 已提交
249
class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
250
 public:
Q
QI JUN 已提交
251
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
252
                  const framework::Tensor& input,
253
                  const framework::Tensor& output,
C
chengduo 已提交
254 255 256 257
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
258
                  framework::Tensor* input_grad) {
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
276
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
277 278 279 280 281 282

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
283
    KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
284 285 286 287
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_height, input_width, output_height, output_width, ksize_height,
        ksize_width, stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
288 289 290
  }
};

Q
QI JUN 已提交
291 292
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool2dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
293

Q
QI JUN 已提交
294
template class Pool2dFunctor<platform::CUDADeviceContext,
295
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
296
template class Pool2dFunctor<platform::CUDADeviceContext,
297
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
298 299 300 301 302 303 304
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool2dFunctor<platform::CUDADeviceContext,
305
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
306
template class Pool2dFunctor<platform::CUDADeviceContext,
307
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
308 309 310 311 312 313
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool2dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
314 315

template <typename PoolProcess, typename T>
C
chengduoZH 已提交
316 317 318 319 320 321 322 323 324 325
__global__ void KernelPool3D(const int nthreads, const T* input_data,
                             const int channels, const int input_depth,
                             const int input_height, const int input_width,
                             const int output_depth, const int output_height,
                             const int output_width, const int ksize_depth,
                             const int ksize_height, const int ksize_width,
                             const int stride_depth, const int stride_height,
                             const int stride_width, const int padding_depth,
                             const int padding_height, const int padding_width,
                             PoolProcess pool_process, T* output_data) {
326
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
343
    T ele = pool_process.initial();
344 345 346 347 348
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
349
          pool_process.compute(
C
chengduo 已提交
350
              input_data[(d * input_height + h) * input_width + w], &ele);
351 352 353 354
        }
      }
    }
    int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
C
chengduo 已提交
355
    pool_process.finalize(static_cast<T>(pool_size), &ele);
356 357 358 359 360
    output_data[index] = ele;
  }
}

template <typename PoolProcess, typename T>
361
__global__ void KernelPool3DGrad(
362
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
363 364 365 366 367 368 369
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, PoolProcess pool_process,
    T* input_grad) {
370
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
371 372 373 374 375 376 377 378 379 380
       index += blockDim.x * gridDim.x) {
    int offsetW = index % input_width + padding_width;
    int offsetH = (index / input_width) % input_height + padding_height;
    int offsetD =
        (index / input_width / input_height) % input_depth + padding_depth;
    int offsetC = (index / input_width / input_height / input_depth) % channels;
    int batch_idx = index / input_width / input_height / input_depth / channels;

    int pdstart = (offsetD < ksize_depth)
                      ? 0
381
                      : (offsetD - ksize_depth) / stride_depth + 1;
382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
    int phstart = (offsetH < ksize_height)
                      ? 0
                      : (offsetH - ksize_height) / stride_height + 1;
    int pwstart = (offsetW < ksize_width)
                      ? 0
                      : (offsetW - ksize_width) / stride_width + 1;
    int pdend = min((offsetD) / stride_depth + 1, output_depth);
    int phend = min((offsetH) / stride_height + 1, output_height);
    int pwend = min((offsetW) / stride_width + 1, output_width);

    T gradient = 0;
    T input = input_data[index];
    int output_idx = (batch_idx * channels + offsetC) * output_depth *
                     output_height * output_width;
    output_data += output_idx;
    output_grad += output_idx;

    for (int pd = pdstart; pd < pdend; ++pd) {
      for (int ph = phstart; ph < phend; ++ph) {
        for (int pw = pwstart; pw < pwend; ++pw) {
          // figure out the pooling size
          int dstart = pd * stride_depth - padding_depth;
          int hstart = ph * stride_height - padding_height;
          int wstart = pw * stride_width - padding_width;
          int dend = min(dstart + ksize_depth, input_depth);
          int hend = min(hstart + ksize_height, input_height);
          int wend = min(wstart + ksize_width, input_width);
          dstart = max(dstart, 0);
          hstart = max(hstart, 0);
          wstart = max(wstart, 0);
          int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
413
          int output_sub_idx = (pd * output_height + ph) * output_width + pw;
414
          pool_process.compute(input, output_data[output_sub_idx],
C
chengduo 已提交
415 416
                               output_grad[output_sub_idx],
                               static_cast<T>(1.0 / pool_size), &gradient);
417 418 419 420 421 422 423
        }
      }
    }
    input_grad[index] = gradient;
  }
}

424
template <typename T>
425
__global__ void KernelMaxPool3DGrad(
426
    const int nthreads, const T* input_data, const T* output_data,
C
chengduoZH 已提交
427 428 429 430 431 432
    const T* output_grad, const int channels, const int input_depth,
    const int input_height, const int input_width, const int output_depth,
    const int output_height, const int output_width, const int ksize_depth,
    const int ksize_height, const int ksize_width, const int stride_depth,
    const int stride_height, const int stride_width, const int padding_depth,
    const int padding_height, const int padding_width, T* input_grad) {
433
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    T ele = output_data[index];
    bool stop = false;
    int maxIdx = -1;
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;
    input_grad +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend && !stop; ++d) {
      for (int h = hstart; h < hend && !stop; ++h) {
        for (int w = wstart; w < wend && !stop; ++w) {
          if (ele == input_data[(d * input_height + h) * input_width + w]) {
            stop = true;
            maxIdx = (d * input_height + h) * input_width + w;
          }
        }
      }
    }
    if (maxIdx != -1) {
      // atomic add
C
chengduoZH 已提交
470
      platform::CudaAtomicAdd(input_grad + maxIdx, output_grad[index]);
471 472 473 474
    }
  }
}

C
chengduoZH 已提交
475 476 477 478 479
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
480
template <typename PoolProcess, class T>
Q
QI JUN 已提交
481
class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
482
 public:
Q
QI JUN 已提交
483
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
484 485 486 487
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
                  framework::Tensor* output) {
488 489 490 491 492
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
493 494 495 496
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
497 498 499 500 501 502 503 504 505 506 507
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
508
    T* output_data = output->mutable_data<T>(context.GetPlace());
509 510 511 512 513 514 515

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
516
    KernelPool3D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
517 518 519 520 521
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, pool_process,
        output_data);
522 523 524
  }
};

C
chengduoZH 已提交
525 526 527 528 529
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
530
template <typename PoolProcess, class T>
Q
QI JUN 已提交
531
class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
532
 public:
Q
QI JUN 已提交
533
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
534
                  const framework::Tensor& input,
535
                  const framework::Tensor& output,
C
chengduo 已提交
536 537 538 539 540
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, PoolProcess pool_process,
                  framework::Tensor* input_grad) {
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
563
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
564

565 566
    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
567 568 569 570
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
571
    KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
572 573 574 575 576
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, pool_process, input_grad_data);
577 578 579
  }
};

C
chengduoZH 已提交
580 581 582 583 584
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
585
template <class T>
Q
QI JUN 已提交
586
class MaxPool3dGradFunctor<platform::CUDADeviceContext, T> {
587
 public:
Q
QI JUN 已提交
588
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
589
                  const framework::Tensor& input,
590
                  const framework::Tensor& output,
C
chengduo 已提交
591 592 593 594
                  const framework::Tensor& output_grad,
                  const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
595
                  framework::Tensor* input_grad) {
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
618
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
619 620 621 622 623 624 625

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
626
    KernelMaxPool3DGrad<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
627 628 629 630 631
        nthreads, input_data, output_data, output_grad_data, input_channels,
        input_depth, input_height, input_width, output_depth, output_height,
        output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
        stride_height, stride_width, padding_depth, padding_height,
        padding_width, input_grad_data);
632 633 634
  }
};

Q
QI JUN 已提交
635 636
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, float>;
template class MaxPool3dGradFunctor<platform::CUDADeviceContext, double>;
C
chengduoZH 已提交
637

Q
QI JUN 已提交
638
template class Pool3dFunctor<platform::CUDADeviceContext,
639
                             paddle::operators::math::MaxPool<float>, float>;
Q
QI JUN 已提交
640
template class Pool3dFunctor<platform::CUDADeviceContext,
641
                             paddle::operators::math::AvgPool<float>, float>;
Q
QI JUN 已提交
642 643 644 645 646 647 648
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<float>,
                                 float>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<float>,
                                 float>;
template class Pool3dFunctor<platform::CUDADeviceContext,
649
                             paddle::operators::math::MaxPool<double>, double>;
Q
QI JUN 已提交
650
template class Pool3dFunctor<platform::CUDADeviceContext,
651
                             paddle::operators::math::AvgPool<double>, double>;
Q
QI JUN 已提交
652 653 654 655 656 657
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::MaxPoolGrad<double>,
                                 double>;
template class Pool3dGradFunctor<platform::CUDADeviceContext,
                                 paddle::operators::math::AvgPoolGrad<double>,
                                 double>;
658

C
chengduoZH 已提交
659
template <typename T1, typename T2>
C
chengduoZH 已提交
660
__global__ void KernelMaxPool2dWithIdx(
C
chengduoZH 已提交
661
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
662 663 664
    const int input_height, const int input_width, const int output_height,
    const int output_width, const int ksize_height, const int ksize_width,
    const int stride_height, const int stride_width, const int padding_height,
C
chengduoZH 已提交
665
    const int padding_width, T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
666
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
667
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
668 669 670 671 672 673 674 675 676 677 678 679 680 681
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int c = (index / output_width / output_height) % channels;
    int batch_idx = index / output_width / output_height / channels;

    int hstart = ph * stride_height - padding_height;
    int hend = min(hstart + ksize_height, input_height);
    hstart = max(hstart, 0);

    int wstart = pw * stride_width - padding_width;
    int wend = min(wstart + ksize_width, input_width);
    wstart = max(wstart, 0);

    input_data += (batch_idx * channels + c) * input_height * input_width;
C
chengduoZH 已提交
682
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
683
    int max_index = -1;
C
chengduoZH 已提交
684 685
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
C
chengduoZH 已提交
686 687 688 689
        int input_index = h * input_width + w;
        if (ele < input_data[input_index]) {
          max_index = input_index;
          ele = input_data[input_index];
C
chengduoZH 已提交
690 691 692 693
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
694
    mask_data[index] = max_index;
C
chengduoZH 已提交
695 696 697
  }
}

C
chengduoZH 已提交
698
template <typename T1, typename T2>
C
chengduoZH 已提交
699
__global__ void KernelMaxPool2DWithIdxGrad(
C
chengduoZH 已提交
700
    const int nthreads, const T1* output_grad, const T2* mask_data,
C
chengduoZH 已提交
701 702 703
    const int channels, const int input_height, const int input_width,
    const int output_height, const int output_width, const int ksize_height,
    const int ksize_width, const int stride_height, const int stride_width,
C
chengduoZH 已提交
704
    const int padding_height, const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
705
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
706 707 708 709
       index += blockDim.x * gridDim.x) {
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int c_offset = (index / input_width / input_height) % channels;
C
chengduoZH 已提交
710 711
    int batch_idx = index / input_width / input_height / channels;

C
chengduoZH 已提交
712 713 714 715 716 717 718 719 720 721 722 723 724
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);

C
chengduoZH 已提交
725
    T1 gradient = 0;
C
chengduoZH 已提交
726
    int input_current_featuremap_idx = h_offset * input_width + w_offset;
C
chengduoZH 已提交
727
    int output_idx =
C
chengduoZH 已提交
728 729
        (batch_idx * channels + c_offset) * output_height * output_width;

C
chengduoZH 已提交
730 731
    mask_data += output_idx;
    output_grad += output_idx;
C
chengduoZH 已提交
732 733 734
    for (int ph = ph_start; ph < ph_end; ++ph) {
      for (int pw = pw_start; pw < pw_end; ++pw) {
        if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
C
chengduoZH 已提交
735 736 737 738 739 740 741
          gradient += output_grad[ph * output_width + pw];
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
742 743 744 745 746
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
747
template <typename T1, typename T2>
Q
QI JUN 已提交
748
class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
749
 public:
Q
QI JUN 已提交
750
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
751 752 753 754
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
755 756 757 758
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
C
chengduoZH 已提交
759 760 761
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
C
chengduoZH 已提交
762 763 764 765 766 767 768
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
769 770 771
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
772 773 774 775 776 777

    int nthreads = batch_size * output_channels * output_height * output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
778
    KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
779 780 781
        nthreads, input_data, input_channels, input_height, input_width,
        output_height, output_width, ksize_height, ksize_width, stride_height,
        stride_width, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
782 783 784
  }
};

C
chengduoZH 已提交
785 786 787 788 789
/*
 * All tensors are in NCHW format.
 * Ksize, strides, paddings are two elements. These two elements represent
 * height and width, respectively.
 */
C
chengduoZH 已提交
790
template <typename T1, typename T2>
Q
QI JUN 已提交
791
class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
792
 public:
Q
QI JUN 已提交
793
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
794
                  const framework::Tensor& output_grad,
C
chengduo 已提交
795 796 797
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
798 799 800 801 802
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_height = input_grad->dims()[2];
    const int input_width = input_grad->dims()[3];
C
chengduoZH 已提交
803 804 805 806 807 808 809 810 811
    const int output_height = output_grad.dims()[2];
    const int output_width = output_grad.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

C
chengduoZH 已提交
812 813 814
    const T2* mask_data = mask.data<T2>();
    const T1* output_grad_data = output_grad.data<T1>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
815 816 817 818 819 820

    int nthreads = batch_size * input_channels * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
821
    KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
822 823 824 825
        nthreads, output_grad_data, mask_data, input_channels, input_height,
        input_width, output_height, output_width, ksize_height, ksize_width,
        stride_height, stride_width, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
826 827 828
  }
};

Q
QI JUN 已提交
829 830 831 832 833 834 835 836
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
837

C
chengduoZH 已提交
838
template <typename T1, typename T2>
C
chengduoZH 已提交
839
__global__ void KernelMaxPool3DWithIdx(
C
chengduoZH 已提交
840
    const int nthreads, const T1* input_data, const int channels,
C
chengduoZH 已提交
841 842 843 844 845
    const int input_depth, const int input_height, const int input_width,
    const int output_depth, const int output_height, const int output_width,
    const int ksize_depth, const int ksize_height, const int ksize_width,
    const int stride_depth, const int stride_height, const int stride_width,
    const int padding_depth, const int padding_height, const int padding_width,
C
chengduoZH 已提交
846
    T1* output_data, T2* mask_data) {
C
chengduoZH 已提交
847
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
848 849 850 851 852 853 854
       index += blockDim.x * gridDim.x) {
    int pw = index % output_width;
    int ph = (index / output_width) % output_height;
    int pd = (index / output_width / output_height) % output_depth;
    int c = (index / output_width / output_height / output_depth) % channels;
    int batch_idx =
        index / output_width / output_height / output_depth / channels;
C
chengduoZH 已提交
855

C
chengduoZH 已提交
856 857 858 859 860 861 862 863 864
    int dstart = pd * stride_depth - padding_depth;
    int hstart = ph * stride_height - padding_height;
    int wstart = pw * stride_width - padding_width;
    int dend = min(dstart + ksize_depth, input_depth);
    int hend = min(hstart + ksize_height, input_height);
    int wend = min(wstart + ksize_width, input_width);
    dstart = max(dstart, 0);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
C
chengduoZH 已提交
865

C
chengduoZH 已提交
866
    T1 ele = -FLT_MAX;
C
chengduoZH 已提交
867
    int max_index = -1;
C
chengduoZH 已提交
868 869 870 871 872 873 874
    input_data +=
        (batch_idx * channels + c) * input_depth * input_height * input_width;

    for (int d = dstart; d < dend; ++d) {
      for (int h = hstart; h < hend; ++h) {
        for (int w = wstart; w < wend; ++w) {
          if (ele < input_data[(d * input_height + h) * input_width + w]) {
C
chengduoZH 已提交
875 876
            max_index = (d * input_height + h) * input_width + w;
            ele = input_data[max_index];
C
chengduoZH 已提交
877 878 879 880 881
          }
        }
      }
    }
    output_data[index] = ele;
C
chengduoZH 已提交
882
    mask_data[index] = max_index;
C
chengduoZH 已提交
883 884 885
  }
}

C
chengduoZH 已提交
886
template <typename T1, typename T2>
C
chengduoZH 已提交
887
__global__ void KernelMaxPool3DWithIdxGrad(
C
chengduoZH 已提交
888 889 890 891 892 893 894
    const int nthreads, const T1* output_grad, const T2* mask,
    const int channels, const int input_depth, const int input_height,
    const int input_width, const int output_depth, const int output_height,
    const int output_width, const int ksize_depth, const int ksize_height,
    const int ksize_width, const int stride_depth, const int stride_height,
    const int stride_width, const int padding_depth, const int padding_height,
    const int padding_width, T1* input_grad) {
C
chengduoZH 已提交
895
  for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
C
chengduoZH 已提交
896
       index += blockDim.x * gridDim.x) {
C
chengduoZH 已提交
897 898 899 900 901
    int w_offset = index % input_width;
    int h_offset = (index / input_width) % input_height;
    int d_offset = (index / input_width / input_height) % input_depth;
    int c_offset =
        (index / input_width / input_height / input_depth) % channels;
C
chengduoZH 已提交
902 903
    int batch_idx = index / input_width / input_height / input_depth / channels;

C
chengduoZH 已提交
904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921
    int pd_start =
        (d_offset + padding_depth < ksize_depth)
            ? 0
            : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
    int ph_start =
        (h_offset + padding_height < ksize_height)
            ? 0
            : (h_offset + padding_height - ksize_height) / stride_height + 1;
    int pw_start =
        (w_offset + padding_width < ksize_width)
            ? 0
            : (w_offset + padding_width - ksize_width) / stride_width + 1;
    int pd_end =
        min((d_offset + padding_depth) / stride_depth + 1, output_depth);
    int ph_end =
        min((h_offset + padding_height) / stride_height + 1, output_height);
    int pw_end =
        min((w_offset + padding_width) / stride_width + 1, output_width);
C
chengduoZH 已提交
922

C
chengduoZH 已提交
923
    T1 gradient = 0;
C
chengduoZH 已提交
924 925 926
    int input_current_feature_map_idx =
        (d_offset * input_height + h_offset) * input_width + w_offset;
    int output_idx = (batch_idx * channels + c_offset) * output_depth *
C
chengduoZH 已提交
927 928 929 930
                     output_height * output_width;
    mask += output_idx;
    output_grad += output_idx;

C
chengduoZH 已提交
931 932 933 934 935
    for (int pd = pd_start; pd < pd_end; ++pd) {
      for (int ph = ph_start; ph < ph_end; ++ph) {
        for (int pw = pw_start; pw < pw_end; ++pw) {
          if (mask[(pd * output_height + ph) * output_width + pw] ==
              input_current_feature_map_idx)
C
chengduoZH 已提交
936 937 938 939 940 941 942 943 944
            gradient +=
                output_grad[(pd * output_height + ph) * output_width + pw];
        }
      }
    }
    input_grad[index] = gradient;
  }
}

C
chengduoZH 已提交
945 946 947 948 949
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
950
template <typename T1, typename T2>
Q
QI JUN 已提交
951
class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
952
 public:
Q
QI JUN 已提交
953
  void operator()(const platform::CUDADeviceContext& context,
C
chengduo 已提交
954 955 956 957
                  const framework::Tensor& input, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings, framework::Tensor* output,
                  framework::Tensor* mask) {
C
chengduoZH 已提交
958 959 960 961 962
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
C
chengduoZH 已提交
963 964 965 966
    const int output_channels = output->dims()[1];
    const int output_depth = output->dims()[2];
    const int output_height = output->dims()[3];
    const int output_width = output->dims()[4];
C
chengduoZH 已提交
967 968 969 970 971 972 973 974 975 976
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
977 978 979
    const T1* input_data = input.data<T1>();
    T1* output_data = output->mutable_data<T1>(context.GetPlace());
    T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
C
chengduoZH 已提交
980 981 982 983 984 985 986

    int nthreads = batch_size * output_channels * output_depth * output_height *
                   output_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
987
    KernelMaxPool3DWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
988 989 990 991
        nthreads, input_data, input_channels, input_depth, input_height,
        input_width, output_depth, output_height, output_width, ksize_depth,
        ksize_height, ksize_width, stride_depth, stride_height, stride_width,
        padding_depth, padding_height, padding_width, output_data, mask_data);
C
chengduoZH 已提交
992 993 994
  }
};

C
chengduoZH 已提交
995 996 997 998 999
/*
 * All tensors are in NCDHW format.
 * Ksize, strides, paddings are three elements. These three elements represent
 * depth, height and width, respectively.
 */
C
chengduoZH 已提交
1000
template <typename T1, typename T2>
Q
QI JUN 已提交
1001
class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
C
chengduoZH 已提交
1002
 public:
Q
QI JUN 已提交
1003
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
1004
                  const framework::Tensor& output_grad,
C
chengduo 已提交
1005 1006 1007
                  const framework::Tensor& mask, const std::vector<int>& ksize,
                  const std::vector<int>& strides,
                  const std::vector<int>& paddings,
C
chengduoZH 已提交
1008 1009 1010 1011 1012 1013
                  framework::Tensor* input_grad) {
    const int batch_size = input_grad->dims()[0];
    const int input_channels = input_grad->dims()[1];
    const int input_depth = input_grad->dims()[2];
    const int input_height = input_grad->dims()[3];
    const int input_width = input_grad->dims()[4];
C
chengduoZH 已提交
1014 1015 1016
    const int output_depth = output_grad.dims()[2];
    const int output_height = output_grad.dims()[3];
    const int output_width = output_grad.dims()[4];
C
chengduoZH 已提交
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

C
chengduoZH 已提交
1027 1028 1029
    const T1* output_grad_data = output_grad.data<T1>();
    const T2* mask_data = mask.data<T2>();
    T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
C
chengduoZH 已提交
1030 1031 1032 1033 1034 1035 1036

    int nthreads =
        batch_size * input_channels * input_depth * input_height * input_width;
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
1037
    KernelMaxPool3DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
1038 1039 1040 1041 1042
        nthreads, output_grad_data, mask_data, input_channels, input_depth,
        input_height, input_width, output_depth, output_height, output_width,
        ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
        stride_width, padding_depth, padding_height, padding_width,
        input_grad_data);
C
chengduoZH 已提交
1043 1044 1045
  }
};

Q
QI JUN 已提交
1046 1047 1048 1049 1050 1051 1052 1053
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, float,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext, float,
                                             int>;
template class MaxPool3dWithIndexFunctor<platform::CUDADeviceContext, double,
                                         int>;
template class MaxPool3dWithIndexGradFunctor<platform::CUDADeviceContext,
                                             double, int>;
C
chengduoZH 已提交
1054 1055 1056 1057

}  // namespace math
}  // namespace operators
}  // namespace paddle