conv_kernel.cu 17.9 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/conv_kernel.h"

H
hong 已提交
17
#include "paddle/phi/backends/gpu/gpu_context.h"
18
#include "paddle/phi/core/dense_tensor.h"
H
hong 已提交
19 20 21
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_HIP
22
#include "paddle/phi/kernels/gpudnn/conv_miopen_helper.h"
H
hong 已提交
23
#else
24
#include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h"
H
hong 已提交
25 26 27 28
#endif

#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/profiler.h"
29 30
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
H
hong 已提交
31 32
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
33
#include "paddle/phi/kernels/funcs/padding.h"
H
hong 已提交
34 35 36 37 38 39 40 41 42 43 44 45
#include "paddle/phi/kernels/impl/conv_cudnn_impl.h"

namespace phi {

template <typename T, typename Context>
void ConvCudnnKernel(const Context& ctx,
                     const DenseTensor& input,
                     const DenseTensor& filter,
                     const std::vector<int>& strides,
                     const std::vector<int>& paddings_t,
                     const std::string& padding_algorithm,
                     const std::vector<int>& dilations_t,
46
                     int groups,
H
hong 已提交
47 48
                     const std::string& data_format,
                     DenseTensor* output) {
H
hong 已提交
49
  ctx.template Alloc<T>(output);
H
hong 已提交
50 51 52
  std::vector<int> paddings = paddings_t;
  std::vector<int> dilations = dilations_t;

53 54 55 56 57 58 59 60 61
  bool has_exhaustive_search = ctx.HasDnnAttr("exhaustive_search");
  VLOG(4) << "GPUContext contains `exhaustive_search`: "
          << has_exhaustive_search;
  bool exhaustive_search_attr =
      has_exhaustive_search
          ? PADDLE_GET_CONST(bool, ctx.GetDnnAttr("exhaustive_search"))
          : false;
  bool exhaustive_search =
      FLAGS_cudnn_exhaustive_search || exhaustive_search_attr;
H
hong 已提交
62
  bool deterministic = FLAGS_cudnn_deterministic;
63
  PADDLE_ENFORCE_EQ(exhaustive_search && deterministic,
H
hong 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
                    false,
                    phi::errors::InvalidArgument(
                        "Cann't set exhaustive_search True and "
                        "FLAGS_cudnn_deterministic True at same time."));

  const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
  auto dtype = paddle::platform::CudnnDataType<T>::type;

#ifdef PADDLE_WITH_HIP
  // HIP MIOPEN ONLY SUPPORT NCHW format
  auto compute_format = paddle::platform::DataLayout::kNCHW;
#else
  // Tensor Core introduced from Volta GPUs supports more faster conv op
  // with FP16 in NHWC data format.
  const bool compute_in_nhwc = dtype == CUDNN_DATA_HALF && IsVoltaOrLater(ctx);
  // We will only do data format conversion from NHWC to NCHW.
  // cudnn will convert NCHW to NHWC automatically on Tensor Core.
  auto compute_format = compute_in_nhwc && channel_last
                            ? paddle::platform::DataLayout::kNHWC
                            : paddle::platform::DataLayout::kNCHW;
#endif
  VLOG(3) << "Compute ConvOp with cuDNN:"
          << " data_format=" << data_format << " compute_format="
          << (compute_format == paddle::platform::DataLayout::kNHWC ? "NHWC"
                                                                    : "NCHW");

  // ------------ transformed tensor -----------
  DenseTensor transformed_input_channel(input.type());
  DenseTensor transformed_output(output->type());
  DenseTensor transformed_filter_channel(filter.type());
  T* output_data = nullptr;
  if (channel_last && compute_format == paddle::platform::DataLayout::kNCHW) {
    VLOG(3) << "Transform input tensor from NHWC to NCHW.";
    ResizeToChannelFirst<Context, T>(ctx, &input, &transformed_input_channel);
    TransToChannelFirst<Context, T>(ctx, &input, &transformed_input_channel);

    ResizeToChannelFirst<Context, T>(ctx, output, &transformed_output);

  } else {
    transformed_input_channel.ShareDataWith(input);
    transformed_output.ShareDataWith(*output);
  }
  if (compute_format == paddle::platform::DataLayout::kNHWC) {
    VLOG(3) << "Transform filter tensor from NCHW to NHWC.";
    ResizeToChannelLast<Context, T>(ctx, &filter, &transformed_filter_channel);
    TransToChannelLast<Context, T>(ctx, &filter, &transformed_filter_channel);
  } else {
    transformed_filter_channel.ShareDataWith(filter);
  }
  output_data = transformed_output.data<T>();

  // update padding and dilation
  auto in_dims = transformed_input_channel.dims();
  auto filter_dims = transformed_filter_channel.dims();
  DDim in_data_dims;
  DDim filter_data_dims;

  if (compute_format == paddle::platform::DataLayout::kNCHW) {
    in_data_dims = slice_ddim(in_dims, 2, in_dims.size());
    filter_data_dims = slice_ddim(filter_dims, 2, filter_dims.size());
  } else {
    in_data_dims = slice_ddim(in_dims, 1, in_dims.size() - 1);
    filter_data_dims = slice_ddim(filter_dims, 1, filter_dims.size() - 1);
  }

  std::vector<int> ksize = vectorize<int>(filter_data_dims);
  UpdatePaddingAndDilation(
      &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);

  int data_dim = strides.size();  // 2d or 3d
  bool is_sys_pad = funcs::IsSymmetricPadding(paddings, data_dim);

  DenseTensor transformed_input;
  std::vector<int> padding_common(data_dim, 0);
  if (!is_sys_pad) {
    std::vector<int> padding_diff(data_dim);
    std::vector<int> new_input_shape_vec(data_dim + 2);
    new_input_shape_vec[0] = transformed_input_channel.dims()[0];

    if (compute_format == paddle::platform::DataLayout::kNCHW) {
      new_input_shape_vec[1] = transformed_input_channel.dims()[1];
    } else {
      new_input_shape_vec[data_dim + 1] =
          transformed_input_channel.dims()[data_dim + 1];
    }

    std::vector<int> input_pad(transformed_input_channel.dims().size() * 2, 0);
    for (size_t i = 0; i < data_dim; ++i) {
      padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
      padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
      if (compute_format == paddle::platform::DataLayout::kNCHW) {
        new_input_shape_vec[i + 2] =
            transformed_input_channel.dims()[i + 2] + padding_diff[i];
      } else {
        new_input_shape_vec[i + 1] =
            transformed_input_channel.dims()[i + 1] + padding_diff[i];
      }
      if (compute_format == paddle::platform::DataLayout::kNCHW) {
        input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
        input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
      } else {
        input_pad[2 * i + 2] = paddings[2 * i] - padding_common[i];
        input_pad[2 * i + 2 + 1] = paddings[2 * i + 1] - padding_common[i];
      }
    }
    DDim new_input_shape(make_ddim(new_input_shape_vec));
    transformed_input.Resize(new_input_shape);
H
hong 已提交
171
    ctx.template Alloc<T>(&transformed_input);
H
hong 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210

    const int rank = transformed_input_channel.dims().size();
    T pad_value(0.0);
    switch (rank) {
      case 4: {
        funcs::PadFunction<Context, T, 4>(ctx,
                                          input_pad,
                                          transformed_input_channel,
                                          pad_value,
                                          &transformed_input);
      } break;
      case 5: {
        funcs::PadFunction<Context, T, 5>(ctx,
                                          input_pad,
                                          transformed_input_channel,
                                          pad_value,
                                          &transformed_input);
      } break;
      default:
        PADDLE_THROW(phi::errors::InvalidArgument(
            "ConvOp only support tensors with 4 or 5 dimensions."));
    }

  } else {
    transformed_input.ShareDataWith(transformed_input_channel);
    if (paddings.size() == data_dim) {
      for (size_t i = 0; i < data_dim; ++i) {
        padding_common[i] = paddings[i];
      }
    } else {
      for (size_t i = 0; i < data_dim; ++i) {
        padding_common[i] = paddings[2 * i];
      }
    }
  }

  const T* input_data = transformed_input.data<T>();
  const T* filter_data = transformed_filter_channel.data<T>();

211 212 213
  auto handle = ctx.cudnn_handle();
  auto workspace_handle = ctx.cudnn_workspace_handle();

H
hong 已提交
214
  // ------------------- cudnn descriptors ---------------------
215 216
  ConvArgs args{handle,
                &transformed_input,
217 218 219 220 221 222 223 224
                &transformed_filter_channel,
                &transformed_output,
                strides,
                padding_common,
                dilations,
                dtype,
                groups,
                compute_format};
H
hong 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272

  paddle::platform::DataLayout layout =
      compute_format == paddle::platform::DataLayout::kNHWC
          ? paddle::platform::DataLayout::kNHWC
          : paddle::platform::DataLayout::kNCHW;
  if (transformed_input.dims().size() == 5) {
    layout = compute_format == paddle::platform::DataLayout::kNHWC
                 ? paddle::platform::DataLayout::kNDHWC
                 : paddle::platform::DataLayout::kNCDHW;
  }
  auto layout_format = paddle::platform::GetCudnnTensorFormat(layout);

#ifdef PADDLE_WITH_HIP
  // MIOPEN need to set groups in cdesc in miopen_desc.h
  args.cdesc.set(dtype,
                 padding_common,
                 strides,
                 dilations,
                 paddle::platform::AllowTF32Cudnn(),
                 groups);
#else
  args.cdesc.set(dtype,
                 padding_common,
                 strides,
                 dilations,
                 paddle::platform::AllowTF32Cudnn());
#endif

#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1)
  // cudnn 7 can support groups, no need to do it manually
  // FIXME(typhoonzero): find a better way to disable groups
  // rather than setting it to 1.
  PADDLE_ENFORCE_GPU_SUCCESS(
      paddle::platform::dynload::cudnnSetConvolutionGroupCount(
          args.cdesc.desc(), groups));
  groups = 1;
#endif
#ifdef PADDLE_WITH_HIP
  // MIOPEN do not set groups in wdesc after set groups in cdesc
  groups = 1;
#endif
  args.idesc.set(transformed_input, layout_format);
  args.wdesc.set(transformed_filter_channel, layout_format, groups);
  args.odesc.set(transformed_output, layout_format);
  int i_n, i_c, i_d, i_h, i_w;
  int o_n, o_c, o_d, o_h, o_w;

  if (compute_format == paddle::platform::DataLayout::kNHWC) {
273 274 275 276 277 278 279 280 281 282 283 284 285 286
    GetNCDHW(transformed_input.dims(),
             paddle::platform::DataLayout::kNHWC,
             &i_n,
             &i_c,
             &i_d,
             &i_h,
             &i_w);
    GetNCDHW(transformed_output.dims(),
             paddle::platform::DataLayout::kNHWC,
             &o_n,
             &o_c,
             &o_d,
             &o_h,
             &o_w);
H
hong 已提交
287
  } else {
288 289 290 291 292 293 294 295 296 297 298 299 300 301
    GetNCDHW(transformed_input.dims(),
             paddle::platform::DataLayout::kNCHW,
             &i_n,
             &i_c,
             &i_d,
             &i_h,
             &i_w);
    GetNCDHW(transformed_output.dims(),
             paddle::platform::DataLayout::kNCHW,
             &o_n,
             &o_c,
             &o_d,
             &o_h,
             &o_w);
H
hong 已提交
302 303 304 305 306 307 308 309 310
  }

  int group_offset_in = i_c / groups * i_h * i_w * i_d;
  int group_offset_out = o_c / groups * o_h * o_w * o_d;
  int group_offset_filter = transformed_filter_channel.numel() / groups;
  // ------------------- cudnn conv workspace ---------------------
  size_t workspace_size = 0;  // final workspace to allocate.
// ------------------- cudnn conv algorithm ---------------------
#ifdef PADDLE_WITH_HIP
311 312
  SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
  using search = SearchAlgorithm<miopenConvFwdAlgorithm_t>;
H
hong 已提交
313
  workspace_size = search::GetWorkspaceSize(args);
314
  fwd_result.algo = search::Find<T>(
H
hong 已提交
315 316
      args, exhaustive_search, deterministic, workspace_size, ctx);
#else
317
  SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
318
  using search = SearchAlgorithm<ConvKind::kForward>;
319
  fwd_result = search::Find<T>(ctx, args, exhaustive_search, deterministic);
H
hong 已提交
320
  workspace_size = fwd_result.workspace_size;
H
hong 已提交
321 322 323 324 325 326 327 328
#endif

#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1)
  // when groups > 1, SearchAlgorithm find algo is CUDNN_CONVOLUTION_\
    // FWD_ALGO_WINOGRAD_NONFUSED, but this kind of algorithm is unstable
  // in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
    // FWD_ALGO_IMPLICIT_GEMM manually.
  if (groups > 1) {
329
    fwd_result.algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
H
hong 已提交
330 331 332 333
  }
#endif

  // ------------------- cudnn conv forward ---------------------
334 335
  ScalingParamType<T> alpha = 1.0f;
  ScalingParamType<T> beta = 0.0f;
H
hong 已提交
336

337 338 339
  // NOTE(zhiqiu): inplace addto is not supportted in double grad yet.
  // ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
  // VLOG(4) << "Conv: use_addto = " << ctx.Attr<bool>("use_addto");
H
hong 已提交
340 341 342 343 344 345 346 347 348 349 350 351 352

#ifdef PADDLE_WITH_HIP
  workspace_handle.RunFunc(
      [&](void* workspace_ptr) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            paddle::platform::dynload::miopenConvolutionForward(
                handle,
                &alpha,
                args.idesc.desc(),
                input_data,
                args.wdesc.desc(),
                filter_data,
                args.cdesc.desc(),
353
                fwd_result.algo,
H
hong 已提交
354 355 356 357 358 359 360 361
                &beta,
                args.odesc.desc(),
                output_data,
                workspace_ptr,
                workspace_size));
      },
      workspace_size);
