pool_grad_kernel.cu 17.8 KB
Newer Older
F
From00 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/pool_grad_kernel.h"

F
From00 已提交
17 18 19 20 21
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/pooling.h"
22
#include "paddle/phi/kernels/gpudnn/pool_gpudnn.h"
F
From00 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
#include "paddle/phi/kernels/pool_kernel.h"

#ifdef PADDLE_WITH_HIP
#include "paddle/phi/kernels/impl/pool_grad_kernel_impl.h"  //  PoolGradRawGPUDNNKernel will call PoolGradRawKernel for pooling type "max" in ROCm
#endif

namespace phi {

template <typename T, typename Context>
void PoolGradRawGPUDNNKernel(const Context& ctx,
                             const DenseTensor& x,
                             const DenseTensor& out,
                             const DenseTensor& dout,
                             const std::vector<int>& kernel_size,
                             const std::vector<int>& strides,
                             const std::vector<int>& paddings,
                             bool exclusive,
                             const std::string& data_format,
                             const std::string& pooling_type,
                             bool global_pooling,
                             bool adaptive,
                             const std::string& padding_algorithm,
                             DenseTensor* dx) {
  PADDLE_ENFORCE_EQ(
      paddle::platform::is_gpu_place(ctx.GetPlace()),
      true,
      errors::InvalidArgument("Pool operator CUDA kernel must use CUDAPlace "
                              "rather than CPUPlace."));

  const DenseTensor* input = &x;
  const DenseTensor* output = &out;
  const DenseTensor* output_grad = &dout;
  DenseTensor* input_grad = dx;
  std::vector<int> paddings_ = paddings;
  std::vector<int> kernel_size_ = kernel_size;

  const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

#ifdef PADDLE_WITH_HIP
  if (pooling_type == "max") {
    PoolGradRawKernel<T, GPUContext>(ctx,
                                     x,
                                     out,
                                     dout,
                                     kernel_size,
                                     strides,
                                     paddings_,
                                     exclusive,
                                     data_format,
                                     pooling_type,
                                     global_pooling,
                                     adaptive,
                                     padding_algorithm,
                                     dx);
    return;
  }
#endif

  // update paddings
  auto in_x_dims = input->dims();
  DDim data_dims;
  if (channel_last) {
    data_dims = slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
  } else {
    data_dims = slice_ddim(in_x_dims, 2, in_x_dims.size());
  }
  funcs::UpdatePadding(&paddings_,
                       global_pooling,
                       adaptive,
                       padding_algorithm,
                       data_dims,
                       strides,
                       kernel_size_);
  if (data_dims.size() * 2 == static_cast<int>(paddings_.size())) {
    for (int i = 0; i < data_dims.size(); ++i) {
      paddings_.erase(paddings_.begin() + i + 1);
    }
  }

  if (global_pooling) {
    funcs::UpdateKernelSize(&kernel_size_, data_dims);
  }

  // ------- tensor grad --------------
  DenseTensor transformed_input(input->type());
  DenseTensor transformed_output(output->type());
  DenseTensor transformed_output_grad(output_grad->type());

  ctx.template Alloc<T>(input_grad);
  DenseTensor transformed_input_grad(input_grad->type());
  GPUDNNDataLayout layout;
  const std::string str_NCHW = "NCHW", str_NHWC = "NHWC";
  const std::string str_NCDHW = "NCDHW", str_NDHWC = "NDHWC";
  if (data_format == str_NDHWC) {
    layout = GPUDNNDataLayout::kNCDHW;
    std::vector<int> axis{0, 4, 1, 2, 3};

    // input
    transformed_input.Resize(input->dims());
    auto in_dims_vec = vectorize(input->dims());
    in_dims_vec[1] = input->dims()[4];
    in_dims_vec[2] = input->dims()[1];
    in_dims_vec[3] = input->dims()[2];
    in_dims_vec[4] = input->dims()[3];
    transformed_input.Resize(make_ddim(in_dims_vec));
    ctx.Alloc(&transformed_input, input->type());

    funcs::Transpose<Context, T, 5> trans5;
    trans5(ctx, *input, &transformed_input, axis);

    // output
    transformed_output.Resize(output->dims());
    auto out_dims_vec = vectorize(output->dims());
    out_dims_vec[1] = output->dims()[4];
    out_dims_vec[2] = output->dims()[1];
    out_dims_vec[3] = output->dims()[2];
    out_dims_vec[4] = output->dims()[3];
    transformed_output.Resize(make_ddim(out_dims_vec));

    ctx.Alloc(&transformed_output, output->type());

    funcs::Transpose<Context, T, 5> trans5_v2;
    trans5_v2(ctx, *output, &transformed_output, axis);

    // output grad
    transformed_output_grad.Resize(make_ddim(out_dims_vec));
    ctx.Alloc(&transformed_output_grad, output_grad->type());

    funcs::Transpose<Context, T, 5> trans5_v3;
    trans5_v3(ctx, *output_grad, &transformed_output_grad, axis);

    // input grad
    transformed_input_grad.Resize(make_ddim(in_dims_vec));

#ifdef PADDLE_WITH_HIP
    // MIOPEN not support NHWC data layout
  } else if (data_format == str_NHWC) {
    layout = GPUDNNDataLayout::kNCHW;

    std::vector<int> axis{0, 3, 1, 2};

    // input
    transformed_input.Resize(input->dims());
    auto in_dims_vec = vectorize(input->dims());
    in_dims_vec[1] = input->dims()[3];
    in_dims_vec[2] = input->dims()[1];
    in_dims_vec[3] = input->dims()[2];
    transformed_input.Resize(make_ddim(in_dims_vec));
    ctx.Alloc(&transformed_input, input->type());

    funcs::Transpose<Context, T, 4> trans4;
    trans4(ctx, *input, &transformed_input, axis);

    // output
    transformed_output.Resize(output->dims());
    auto out_dims_vec = vectorize(output->dims());
    out_dims_vec[1] = output->dims()[3];
    out_dims_vec[2] = output->dims()[1];
    out_dims_vec[3] = output->dims()[2];
    transformed_output.Resize(make_ddim(out_dims_vec));
    ctx.Alloc(&transformed_output, output->type());

    funcs::Transpose<Context, T, 4> trans4_v2;
    trans4_v2(ctx, *output, &transformed_output, axis);

    // output grad
    transformed_output_grad.Resize(make_ddim(out_dims_vec));
    ctx.Alloc(&transformed_output_grad, output_grad->type());

    funcs::Transpose<Context, T, 4> trans4_v3;
    trans4_v3(ctx, *output_grad, &transformed_output_grad, axis);

    // input grad
    transformed_input_grad.Resize(make_ddim(in_dims_vec));
#endif
  } else {
    layout = GetLayoutFromStr(data_format);
    transformed_input = *input;
    transformed_output = *output;
    transformed_output_grad = *output_grad;
    transformed_input_grad = *input_grad;
  }

  const T* input_data = transformed_input.data<T>();
  const T* output_data = transformed_output.data<T>();
  const T* output_grad_data = transformed_output_grad.data<T>();

  // ------------------- cudnn descriptors ---------------------
  ScopedTensorDescriptor input_desc;
  ScopedTensorDescriptor output_desc;
  ScopedPoolingDescriptor pool_desc;

#ifdef PADDLE_WITH_HIP
  miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
      layout, vectorize<int>(transformed_input.dims()));
  miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
      layout, vectorize<int>(transformed_output.dims()));
