conv_cudnn_frontend.h 18.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <vector>

20 21
#include "glog/logging.h"

22
#include "paddle/phi/backends/dynload/cudnn_frontend.h"
23
#include "paddle/phi/backends/gpu/cuda/cudnn_desc.h"
24 25
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
26
#include "paddle/phi/core/utils/data_type.h"
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"

namespace phi {

class CudnnFrontendConvHelper {
 public:
  static bool IsNonDeterministic(cudnnBackendDescriptor_t engine_config) {
    return cudnn_frontend::hasNumericalNote<
        CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(engine_config);
  }
  static bool AllowAll(cudnnBackendDescriptor_t engine_config) {
    (void)engine_config;
    return false;
  }

  static uint8_t GetAlignment(const phi::DenseTensor* tensor) {
    // alignment are in bytes
    uint8_t alignment = 1;
    uint64_t address = reinterpret_cast<uint64_t>(tensor->data());
    while (address % alignment == 0 && alignment < 16) alignment *= 2;
    return alignment;
  }

  static std::vector<int64_t> GetInt64Array(const std::vector<int>& in_array) {
    std::vector<int64_t> out_array(in_array.size());
    for (int i = 0; i < in_array.size(); i++) {
      out_array[i] = static_cast<int64_t>(in_array[i]);
    }
    return out_array;
  }

  static std::vector<int64_t> GenerateStrides(
      const std::vector<int64_t>& dim, cudnnTensorFormat_t filter_format) {
    // ref:
    // https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/helpers.cpp
    // For INT8x4 and INT8x32 we still compute standard strides here to input
    // into the cuDNN functions. We will manually scale by resizeFactor in the
    // cpu ref.
    size_t nb_dims = dim.size();
    std::vector<int64_t> stride(nb_dims);
    if (filter_format == CUDNN_TENSOR_NCHW) {
      stride[nb_dims - 1] = 1;
      for (int64_t d = nb_dims - 2; d >= 0; d--) {
        stride[d] = stride[d + 1] * dim[d + 1];
      }
    } else {
      // Here we assume that the format is CUDNN_TENSOR_NHWC
      stride[1] = 1;
      stride[nb_dims - 1] = stride[1] * dim[1];
      for (int64_t d = nb_dims - 2; d >= 2; d--) {
        stride[d] = stride[d + 1] * dim[d + 1];
      }
      stride[0] = stride[2] * dim[2];
    }
    return stride;
  }

  static cudnn_frontend::Tensor GetTensorDescriptor(
      const phi::DenseTensor* tensor,
      int64_t id,
      cudnnTensorFormat_t layout_format) {
    auto transformed_dims = phi::vectorize<int64_t>(tensor->dims());
    if (layout_format == CUDNN_TENSOR_NHWC) {
91 92
      transformed_dims =
          phi::backends::gpu::TransformDimOrder(transformed_dims);
93 94 95 96 97 98 99 100
    }
    std::vector<int64_t> strides =
        GenerateStrides(transformed_dims, layout_format);
    return cudnn_frontend::TensorBuilder()
        .setDim(transformed_dims.size(), transformed_dims.data())
        .setStrides(strides.size(), strides.data())
        .setId(id)
        .setAlignment(GetAlignment(tensor))
101
        .setDataType(phi::backends::gpu::ToCudnnDataType(tensor->dtype()))
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
        .build();
  }