#else
362 363 364 365 366 367 368 369 370 371 372 373 374
  ConvRunner<T, ConvKind::kForward>::Apply(ctx,
                                           args,
                                           fwd_result,
                                           input_data,
                                           filter_data,
                                           output_data,
                                           groups,
                                           group_offset_in,
                                           group_offset_filter,
                                           group_offset_out,
                                           workspace_size,
                                           &workspace_handle,
                                           false);
H
hong 已提交
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
#endif

  if (channel_last && compute_format == paddle::platform::DataLayout::kNCHW) {
    TransToChannelLast<Context, T>(ctx, &transformed_output, output);
  }
}

template <typename T, typename Context>
void Conv3DCudnnKernel(const Context& dev_ctx,
                       const DenseTensor& input,
                       const DenseTensor& filter,
                       const std::vector<int>& strides,
                       const std::vector<int>& paddings,
                       const std::string& padding_algorithm,
                       int groups,
                       const std::vector<int>& dilations,
                       const std::string& data_format,
                       DenseTensor* out) {
  ConvCudnnKernel<T>(dev_ctx,
                     input,
                     filter,
                     strides,
                     paddings,
                     padding_algorithm,
                     dilations,
400
                     groups,
H
hong 已提交
401 402 403 404
                     data_format,
                     out);
}