#else
  cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
      layout, vectorize<int>(transformed_input.dims()));
  cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
      layout, vectorize<int>(transformed_output.dims()));
#endif
  PoolingMode pooling_mode;
  if (pooling_type == "max") {
    if (FLAGS_cudnn_deterministic) {
      pooling_mode = PoolingMode::kMaximumDeterministic;
    } else {
      pooling_mode = PoolingMode::kMaximum;
    }
  } else {
    pooling_mode = exclusive ? PoolingMode::kAverageExclusive
                             : PoolingMode::kAverageInclusive;
  }

#ifdef PADDLE_WITH_HIP
  miopenPoolingDescriptor_t cudnn_pool_desc =
      pool_desc.descriptor(pooling_mode, kernel_size_, paddings_, strides);
#else
  cudnnPoolingDescriptor_t cudnn_pool_desc =
      pool_desc.descriptor(pooling_mode, kernel_size_, paddings_, strides);
#endif

  // ------------------- cudnn pool algorithm ---------------------
  auto handle = ctx.cudnn_handle();
  ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
  if (input_grad) {
    T* input_grad_data = ctx.template Alloc<T>(&transformed_input_grad);
// Because beta is zero, it is unnecessary to reset input_grad.
#ifdef PADDLE_WITH_HIP
    char* pool_workspace;
    size_t pool_worksize = 0;
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenPoolingGetWorkSpaceSizeV2(
        cudnn_pool_desc, cudnn_output_desc, &pool_worksize));
    PADDLE_ENFORCE_GPU_SUCCESS(hipMalloc(&pool_workspace, pool_worksize));
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenPoolingBackward(handle,
                                                              cudnn_pool_desc,
                                                              &alpha,
                                                              cudnn_output_desc,
                                                              output_data,
                                                              cudnn_output_desc,
                                                              output_grad_data,
                                                              cudnn_input_desc,
                                                              input_data,
                                                              &beta,
                                                              cudnn_input_desc,
                                                              input_grad_data,
                                                              pool_workspace));
    PADDLE_ENFORCE_GPU_SUCCESS(hipFree(pool_workspace));
