未验证 提交 14a6e67b 编写于 作者: T Tian Zheng 提交者: GitHub

CUDNN v8 Implementation of Convolution Kernels (#47454)

* Refactor conv_kernel and conv_grad_kernel to provide interface for CUDNNv8 implementation

* Fix macro

* Add implementation for conv_kernel and conv_grad_kernel

* Modification after rebase onto latest develop

* Modify plan cache to comply with the API of phi::autotune

* Refactor to reduce duplicate code

* Review fix:
- move functions in  conv_kernel_impl_v8.h and conv_grad_kernel_impl_v8.h to conv_kernel.cu and conv_grad_kernelk.cu
- add const specifier for input tensor
- add logging when plans fail to execute
- move CudnnConvBwdFilterV8 and CudnnConvBwdDataV8 to conv_cudnn_frontend.h

* - move plan building outside of cache

* Fix ROCM build
上级 593bc4e2
......@@ -40,12 +40,13 @@ inline cudnnDataType_t ToCudnnDataType(const T& t) {
return ToCudnnDataType(type);
}
inline std::vector<int> TransformDimOrder(const std::vector<int>& dims) {
std::vector<int> transformed_dims(dims.begin(), dims.end());
template <typename T>
inline std::vector<T> TransformDimOrder(const std::vector<T>& dims) {
std::vector<T> transformed_dims(dims.begin(), dims.end());
if (dims.size() < 4) {
return transformed_dims;
}
int H, W, D, C;
T H, W, D, C;
if (dims.size() == 4) {
H = dims[1];
W = dims[2];
......
......@@ -1053,4 +1053,17 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
* Note: Enable CUDNNv8 Frontend API for CUDNN kernels.
*/
PADDLE_DEFINE_EXPORTED_bool(enable_cudnn_frontend, false, "");
/**
* CUDNNv8 related FLAG
* Name: cudnn_cache_saturation_count
* Since Version: 2.5.0
* Value Range: int64_t, default=1
* Example:
* Note: Set saturation count for CUDNNv8 cache. A candidate execution
* plan need to be considered as the fastest plan by exhaustive search
* N times before it is actually added in the cache. It is useful when
* the result of exhaustive search is unstable.
*/
PADDLE_DEFINE_EXPORTED_int32(cudnn_cache_saturation_count, 1, "");
#endif // PADDLE_WITH_CUDNN_FRONTEND
......@@ -84,7 +84,9 @@ if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl)
endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils)
if(WITH_CUDNN_FRONTEND)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} cudnn-frontend)
endif()
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
file(GLOB kernel_h "*.h" "selected_rows/*.h" "sparse/*.h" "strings/*.h")
......
cc_library(cache SRCS cache.cc)
if(WITH_CUDNN_FRONTEND)
cc_library(
cache
SRCS cache.cc
DEPS cudnn-frontend)
else()
cc_library(cache SRCS cache.cc)
endif()
cc_library(
switch_autotune
SRCS switch_autotune.cc
......
......@@ -38,6 +38,17 @@ std::string AlgorithmTypeString(int64_t algo_type) {
static_cast<int64_t>(AlgorithmType::kConvBackwardFilter)) {
return "conv_backward_filter";
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForwardV8)) {
return "conv_forward_v8";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardDataV8)) {
return "conv_backward_data_v8";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilterV8)) {
return "conv_backward_filter_v8";
}
#endif
return std::to_string(algo_type);
}
......@@ -71,6 +82,20 @@ void AutoTuneCache::UpdateStatus() {
cache_misses += v.second.CacheMisses();
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
for (auto& v : cudnn_v8_auto_tune_map_) {
VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
<< AlgorithmTypeString(v.first)
<< " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
}
#endif
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
......
......@@ -19,7 +19,9 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/cache_base.h"
#ifdef PADDLE_WITH_CUDNN_FRONTEND
#include "paddle/phi/kernels/autotune/cache_cudnn_frontend.h"
#endif
namespace phi {
namespace autotune {
......@@ -41,8 +43,16 @@ enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
#ifdef PADDLE_WITH_CUDNN_FRONTEND
kConvForwardV8 = 4,
kConvBackwardDataV8 = 5,
kConvBackwardFilterV8 = 6,
kTranspose = 7,
kAlgorithmCount = 8
#else
kTranspose = 4,
kAlgorithmCount = 5
#endif
};
// AlgorithmsConfigKey -> AlgorithmsID
......@@ -53,7 +63,10 @@ using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
using ConvAlgorithmsCacheMap = ConvAlgorithmsCache<ConvAutoTuneResult>;
using ConvAlgorithmsTypeMap =
std::unordered_map<int64_t, ConvAlgorithmsCacheMap>;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
using CudnnV8AlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif
class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
......@@ -69,6 +82,12 @@ class AutoTuneCache {
return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
return cudnn_v8_auto_tune_map_[static_cast<int64_t>(algo_type)];
}
#endif
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
void Clean() {
......@@ -79,6 +98,12 @@ class AutoTuneCache {
for (auto& v : conv_auto_tune_map_) {
v.second.Clean();
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
for (auto& v : cudnn_v8_auto_tune_map_) {
v.second.Clean();
}
#endif
}
void UpdateStatus();
......@@ -117,6 +142,16 @@ class AutoTuneCache {
ConvAlgorithmsCacheMap cache;
conv_auto_tune_map_[key] = cache;
}
#ifdef PADDLE_WITH_CUDNN_FRONTEND
} else if (algo_type == AlgorithmType::kConvForwardV8 ||
algo_type == AlgorithmType::kConvBackwardDataV8 ||
algo_type == AlgorithmType::kConvBackwardFilterV8) {
int64_t key = static_cast<int64_t>(algo_type);
if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
CudnnFrontendPlanCache cache;
cudnn_v8_auto_tune_map_[key] = cache;
}
#endif
} else {
int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
......@@ -128,6 +163,9 @@ class AutoTuneCache {
AlgorithmsTypeMap auto_tune_map_;
ConvAlgorithmsTypeMap conv_auto_tune_map_;
#ifdef PADDLE_WITH_CUDNN_FRONTEND
CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_;
#endif
std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_{0};
int64_t total_cache_misses_{0};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <mutex>
#include <string>
#include <vector>
#include "paddle/phi/backends/dynload/cudnn_frontend.h"
DECLARE_int32(cudnn_cache_saturation_count);
namespace phi {
namespace autotune {
class CudnnFrontendPlanCache {
public:
CudnnFrontendPlanCache() : cache_mutex_(new std::mutex()) {
map_.clear();
tracker_.clear();
saturation_count_ = FLAGS_cudnn_cache_saturation_count;
}
int64_t Size() const { return map_.size(); }
int64_t CacheHits() const { return cache_hits_; }
int64_t CacheMisses() const { return cache_misses_; }
float CacheHitRate() const {
int64_t num_accesses = cache_hits_ + cache_misses_;
float cache_hit_rate = 0.;
if (num_accesses != 0) {
cache_hit_rate =
static_cast<float>(cache_hits_) / static_cast<float>(num_accesses);
}
return cache_hit_rate;
}
void Clean() {
std::lock_guard<std::mutex> lock(*cache_mutex_);
map_.clear();
tracker_.clear();
cache_hits_ = 0;
cache_misses_ = 0;
}
bool FindPlan(const cudnn_frontend::OperationGraph& op_graph,
bool use_addto = false) {
bool ret = false;
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto)) > 0) {
cache_hits_++;
ret = true;
} else {
cache_misses_++;
}
return ret;
}
cudnn_frontend::ManagedOpaqueDescriptor GetConfig(
const cudnn_frontend::OperationGraph& op_graph,
cudnnHandle_t handle,
bool use_addto = false) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
auto engine_config = map_[MakeKey(op_graph, use_addto)];
return engine_config;
}
void InsertPlan(const cudnn_frontend::OperationGraph& op_graph,
const cudnn_frontend::ExecutionPlan& plan,
bool use_addto = false) {
VLOG(4) << "[cudnn_frontend] cache: Insert graph tag: "
<< op_graph.getTag();
std::lock_guard<std::mutex> lock(*cache_mutex_);
map_.insert(
std::make_pair(MakeKey(op_graph, use_addto), plan.GetEngineConfig()));
}
bool IsStable(const cudnn_frontend::OperationGraph& op_graph,
const std::string& tag,
bool use_addto = false) {
if (saturation_count_ == 1) {
return true;
}
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto))) {
return false;
}
int cnt = tracker_[std::make_pair(MakeKey(op_graph, use_addto), tag)] += 1;
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << op_graph.getTag()
<< " " << tag << " " << cnt;
return cnt >= saturation_count_;
}
private:
static cudnn_frontend::feature_vector_t MakeKey(
const cudnn_frontend::OperationGraph& op_graph, bool use_addto) {
auto key = op_graph.getFeatureVector();
key.push_back(static_cast<uint64_t>(use_addto));
return key;
}
std::map<cudnn_frontend::feature_vector_t,
cudnn_frontend::ManagedOpaqueDescriptor>
map_;
std::shared_ptr<std::mutex> cache_mutex_;
int saturation_count_;
using SaturationTracker =
std::map<std::pair<cudnn_frontend::feature_vector_t, std::string>, int>;
SaturationTracker tracker_;
int64_t cache_hits_{0};
int64_t cache_misses_{0};
}; // class CudnnFrontendPlanCache
} // namespace autotune
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/device/gpu/cuda/cudnn_desc.h"
#include "paddle/phi/backends/dynload/cudnn_frontend.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
namespace phi {
class CudnnFrontendConvHelper {
public:
static bool IsNonDeterministic(cudnnBackendDescriptor_t engine_config) {
return cudnn_frontend::hasNumericalNote<
CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(engine_config);
}
static bool AllowAll(cudnnBackendDescriptor_t engine_config) {
(void)engine_config;
return false;
}
static uint8_t GetAlignment(const phi::DenseTensor* tensor) {
// alignment are in bytes
uint8_t alignment = 1;
uint64_t address = reinterpret_cast<uint64_t>(tensor->data());
while (address % alignment == 0 && alignment < 16) alignment *= 2;
return alignment;
}
static std::vector<int64_t> GetInt64Array(const std::vector<int>& in_array) {
std::vector<int64_t> out_array(in_array.size());
for (int i = 0; i < in_array.size(); i++) {
out_array[i] = static_cast<int64_t>(in_array[i]);
}
return out_array;
}
static std::vector<int64_t> GenerateStrides(
const std::vector<int64_t>& dim, cudnnTensorFormat_t filter_format) {
// ref:
// https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/helpers.cpp
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the
// cpu ref.
size_t nb_dims = dim.size();
std::vector<int64_t> stride(nb_dims);
if (filter_format == CUDNN_TENSOR_NCHW) {
stride[nb_dims - 1] = 1;
for (int64_t d = nb_dims - 2; d >= 0; d--) {
stride[d] = stride[d + 1] * dim[d + 1];
}
} else {
// Here we assume that the format is CUDNN_TENSOR_NHWC
stride[1] = 1;
stride[nb_dims - 1] = stride[1] * dim[1];
for (int64_t d = nb_dims - 2; d >= 2; d--) {
stride[d] = stride[d + 1] * dim[d + 1];
}
stride[0] = stride[2] * dim[2];
}
return stride;
}
static cudnn_frontend::Tensor GetTensorDescriptor(
const phi::DenseTensor* tensor,
int64_t id,
cudnnTensorFormat_t layout_format) {
auto transformed_dims = phi::vectorize<int64_t>(tensor->dims());
if (layout_format == CUDNN_TENSOR_NHWC) {
transformed_dims = paddle::platform::TransformDimOrder(transformed_dims);
}
std::vector<int64_t> strides =
GenerateStrides(transformed_dims, layout_format);
return cudnn_frontend::TensorBuilder()
.setDim(transformed_dims.size(), transformed_dims.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(GetAlignment(tensor))
.setDataType(paddle::platform::ToCudnnDataType(
paddle::framework::TransToProtoVarType(tensor->dtype())))
.build();
}
static cudnn_frontend::ConvDesc_v8 GetConvDescriptor(
cudnnDataType_t dataType,
const std::vector<int>& padding,
const std::vector<int>& stride,
const std::vector<int>& dilation) {
uint64_t conv_dim = stride.size();
cudnnDataType_t compute_type =
(dataType == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
std::vector<int64_t> padding_int64 = GetInt64Array(padding);
std::vector<int64_t> stride_int64 = GetInt64Array(stride);
std::vector<int64_t> dilation_int64 = GetInt64Array(dilation);
return cudnn_frontend::ConvDescBuilder()
.setDataType(compute_type)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, stride_int64.data())
.setPrePadding(conv_dim, padding_int64.data())
.setPostPadding(conv_dim, padding_int64.data())
.setDilation(conv_dim, dilation_int64.data())
.build();
}
template <cudnnBackendDescriptorType_t op_mode>
static cudnn_frontend::OperationGraph BuildConvOperationGraph(
const phi::DenseTensor* x_tensor,
const phi::DenseTensor* y_tensor,
const phi::DenseTensor* w_tensor,
cudnnTensorFormat_t layout_format,
const std::vector<int>& strides,
const std::vector<int>& padding_common,
const std::vector<int>& dilations,
cudnnDataType_t dtype,
cudnnHandle_t handle,
float alpha,
float beta) {
auto op = cudnn_frontend::OperationBuilder(op_mode)
.setxDesc(GetTensorDescriptor(x_tensor, 'x', layout_format))
.setyDesc(GetTensorDescriptor(y_tensor, 'y', layout_format))
.setwDesc(GetTensorDescriptor(w_tensor, 'w', layout_format))
.setcDesc(GetConvDescriptor(
dtype, padding_common, strides, dilations))
.setAlpha(alpha)
.setBeta(beta)
.build();
std::array<cudnn_frontend::Operation const*, 1> ops = {&op};
return cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(1, ops.data())
.build();
}
static cudnn_frontend::executionPlans_t FindExecutionPlans(
cudnn_frontend::OperationGraph* op_graph_pointer,
bool exhaustive_search,
bool deterministic,
void* x_data,
void* y_data,
void* w_data,
cudnnHandle_t handle,
phi::DnnWorkspaceHandle* workspace_handle) {
auto heurgen_method = [=](cudnn_frontend::OperationGraph& op_graph_)
-> cudnn_frontend::EngineConfigList {
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(op_graph_)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
VLOG(4) << "Heuristic has " << heuristics.getEngineConfigCount()
<< " configurations ";
auto& engine_configs =
heuristics.getEngineConfig(heuristics.getEngineConfigCount());
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(engine_configs,
filtered_configs,
deterministic ? IsNonDeterministic : AllowAll);
return filtered_configs;
};
auto fallback_method = [=](cudnn_frontend::OperationGraph& op_graph_)
-> cudnn_frontend::EngineConfigList {
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(op_graph_)
.build();
auto& fallback_list = fallback.getFallbackList();
cudnn_frontend::EngineConfigList filtered_configs;
cudnn_frontend::filter(fallback_list,
filtered_configs,
deterministic ? IsNonDeterministic : AllowAll);
return filtered_configs;
};
std::array<cudnn_frontend::GeneratorSource const, 2> sources = {
heurgen_method, fallback_method};
cudnn_frontend::EngineConfigGenerator generator(sources.size(),
sources.data());
size_t workspace_size_limit =
CalcWorkspaceLimitInBytes(UseFixedWorkspace());
auto predicate_function =
[=](cudnn_frontend::ExecutionPlan const& plan) -> bool {
return plan.getWorkspaceSize() > workspace_size_limit;
};
auto plans =
generator.cudnnGetPlan(handle, *op_graph_pointer, predicate_function);
bool use_autotune = phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (!deterministic && (exhaustive_search || use_autotune)) {
size_t workspace_size_max = 0;
std::for_each(
plans.begin(), plans.end(), [&](cudnn_frontend::ExecutionPlan& opt) {
if (opt.getWorkspaceSize() > workspace_size_max) {
workspace_size_max = opt.getWorkspaceSize();
}
});
VLOG(6) << "[cudnn_frontend] Max workspace size: " << workspace_size_max;
workspace_handle->RunFunc(
[&](void* workspace_ptr) {
void* data_ptrs[] = {x_data, y_data, w_data};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
plans =
generator
.cudnnFindPlan<cudnn_frontend::CudnnFindSamplingTechnique::
CUDNN_FIND_SAMPLE_MEDIAN_OF_THREE>(
handle,
*op_graph_pointer,
variant_pack,
predicate_function);
},
workspace_size_max);
}
std::for_each(
plans.begin(), plans.end(), [](cudnn_frontend::ExecutionPlan& opt) {
VLOG(6) << "Plan tag: " << opt.getTag() << " finished in "
<< opt.getExecutionTime() << " ms,"
<< " workspace: " << opt.getWorkspaceSize() << " bytes";
});
return plans;
}
}; // class CudnnFrontendConvHelper
template <typename T>
void CudnnConvBwdDataV8(const DenseTensor* dy_tensor,
const DenseTensor* w_tensor,
cudnnHandle_t handle,
DnnWorkspaceHandle* workspace_handle,
const std::vector<int>& strides,
const std::vector<int>& padding_common,
const std::vector<int>& dilations,
cudnnDataType_t dtype,
cudnnTensorFormat_t layout_format,
bool use_addto,
bool exhaustive_search,
bool deterministic,
DenseTensor* dx_tensor) {
auto& plan_cache_bwd_data =
phi::autotune::AutoTuneCache::Instance().GetConvV8(
phi::autotune::AlgorithmType::kConvBackwardDataV8);
T* dy_tensor_data = const_cast<T*>(dy_tensor->data<T>());
T* w_tensor_data = const_cast<T*>(w_tensor->data<T>());
T* dx_tensor_data = dx_tensor->data<T>();
float alpha = 1.0f;
float beta = use_addto ? 1.0f : 0.0f;
using helper = CudnnFrontendConvHelper;
auto op_graph = helper::BuildConvOperationGraph<
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR>(
dx_tensor,
dy_tensor,
w_tensor,
layout_format,
strides,
padding_common,
dilations,
dtype,
handle,
alpha,
beta);
if (plan_cache_bwd_data.FindPlan(op_graph, use_addto)) {
auto engine_config =
plan_cache_bwd_data.GetConfig(op_graph, handle, use_addto);
auto cached_plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(engine_config, op_graph.getTag())
.build();
auto workspace_size = cached_plan.getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << cached_plan.getTag()
<< "; Require workspace: " << workspace_size;
workspace_handle->RunFunc(
[&](void* workspace_ptr) {
void* data_ptrs[] = {dx_tensor_data, dy_tensor_data, w_tensor_data};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
handle, cached_plan.get_raw_desc(), variant_pack.get_raw_desc()));
},
workspace_size);
return;
}
auto plans = helper::FindExecutionPlans(&op_graph,
exhaustive_search,
deterministic,
dx_tensor_data,
dy_tensor_data,
w_tensor_data,
handle,
workspace_handle);
for (auto& plan : plans) {
try {
int64_t workspace_size = plan.getWorkspaceSize();
workspace_handle->RunFunc(
[&](void* workspace_ptr) {
void* data_ptrs[] = {dx_tensor_data, dy_tensor_data, w_tensor_data};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
},
workspace_size);
if (!exhaustive_search ||
plan_cache_bwd_data.IsStable(op_graph, plan.getTag(), use_addto)) {
plan_cache_bwd_data.InsertPlan(op_graph, plan, use_addto);
}
return;
} catch (cudnn_frontend::cudnnException& e) {
} catch (phi::enforce::EnforceNotMet& e) {
}
}
PADDLE_THROW(
phi::errors::InvalidArgument("[CUDNN Frontend API] No valid plan could "
"be found to execute conv backward data."));
}
template <typename T>
void CudnnConvBwdFilterV8(const DenseTensor* x_tensor,
const DenseTensor* dy_tensor,
cudnnHandle_t handle,
DnnWorkspaceHandle* workspace_handle,
const std::vector<int>& strides,
const std::vector<int>& padding_common,
const std::vector<int>& dilations,
cudnnDataType_t dtype,
cudnnTensorFormat_t layout_format,
bool use_addto,
bool exhaustive_search,
bool deterministic,
DenseTensor* dw_tensor) {
auto& plan_cache_bwd_filter =
phi::autotune::AutoTuneCache::Instance().GetConvV8(
phi::autotune::AlgorithmType::kConvBackwardFilterV8);
T* x_tensor_data = const_cast<T*>(x_tensor->data<T>());
T* dy_tensor_data = const_cast<T*>(dy_tensor->data<T>());
T* dw_tensor_data = dw_tensor->data<T>();
float alpha = 1.0f;
float beta = 0.0f;
using helper = CudnnFrontendConvHelper;
auto op_graph = helper::BuildConvOperationGraph<
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR>(
x_tensor,
dy_tensor,
dw_tensor,
layout_format,
strides,
padding_common,
dilations,
dtype,
handle,
alpha,
beta);
if (plan_cache_bwd_filter.FindPlan(op_graph)) {
auto engine_config = plan_cache_bwd_filter.GetConfig(op_graph, handle);
auto cached_plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(engine_config, op_graph.getTag())
.build();
auto workspace_size = cached_plan.getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << cached_plan.getTag()
<< "; Require workspace: " << workspace_size;
workspace_handle->RunFunc(
[&](void* workspace_ptr) {
void* data_ptrs[] = {x_tensor_data, dy_tensor_data, dw_tensor_data};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
handle, cached_plan.get_raw_desc(), variant_pack.get_raw_desc()));
},
workspace_size);
return;
}
auto plans = helper::FindExecutionPlans(&op_graph,
exhaustive_search,
deterministic,
x_tensor_data,
dy_tensor_data,
dw_tensor_data,
handle,
workspace_handle);
for (auto& plan : plans) {
try {
int64_t workspace_size = plan.getWorkspaceSize();
workspace_handle->RunFunc(
[&](void* workspace_ptr) {
void* data_ptrs[] = {x_tensor_data, dy_tensor_data, dw_tensor_data};
int64_t uids[] = {'x', 'y', 'w'};
auto variant_pack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnBackendExecute(
handle, plan.get_raw_desc(), variant_pack.get_raw_desc()));
},
workspace_size);
if (!exhaustive_search ||
plan_cache_bwd_filter.IsStable(op_graph, plan.getTag())) {
plan_cache_bwd_filter.InsertPlan(op_graph, plan);
}
return;
} catch (cudnn_frontend::cudnnException& e) {
VLOG(4) << "Plan " << plan.describe()
<< "failed to execute. Trying next plan.";
} catch (phi::enforce::EnforceNotMet& e) {
VLOG(4) << "Plan " << plan.describe()
<< "failed to execute. Trying next plan.";
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"[CUDNN Frontend API] No valid plan could "
"be found to execute conv backward filter."));
}
} // namespace phi
......@@ -23,31 +23,6 @@ namespace phi {
using ConvArgs = ConvArgsBase<cudnnHandle_t, cudnnDataType_t>;
static inline double ToMegaBytes(size_t bytes) {
return static_cast<double>(bytes) / (1 << 20);
}
static inline bool UseFixedWorkspace() {
return FLAGS_conv_workspace_size_limit >= 0;
}
static size_t CalcWorkspaceLimitInBytes(bool use_fixed_workspace) {
if (!use_fixed_workspace) {
int device_id = phi::backends::gpu::GetCurrentDeviceId();
int64_t allocated =
paddle::memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
int64_t reserved =
paddle::memory::DeviceMemoryStatCurrentValue("Reserved", device_id);
int64_t availble = paddle::platform::GpuAvailableMemToAlloc();
VLOG(3) << "[memory] allocated=" << ToMegaBytes(allocated)
<< " MB, reserved=" << ToMegaBytes(reserved)
<< " MB, available_to_alloc=" << ToMegaBytes(availble) << " MB.";
return std::max(availble, reserved - allocated);
} else {
return FLAGS_conv_workspace_size_limit * 1024 * 1024;
}
}
template <typename PerfT>
std::string GetPerfResultString(std::string prefix,
const std::vector<PerfT>& perf_results,
......
......@@ -36,6 +36,31 @@ using ScalingParamType =
enum class ConvKind { kForward = 1, kBackwardData = 2, kBackwardFilter = 3 };
static inline double ToMegaBytes(size_t bytes) {
return static_cast<double>(bytes) / (1 << 20);
}
static inline bool UseFixedWorkspace() {
return FLAGS_conv_workspace_size_limit >= 0;
}
static size_t CalcWorkspaceLimitInBytes(bool use_fixed_workspace) {
if (!use_fixed_workspace) {
int device_id = phi::backends::gpu::GetCurrentDeviceId();
int64_t allocated =
paddle::memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
int64_t reserved =
paddle::memory::DeviceMemoryStatCurrentValue("Reserved", device_id);
int64_t availble = paddle::platform::GpuAvailableMemToAlloc();
VLOG(3) << "[memory] allocated=" << ToMegaBytes(allocated)
<< " MB, reserved=" << ToMegaBytes(reserved)
<< " MB, available_to_alloc=" << ToMegaBytes(availble) << " MB.";
return std::max(availble, reserved - allocated);
} else {
return FLAGS_conv_workspace_size_limit * 1024 * 1024;
}
}
// The container of SearchAlgorithm::Find() result.
template <typename AlgoT>
struct SearchResult {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册