H
hong 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
template <typename T, typename Context>
void DepthwiseConvCudnnKernel(const Context& dev_ctx,
                              const DenseTensor& input,
                              const DenseTensor& filter,
                              const std::vector<int>& strides,
                              const std::vector<int>& paddings,
                              const std::string& padding_algorithm,
                              int groups,
                              const std::vector<int>& dilations,
                              const std::string& data_format,
                              DenseTensor* out) {
  ConvCudnnKernel<T>(dev_ctx,
                     input,
                     filter,
                     strides,
                     paddings,
                     padding_algorithm,
                     dilations,
423
                     groups,
H
hong 已提交
424 425 426 427
                     data_format,
                     out);
}

H
hong 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
}  // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(conv2d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnKernel,
                   float,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(conv3d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnKernel,
                   float,
                   phi::dtype::float16) {}
H
hong 已提交
444 445 446 447 448 449 450 451

PD_REGISTER_KERNEL(depthwise_conv2d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::DepthwiseConvCudnnKernel,
                   float,
                   phi::dtype::float16) {}

H
hong 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(conv2d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(conv3d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnKernel,
                   float,
                   double,
                   phi::dtype::float16,
                   phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(conv2d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::ConvCudnnKernel,
                   float,
                   double,
                   phi::dtype::float16) {}

PD_REGISTER_KERNEL(conv3d,
                   GPUDNN,
                   ALL_LAYOUT,
                   phi::Conv3DCudnnKernel,
                   float,
                   double,
                   phi::dtype::float16) {}
#endif

#endif

// todo register bfloat16