  static cudnn_frontend::ConvDesc_v8 GetConvDescriptor(
      cudnnDataType_t dataType,
      const std::vector<int>& padding,
      const std::vector<int>& stride,
      const std::vector<int>& dilation) {
    uint64_t conv_dim = stride.size();
    cudnnDataType_t compute_type =
        (dataType == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
    std::vector<int64_t> padding_int64 = GetInt64Array(padding);
    std::vector<int64_t> stride_int64 = GetInt64Array(stride);
    std::vector<int64_t> dilation_int64 = GetInt64Array(dilation);
    return cudnn_frontend::ConvDescBuilder()
        .setDataType(compute_type)
        .setMathMode(CUDNN_CROSS_CORRELATION)
        .setNDims(conv_dim)
        .setStrides(conv_dim, stride_int64.data())
        .setPrePadding(conv_dim, padding_int64.data())
        .setPostPadding(conv_dim, padding_int64.data())
        .setDilation(conv_dim, dilation_int64.data())
        .build();
  }

  template <cudnnBackendDescriptorType_t op_mode>
  static cudnn_frontend::OperationGraph BuildConvOperationGraph(
      const phi::DenseTensor* x_tensor,
      const phi::DenseTensor* y_tensor,
      const phi::DenseTensor* w_tensor,
      cudnnTensorFormat_t layout_format,
      const std::vector<int>& strides,
      const std::vector<int>& padding_common,
      const std::vector<int>& dilations,
      cudnnDataType_t dtype,
      cudnnHandle_t handle,
      float alpha,
      float beta) {
    auto op = cudnn_frontend::OperationBuilder(op_mode)
                  .setxDesc(GetTensorDescriptor(x_tensor, 'x', layout_format))
                  .setyDesc(GetTensorDescriptor(y_tensor, 'y', layout_format))
                  .setwDesc(GetTensorDescriptor(w_tensor, 'w', layout_format))
                  .setcDesc(GetConvDescriptor(
                      dtype, padding_common, strides, dilations))
                  .setAlpha(alpha)
                  .setBeta(beta)
                  .build();
    std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
    return cudnn_frontend::OperationGraphBuilder()
        .setHandle(handle)
        .setOperationGraph(1, ops.data())
        .build();
  }

  static cudnn_frontend::executionPlans_t FindExecutionPlans(
      cudnn_frontend::OperationGraph* op_graph_pointer,
      bool exhaustive_search,
      bool deterministic,
      void* x_data,
      void* y_data,
      void* w_data,
      cudnnHandle_t handle,
      phi::DnnWorkspaceHandle* workspace_handle) {
    auto heurgen_method = [=](cudnn_frontend::OperationGraph& op_graph_)
        -> cudnn_frontend::EngineConfigList {
      auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
                            .setOperationGraph(op_graph_)
                            .setHeurMode(CUDNN_HEUR_MODE_INSTANT)
                            .build();
      VLOG(4) << "Heuristic has " << heuristics.getEngineConfigCount()
              << " configurations ";

      auto& engine_configs =
          heuristics.getEngineConfig(heuristics.getEngineConfigCount());
      cudnn_frontend::EngineConfigList filtered_configs;
      cudnn_frontend::filter(engine_configs,
                             filtered_configs,
                             deterministic ? IsNonDeterministic : AllowAll);
      return filtered_configs;
    };

    auto fallback_method = [=](cudnn_frontend::OperationGraph& op_graph_)
        -> cudnn_frontend::EngineConfigList {
      auto fallback = cudnn_frontend::EngineFallbackListBuilder()
                          .setOperationGraph(op_graph_)
                          .build();
      auto& fallback_list = fallback.getFallbackList();
      cudnn_frontend::EngineConfigList filtered_configs;
      cudnn_frontend::filter(fallback_list,
                             filtered_configs,
                             deterministic ? IsNonDeterministic : AllowAll);
      return filtered_configs;
    };

    std::array<cudnn_frontend::GeneratorSource const, 2> sources = {
        heurgen_method, fallback_method};
    cudnn_frontend::EngineConfigGenerator generator(sources.size(),
                                                    sources.data());

    size_t workspace_size_limit =
        CalcWorkspaceLimitInBytes(UseFixedWorkspace());
    auto predicate_function =
        [=](cudnn_frontend::ExecutionPlan const& plan) -> bool {
      return plan.getWorkspaceSize() > workspace_size_limit;
    };

    auto plans =
        generator.cudnnGetPlan(handle, *op_graph_pointer, predicate_function);

    bool use_autotune = phi::autotune::AutoTuneStatus::Instance().UseAutoTune();

    if (!deterministic && (exhaustive_search || use_autotune)) {
      size_t workspace_size_max = 0;
      std::for_each(
          plans.begin(), plans.end(), [&](cudnn_frontend::ExecutionPlan& opt) {
            if (opt.getWorkspaceSize() > workspace_size_max) {
              workspace_size_max = opt.getWorkspaceSize();
            }
          });
      VLOG(6) << "[cudnn_frontend] Max workspace size: " << workspace_size_max;
      workspace_handle->RunFunc(
          [&](void* workspace_ptr) {
            void* data_ptrs[] = {x_data, y_data, w_data};
            int64_t uids[] = {'x', 'y', 'w'};
            auto variant_pack = cudnn_frontend::VariantPackBuilder()
                                    .setWorkspacePointer(workspace_ptr)
                                    .setDataPointers(3, data_ptrs)
                                    .setUids(3, uids)
                                    .build();
            plans =
                generator
                    .cudnnFindPlan<cudnn_frontend::CudnnFindSamplingTechnique::
                                       CUDNN_FIND_SAMPLE_MEDIAN_OF_THREE>(
                        handle,
                        *op_graph_pointer,
                        variant_pack,
                        predicate_function);
          },
          workspace_size_max);
    }

    std::for_each(
        plans.begin(), plans.end(), [](cudnn_frontend::ExecutionPlan& opt) {
          VLOG(6) << "Plan tag: " << opt.getTag() << " finished in "
                  << opt.getExecutionTime() << " ms,"
                  << " workspace: " << opt.getWorkspaceSize() << " bytes";
        });

    return plans;
  }
};  // class CudnnFrontendConvHelper

template <typename T>
void CudnnConvBwdDataV8(const DenseTensor* dy_tensor,
                        const DenseTensor* w_tensor,
                        cudnnHandle_t handle,
                        DnnWorkspaceHandle* workspace_handle,
                        const std::vector<int>& strides,
                        const std::vector<int>& padding_common,
                        const std::vector<int>& dilations,
                        cudnnDataType_t dtype,
                        cudnnTensorFormat_t layout_format,
                        bool use_addto,
                        bool exhaustive_search,
                        bool deterministic,
                        DenseTensor* dx_tensor) {
  auto& plan_cache_bwd_data =
      phi::autotune::AutoTuneCache::Instance().GetConvV8(
          phi::autotune::AlgorithmType::kConvBackwardDataV8);
  T* dy_tensor_data = const_cast<T*>(dy_tensor->data<T>());
  T* w_tensor_data = const_cast<T*>(w_tensor->data<T>());
  T* dx_tensor_data = dx_tensor->data<T>();

  float alpha = 1.0f;
  float beta = use_addto ? 1.0f : 0.0f;

  using helper = CudnnFrontendConvHelper;
  auto op_graph = helper::BuildConvOperationGraph<
      CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR>(
      dx_tensor,
      dy_tensor,
      w_tensor,
      layout_format,
      strides,
      padding_common,
      dilations,
      dtype,
      handle,
      alpha,
      beta);

  if (plan_cache_bwd_data.FindPlan(op_graph, use_addto)) {
    auto engine_config =
        plan_cache_bwd_data.GetConfig(op_graph, handle, use_addto);
    auto cached_plan = cudnn_frontend::ExecutionPlanBuilder()
                           .setHandle(handle)
                           .setEngineConfig(engine_config, op_graph.getTag())
                           .build();
    auto workspace_size = cached_plan.getWorkspaceSize();
    VLOG(4) << "Cached execution plan found." << cached_plan.getTag()
            << "; Require workspace: " << workspace_size;
    workspace_handle->RunFunc(
        [&](void* workspace_ptr) {
          void* data_ptrs[] = {dx_tensor_data, dy_tensor_data, w_tensor_data};
          int64_t uids[] = {'x', 'y', 'w'};
          auto variant_pack = cudnn_frontend::VariantPackBuilder()
                                  .setWorkspacePointer(workspace_ptr)
                                  .setDataPointers(3, data_ptrs)
                                  .setUids(3, uids)
                                  .build();
          PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
              handle, cached_plan.get_raw_desc(), variant_pack.get_raw_desc()));
        },
        workspace_size);
    return;
  }