#else
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnPoolingBackward(handle,
                                                             cudnn_pool_desc,
                                                             &alpha,
                                                             cudnn_output_desc,
                                                             output_data,
                                                             cudnn_output_desc,
                                                             output_grad_data,
                                                             cudnn_input_desc,
                                                             input_data,
                                                             &beta,
                                                             cudnn_input_desc,
                                                             input_grad_data));
#endif

    if (data_format == str_NDHWC) {
      std::vector<int> axis{0, 2, 3, 4, 1};
      funcs::Transpose<Context, T, 5> trans5_v4;
      trans5_v4(ctx, transformed_input_grad, input_grad, axis);
    }
#ifdef PADDLE_WITH_HIP
    // MIOPEN not support NHWC data layout
    if (data_format == str_NHWC) {
      std::vector<int> axis{0, 2, 3, 1};
      funcs::Transpose<Context, T, 4> trans4_v4;
      trans4_v4(ctx, transformed_input_grad, input_grad, axis);
    }
#endif
  }
}

template <typename T, typename Context>
void Pool2dGradGPUDNNKernel(const Context& ctx,
                            const DenseTensor& x,
                            const DenseTensor& out,
                            const DenseTensor& dout,
308
                            const IntArray& kernel_size,
F
From00 已提交
309 310 311 312 313 314 315 316 317 318
                            const std::vector<int>& strides,
                            const std::vector<int>& paddings,
                            bool ceil_mode,
                            bool exclusive,
                            const std::string& data_format,
                            const std::string& pooling_type,
                            bool global_pooling,
                            bool adaptive,
                            const std::string& padding_algorithm,
                            DenseTensor* dx) {
319 320
  std::vector<int> kernel_size_val(kernel_size.GetData().begin(),
                                   kernel_size.GetData().end());
F
From00 已提交
321 322 323 324
  PoolGradRawGPUDNNKernel<T, Context>(ctx,
                                      x,
                                      out,
                                      dout,
325
                                      kernel_size_val,
F
From00 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339
                                      strides,
                                      paddings,
                                      exclusive,
                                      data_format,
                                      pooling_type,
                                      global_pooling,
                                      adaptive,
                                      padding_algorithm,
                                      dx);
}

template <typename T, typename Context>
void Pool2dDoubleGradGPUDNNKernel(const Context& ctx,
                                  const DenseTensor& x,
340
                                  const IntArray& kernel_size,
F
From00 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                                  const std::vector<int>& strides,
                                  const std::vector<int>& paddings,
                                  bool ceil_mode,
                                  bool exclusive,
                                  const std::string& data_format,
                                  const std::string& pooling_type,
                                  bool global_pooling,
                                  bool adaptive,
                                  const std::string& padding_algorithm,
                                  DenseTensor* out) {
  if (pooling_type == "max") {
    PADDLE_THROW(
        errors::InvalidArgument("Pool op grad grad only supports avgpool."));
  } else {
    Pool2dGPUDNNKernel<T, Context>(ctx,
                                   x,
                                   kernel_size,
                                   strides,
                                   paddings,
                                   ceil_mode,
                                   exclusive,
                                   data_format,
                                   pooling_type,
                                   global_pooling,
                                   adaptive,
                                   padding_algorithm,
                                   out);
  }
}

template <typename T, typename Context>
void Pool3dGradGPUDNNKernel(const Context& ctx,
                            const DenseTensor& x,
                            const DenseTensor& out,
                            const DenseTensor& dout,
                            const std::vector<int>& kernel_size,
                            const std::vector<int>& strides,
                            const std::vector<int>& paddings,
                            bool ceil_mode,
                            bool exclusive,
                            const std::string& data_format,
                            const std::string& pooling_type,
                            bool global_pooling,
                            bool adaptive,
                            const std::string& padding_algorithm,
                            DenseTensor* dx) {
  PoolGradRawGPUDNNKernel<T, Context>(ctx,
                                      x,
                                      out,
                                      dout,
                                      kernel_size,
                                      strides,
                                      paddings,
                                      exclusive,
                                      data_format,
                                      pooling_type,
                                      global_pooling,
                                      adaptive,
                                      padding_algorithm,
                                      dx);
}

}  // namespace phi

using phi::dtype::float16;

#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(pool2d_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool2dGradGPUDNNKernel,
                   float,
                   float16) {}
PD_REGISTER_KERNEL(pool2d_double_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool2dDoubleGradGPUDNNKernel,
                   float,
                   float16) {}
PD_REGISTER_KERNEL(pool3d_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool3dGradGPUDNNKernel,
                   float,
                   float16) {}
#else
PD_REGISTER_KERNEL(pool2d_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool2dGradGPUDNNKernel,
                   float,
                   double,
                   float16) {}
PD_REGISTER_KERNEL(pool2d_double_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool2dDoubleGradGPUDNNKernel,
                   float,
                   double,
                   float16) {}
PD_REGISTER_KERNEL(pool3d_grad,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Pool3dGradGPUDNNKernel,
                   float,
                   double,
                   float16) {}
#endif