  auto plans = helper::FindExecutionPlans(&op_graph,
                                          exhaustive_search,
                                          deterministic,
                                          dx_tensor_data,
                                          dy_tensor_data,
                                          w_tensor_data,
                                          handle,
                                          workspace_handle);

  for (auto& plan : plans) {
    try {
      int64_t workspace_size = plan.getWorkspaceSize();
      workspace_handle->RunFunc(
          [&](void* workspace_ptr) {
            void* data_ptrs[] = {dx_tensor_data, dy_tensor_data, w_tensor_data};
            int64_t uids[] = {'x', 'y', 'w'};
            auto variant_pack = cudnn_frontend::VariantPackBuilder()
                                    .setWorkspacePointer(workspace_ptr)
                                    .setDataPointers(3, data_ptrs)
                                    .setUids(3, uids)
                                    .build();
            PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
                handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
          },
          workspace_size);
      if (!exhaustive_search ||
          plan_cache_bwd_data.IsStable(op_graph, plan.getTag(), use_addto)) {
        plan_cache_bwd_data.InsertPlan(op_graph, plan, use_addto);
      }
      return;
    } catch (cudnn_frontend::cudnnException& e) {
    } catch (phi::enforce::EnforceNotMet& e) {
    }
  }
  PADDLE_THROW(
      phi::errors::InvalidArgument("[CUDNN Frontend API] No valid plan could "
                                   "be found to execute conv backward data."));
}

template <typename T>
void CudnnConvBwdFilterV8(const DenseTensor* x_tensor,
                          const DenseTensor* dy_tensor,
                          cudnnHandle_t handle,
                          DnnWorkspaceHandle* workspace_handle,
                          const std::vector<int>& strides,
                          const std::vector<int>& padding_common,
                          const std::vector<int>& dilations,
                          cudnnDataType_t dtype,
                          cudnnTensorFormat_t layout_format,
                          bool use_addto,
                          bool exhaustive_search,
                          bool deterministic,
                          DenseTensor* dw_tensor) {
  auto& plan_cache_bwd_filter =
      phi::autotune::AutoTuneCache::Instance().GetConvV8(
          phi::autotune::AlgorithmType::kConvBackwardFilterV8);
  T* x_tensor_data = const_cast<T*>(x_tensor->data<T>());
  T* dy_tensor_data = const_cast<T*>(dy_tensor->data<T>());
  T* dw_tensor_data = dw_tensor->data<T>();

  float alpha = 1.0f;
  float beta = 0.0f;

  using helper = CudnnFrontendConvHelper;
  auto op_graph = helper::BuildConvOperationGraph<
      CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR>(
      x_tensor,
      dy_tensor,
      dw_tensor,
      layout_format,
      strides,
      padding_common,
      dilations,
      dtype,
      handle,
      alpha,
      beta);

  if (plan_cache_bwd_filter.FindPlan(op_graph)) {
    auto engine_config = plan_cache_bwd_filter.GetConfig(op_graph, handle);
    auto cached_plan = cudnn_frontend::ExecutionPlanBuilder()
                           .setHandle(handle)
                           .setEngineConfig(engine_config, op_graph.getTag())
                           .build();
    auto workspace_size = cached_plan.getWorkspaceSize();
    VLOG(4) << "Cached execution plan found." << cached_plan.getTag()
            << "; Require workspace: " << workspace_size;
    workspace_handle->RunFunc(
        [&](void* workspace_ptr) {
          void* data_ptrs[] = {x_tensor_data, dy_tensor_data, dw_tensor_data};
          int64_t uids[] = {'x', 'y', 'w'};
          auto variant_pack = cudnn_frontend::VariantPackBuilder()
                                  .setWorkspacePointer(workspace_ptr)
                                  .setDataPointers(3, data_ptrs)
                                  .setUids(3, uids)
                                  .build();
          PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
              handle, cached_plan.get_raw_desc(), variant_pack.get_raw_desc()));
        },
        workspace_size);
    return;
  }

  auto plans = helper::FindExecutionPlans(&op_graph,
                                          exhaustive_search,
                                          deterministic,
                                          x_tensor_data,
                                          dy_tensor_data,
                                          dw_tensor_data,
                                          handle,
                                          workspace_handle);

  for (auto& plan : plans) {
    try {
      int64_t workspace_size = plan.getWorkspaceSize();
      workspace_handle->RunFunc(
          [&](void* workspace_ptr) {
            void* data_ptrs[] = {x_tensor_data, dy_tensor_data, dw_tensor_data};
            int64_t uids[] = {'x', 'y', 'w'};
            auto variant_pack = cudnn_frontend::VariantPackBuilder()
                                    .setWorkspacePointer(workspace_ptr)
                                    .setDataPointers(3, data_ptrs)
                                    .setUids(3, uids)
                                    .build();
            PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
                handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
          },
          workspace_size);
      if (!exhaustive_search ||
          plan_cache_bwd_filter.IsStable(op_graph, plan.getTag())) {
        plan_cache_bwd_filter.InsertPlan(op_graph, plan);
      }
      return;
    } catch (cudnn_frontend::cudnnException& e) {
      VLOG(4) << "Plan " << plan.describe()
              << "failed to execute. Trying next plan.";
    } catch (phi::enforce::EnforceNotMet& e) {
      VLOG(4) << "Plan " << plan.describe()
              << "failed to execute. Trying next plan.";
    }
  }

  PADDLE_THROW(phi::errors::InvalidArgument(
      "[CUDNN Frontend API] No valid plan could "
      "be found to execute conv backward filter."));
}

}  // namespace phi