未验证 提交 b4adbe5c 编写于 作者: Y Yiqun Liu 提交者: GitHub

[Cherry-pick 2.3] Autotune the workspace and kernel choosing of conv (#41833)

Cherry-pick #40338 #41741 #41313
上级 a9d8b947
......@@ -15,7 +15,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(pylayer)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
add_dependencies(grad_tensor_holder eager_final_state_codegen)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info switch_autotune)
endif()
cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor)
......
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
......
......@@ -9,8 +9,8 @@ cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_f
add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
if(NOT WIN32)
if(WITH_NCCL OR WITH_RCCL)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
using framework::ConvSearchCache;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
// As the basic for SearchAlgorithm struct.
template <typename PerfT>
struct SearchAlgorithm {};
// As the container of searchAlgorithm::Find() result.
template <typename AlgoT>
struct SearchResult {
SearchResult() {}
explicit SearchResult(AlgoT a) : algo(a) {}
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
};
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
// As the container of conv relevant descriptors.
template <typename HandleT, typename DataT>
struct ConvArgsBase {
HandleT handle;
platform::TensorDescriptor idesc, odesc;
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
DataT cudnn_dtype;
// strides
std::vector<int> s;
// paddings
std::vector<int> p;
// dilations
std::vector<int> d;
ConvArgsBase(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d, DataT dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
template <typename T>
size_t GetCacheKey() const {
auto x_shape = phi::vectorize(x->dims());
auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d;
return phi::autotune::ConvKey(
x_shape, w_shape, p, s, d,
paddle::experimental::CppTypeToDataType<T>::Type());
}
};
static inline void GetNCDHW(const framework::DDim& dims,
const DataLayout& layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
*C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
int i = layout == DataLayout::kNCHW ? 0 : 1;
if (dims.size() == 5) {
*D = dims[2 - i];
*H = dims[3 - i];
*W = dims[4 - i];
} else {
*D = 1;
*H = dims[2 - i];
*W = dims[3 - i];
}
}
} // namespace operators
} // namespace paddle
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_int64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
DECLARE_int64(cudnn_exhaustive_search_times);
......
......@@ -14,42 +14,12 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/fluid/operators/conv_base_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
static inline void GetNCDHW(const framework::DDim& dims,
const DataLayout& layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
*C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
int i = layout == DataLayout::kNCHW ? 0 : 1;
if (dims.size() == 5) {
*D = dims[2 - i];
*H = dims[3 - i];
*W = dims[4 - i];
} else {
*D = 1;
*H = dims[2 - i];
*W = dims[3 - i];
}
}
using ConvArgs = ConvArgsBase<miopenHandle_t, miopenDataType_t>;
template <typename DeviceContext, typename T, size_t D>
static void RemovePaddingSlice(const phi::GPUContext& context,
......@@ -66,9 +36,8 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
extents[i] = new_out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
int start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
......@@ -85,41 +54,6 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
out_t.device(place) = in_t.slice(offsets, extents);
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
using framework::ConvSearchCache;
struct ConvArgs {
miopenHandle_t handle;
platform::TensorDescriptor idesc, odesc;
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
miopenDataType_t cudnn_dtype;
// strides
std::vector<int> s;
// paddings
std::vector<int> p;
// dilations
std::vector<int> d;
ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d,
miopenDataType_t dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
};
template <typename algo_t>
struct SearchAlgorithm {};
template <>
struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
using perf_t = miopenConvAlgoPerf_t;
......
......@@ -16,8 +16,6 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
DECLARE_uint64(conv_workspace_size_limit);
namespace paddle {
namespace operators {
......
......@@ -188,6 +188,8 @@ class RecordedGpuMallocHelper {
if (UNLIKELY(malloc_managed_memory)) {
result = cudaMallocManaged(ptr, size);
} else {
VLOG(10) << "[cudaMalloc] size=" << static_cast<double>(size) / (1 << 20)
<< " MB";
result = cudaMalloc(ptr, size);
}
#endif
......@@ -226,6 +228,8 @@ class RecordedGpuMallocHelper {
if (err != hipErrorDeinitialized) {
#else
auto err = cudaFree(ptr);
VLOG(10) << "[cudaFree] size=" << static_cast<double>(size) / (1 << 20)
<< " MB";
if (err != cudaErrorCudartUnloading) {
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
......
......@@ -522,8 +522,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
auto& instance = memory::allocation::AllocatorFacade::Instance();
instance.SetDefaultStream(place, phi::GPUContext::stream());
workspace_.reset(
new phi::DnnWorkspaceHandle(instance.GetAllocator(place).get()));
workspace_.reset(new phi::DnnWorkspaceHandle(
instance.GetAllocator(place).get(), stream()));
}
CUDADeviceContext::~CUDADeviceContext() = default;
......@@ -623,7 +623,8 @@ phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace())
.get());
.get(),
stream());
}
return phi::GPUContext::cudnn_workspace_handle();
}
......
......@@ -158,8 +158,7 @@ PADDLE_DEFINE_EXPORTED_bool(
* increased.
* Users need to balance memory and speed.
*/
PADDLE_DEFINE_EXPORTED_uint64(
conv_workspace_size_limit,
PADDLE_DEFINE_EXPORTED_int64(conv_workspace_size_limit,
paddle::platform::kDefaultConvWorkspaceSizeLimitMB,
"cuDNN convolution workspace limit in MB unit.");
......@@ -800,3 +799,12 @@ DEFINE_bool(enable_ins_parser_file, false,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
#endif
/**
* Autotune related FLAG
* Name: FLAGS_use_autotune
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
*/
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
......@@ -4430,7 +4430,7 @@ All parameter, weight, gradient are variables in Paddle.
return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune();
});
m.def("autotune_range", [](int64_t start, int64_t stop) {
m.def("set_autotune_range", [](int64_t start, int64_t stop) {
return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start,
stop);
});
......@@ -4439,10 +4439,8 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("autotune_status", [] {
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
py::dict res;
res["use_autotune"] =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID();
res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size();
res["cache_hit_rate"] =
......
......@@ -12,6 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
......@@ -155,6 +156,39 @@ static void StreamCallbackFunc(gpuStream_t stream,
} // namespace internal
void DnnWorkspaceHandle::RunFuncSync(
const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
bool use_cached_allocation) {
bool need_realloc = required_workspace_bytes > WorkspaceSize();
if (need_realloc && !use_cached_allocation) {
void* workspace_ptr = nullptr;
size_t size = ((required_workspace_bytes + 255) >> 8) << 8;
std::lock_guard<std::mutex> guard(*mtx_);
#ifdef PADDLE_WITH_HIP
auto status = hipMalloc(&workspace_ptr, size);
#else
auto status = cudaMalloc(&workspace_ptr, size);
#endif
if (status == gpuSuccess) {
cudnn_func(workspace_ptr);
phi::backends::gpu::GpuStreamSync(stream_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr));
#endif
return;
}
}
RunFunc(cudnn_func, required_workspace_bytes);
if (need_realloc) {
// Release the workspace allocated in this running.
ResetWorkspace();
}
}
void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; }
void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
......@@ -295,13 +329,13 @@ struct GPUContext::Impl {
void InitDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
workspace_ = new DnnWorkspaceHandle(allocator_);
workspace_ = new DnnWorkspaceHandle(allocator_, stream_);
}
void DestoryInternalWorkspace() {
if (owned_ && workspace_ != nullptr) {
delete workspace_;
stream_ = nullptr;
workspace_ = nullptr;
}
}
......@@ -313,7 +347,7 @@ struct GPUContext::Impl {
DnnWorkspaceHandle GetDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
return DnnWorkspaceHandle(allocator_);
return DnnWorkspaceHandle(allocator_, stream_);
}
void InitStream() {
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
......@@ -28,8 +29,8 @@ namespace phi {
class DnnWorkspaceHandle {
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {
inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)
: allocator_(allocator), stream_(stream) {
mtx_.reset(new std::mutex());
}
......@@ -48,11 +49,9 @@ class DnnWorkspaceHandle {
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
bool use_cached_allocation = true);
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
......@@ -70,7 +69,8 @@ class DnnWorkspaceHandle {
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
Allocator* allocator_{nullptr}; // Not owned
gpuStream_t stream_{nullptr}; // Not owned
std::unique_ptr<std::mutex> mtx_;
};
......
......@@ -6,12 +6,15 @@ file(APPEND ${kernel_declare_file} "#include \"paddle/phi/core/kernel_registry.h
# phi functors and functions called by kernels
add_subdirectory(funcs)
# kernel autotune
add_subdirectory(autotune)
# phi depends all phi kernel targets
set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor )
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......@@ -27,13 +30,17 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS cross_entropy_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
set(AUTOTUNE_KERNELS conv_kernel conv_grad_kernel conv_grad_grad_kernel conv_transpose_kernel conv_transpose_grad_kernel)
set(MANUAL_BUILD_KERNELS ${AUTOTUNE_KERNELS} cross_entropy_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_sum_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
foreach(src ${AUTOTUNE_KERNELS})
kernel_library(${src} DEPS ${COMMON_KERNEL_DEPS} switch_autotune)
endforeach()
kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
......@@ -74,6 +81,3 @@ add_subdirectory(selected_rows)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels
add_subdirectory(strings)
# 5. kernel autotune
add_subdirectory(autotune)
......@@ -7,5 +7,6 @@ elseif (WITH_ROCM)
endif()
cc_library(cache SRCS cache.cc DEPS boost)
cc_library(switch_autotune SRCS switch_autotune.cc DEPS cache flags)
cc_test(cache_test SRCS cache_test.cc DEPS gtest cache)
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/phi/kernels/autotune/cache.h"
#include <iomanip>
#include "glog/logging.h"
namespace phi {
namespace autotune {
......@@ -32,5 +34,40 @@ size_t ConvKey(const std::vector<int64_t>& x_dims,
static_cast<int64_t>(dtype));
}
std::string AlgorithmTypeString(int64_t algo_type) {
if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForward)) {
return "conv_forward";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardData)) {
return "conv_backward_data";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilter)) {
return "conv_backward_filter";
}
return std::to_string(algo_type);
}
void AutoTuneCache::UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
int name_width = 24;
std::cout.setf(std::ios::left);
for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
<< AlgorithmTypeString(v.first)
<< " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
}
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
} // namespace autotune
} // namespace phi
......@@ -13,11 +13,12 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <mutex>
#include <numeric>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
......@@ -92,6 +93,13 @@ class AlgorithmsCache {
return ret;
}
void Clean() {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_.clear();
cache_hits_ = 0;
cache_misses_ = 0;
}
void Set(size_t key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_[key] = algo;
......@@ -116,15 +124,22 @@ class AlgorithmsCache {
private:
std::unordered_map<size_t, AlgorithmT> hash_;
std::shared_ptr<std::mutex> cache_mutex_;
int64_t cache_hits_ = 0;
int64_t cache_misses_ = 0;
int64_t cache_hits_{0};
int64_t cache_misses_{0};
};
enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kAlgorithmCount = 4
};
// AlgorithmsConfigKey -> AlgorithmsID
using AlgorithmsConfigKeyMap = AlgorithmsCache<int64_t>;
// AlgorithmsType -> AlgorithmsCache
using AlgorithmsTypeMap =
std::unordered_map<std::string, AlgorithmsConfigKeyMap>;
using AlgorithmsCacheMap = AlgorithmsCache<int64_t>;
// AlgorithmType -> AlgorithmsCache
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
class AutoTuneCache {
public:
......@@ -133,42 +148,30 @@ class AutoTuneCache {
return autotune_cache;
}
AlgorithmsConfigKeyMap& RegisterOrGet(const std::string& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
if (auto_tune_map_.find(algo_type) == auto_tune_map_.end()) {
AlgorithmsConfigKeyMap cache;
auto_tune_map_[algo_type] = cache;
AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) {
return auto_tune_map_[static_cast<int64_t>(algo_type)];
}
return auto_tune_map_[algo_type];
AlgorithmsCacheMap& GetConvForward() {
return Get(AlgorithmType::kConvForward);
}
void Clean(float miss_rate) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
// Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
if (miss_rate > 0.01) {
auto_tune_map_.clear();
AlgorithmsCacheMap& GetConvBackwardData() {
return Get(AlgorithmType::kConvBackwardData);
}
AlgorithmsCacheMap& GetConvBackwardFilter() {
return Get(AlgorithmType::kConvBackwardFilter);
}
void UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
void Clean() {
for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
v.second.Clean();
}
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
void UpdateStatus();
// The number of total config cached
int64_t Size() const { return total_size_; }
......@@ -183,17 +186,30 @@ class AutoTuneCache {
total_cache_hit_rate = static_cast<float>(total_cache_hits_) /
static_cast<float>(total_num_accesses);
}
return total_cache_hit_rate;
}
private:
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {}
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {
for (int i = 1; i < static_cast<int>(AlgorithmType::kAlgorithmCount); ++i) {
Register(static_cast<AlgorithmType>(i));
}
}
void Register(const AlgorithmType& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
AlgorithmsCacheMap cache;
auto_tune_map_[key] = cache;
}
}
AlgorithmsTypeMap auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_ = 0;
int64_t total_cache_misses_ = 0;
int64_t total_size_ = 0;
int64_t total_cache_hits_{0};
int64_t total_cache_misses_{0};
int64_t total_size_{0};
};
} // namespace autotune
......
......@@ -22,7 +22,7 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) {
auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.RegisterOrGet("conv_fw");
auto& cache = autotune_cache.GetConvForward();
std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
DECLARE_bool(use_autotune);
namespace phi {
namespace autotune {
void AutoTuneStatus::EnableAutoTune() {
FLAGS_use_autotune = true;
Init();
}
void AutoTuneStatus::DisableAutoTune() {
FLAGS_use_autotune = false;
Init();
}
void AutoTuneStatus::Update() {
current_steps_id_ += 1;
if (!FLAGS_use_autotune) {
return;
}
// This fuction is called when each iter finished.
if (current_steps_id_ + 1 < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ + 1 >= start_step_id_ &&
current_steps_id_ + 1 < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID: " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< static_cast<int>(AutoTuneCache::Instance().CacheHitRate() * 100)
<< "%, Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
} else {
use_autotune_ = false;
// Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
// TODO(limingshu): Currently works for conv op only, this
// method shall be opimized when more ops involved in.
// float miss_rate = static_cast<float>(1) - RecentHitRate();
// if (current_steps_id_ == stop_step_id_) {
// AutoTuneCache::Instance().Clean(miss_rate);
// }
if (VLOG_IS_ON(4)) {
AutoTuneCache::Instance().UpdateStatus();
VLOG(4) << "Step ID: " << current_steps_id_ << ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
}
}
}
} // namespace autotune
} // namespace phi
......@@ -13,10 +13,8 @@
// limitations under the License.
#pragma once
#include <cmath>
#include <mutex>
#include <numeric>
#include "glog/logging.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace phi {
......@@ -31,45 +29,11 @@ class AutoTuneStatus {
bool UseAutoTune() { return use_autotune_; }
// EnableAutoTune and DisableAutoTune Should be used for debug only.
void EnableAutoTune() {
use_autotune_ = true;
Init();
}
void DisableAutoTune() {
use_autotune_ = false;
Init();
}
void Update() {
current_steps_id_ += 1;
// EnableAutoTune and DisableAutoTune should be used for debug only.
void EnableAutoTune();
void DisableAutoTune();
if (!use_autotune_ && !update_use_autotune_) {
return;
}
if (current_steps_id_ < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ >= start_step_id_ &&
current_steps_id_ < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< AutoTuneCache::Instance().CacheHitRate()
<< ", Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: " << StepHitRate();
} else if (current_steps_id_ == stop_step_id_) {
use_autotune_ = false;
update_use_autotune_ = false;
// clean cache according miss rate
float miss_rate = static_cast<float>(1) - RecentHitRate();
AutoTuneCache::Instance().Clean(miss_rate);
VLOG(3) << "Recent Miss Rate: " << miss_rate;
}
}
void Update();
int64_t StepID() { return current_steps_id_; }
......@@ -84,6 +48,9 @@ class AutoTuneStatus {
// Hit Rate of Current Step
float StepHitRate() {
static int64_t last_step_id = -2;
if (last_step_id != current_steps_id_) {
int64_t current_hits = AutoTuneCache::Instance().CacheHits();
int64_t current_misses = AutoTuneCache::Instance().CacheMisses();
int64_t step_hits_ = current_hits - previous_hits_;
......@@ -96,7 +63,10 @@ class AutoTuneStatus {
}
previous_hits_ = current_hits;
previous_misses_ = current_misses;
return step_hit_rate;
current_step_hit_rate_ = step_hit_rate;
last_step_id = current_steps_id_;
}
return current_step_hit_rate_;
}
void SetAutoTuneRange(int64_t start, int64_t stop) {
......@@ -108,21 +78,21 @@ class AutoTuneStatus {
AutoTuneStatus() = default;
void Init() {
update_use_autotune_ = use_autotune_;
use_autotune_ = false;
current_steps_id_ = -1;
previous_hits_ = 0;
previous_misses_ = 0;
step_hit_rates_.clear();
AutoTuneCache::Instance().Clean(1.0);
AutoTuneCache::Instance().Clean();
}
int64_t start_step_id_ = 0;
int64_t stop_step_id_ = 10;
int64_t current_steps_id_ = -1;
bool use_autotune_ = false;
bool update_use_autotune_ = false;
int64_t previous_hits_ = 0;
int64_t previous_misses_ = 0;
bool use_autotune_{false};
int64_t start_step_id_{1};
int64_t stop_step_id_{10};
int64_t current_steps_id_{-1};
int64_t previous_hits_{0};
int64_t previous_misses_{0};
float current_step_hit_rate_{0.f};
std::vector<float> step_hit_rates_;
};
......
......@@ -289,21 +289,17 @@ void ConvCudnnGradGradKernel(
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t fwd_algo1 = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvFwdAlgorithm_t fwd_algo2 = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvBwdDataAlgorithm_t data_algo =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result1;
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result2;
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> data_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionFwdAlgo_t fwd_algo2 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result1;
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result2;
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> data_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
auto layout = paddle::platform::GetCudnnTensorFormat(
......@@ -332,13 +328,13 @@ void ConvCudnnGradGradKernel(
using search1 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
fwd_algo1 = search1::Find<T>(
fwd_result1.algo = search1::Find<T>(
args1, exhaustive_search, false, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
fwd_result1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif
}
......@@ -360,14 +356,14 @@ void ConvCudnnGradGradKernel(
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
fwd_algo2 = search2::Find<T>(
fwd_result2.algo = search2::Find<T>(
args2, exhaustive_search, false, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2));
fwd_result2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif
}
}
......@@ -389,15 +385,15 @@ void ConvCudnnGradGradKernel(
using search3 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo = search3::Find<T>(
filter_result.algo = search3::Find<T>(
args3, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
filter_result =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo));
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
......@@ -419,14 +415,15 @@ void ConvCudnnGradGradKernel(
using search4 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo = search4::Find<T>(
data_result.algo = search4::Find<T>(
args4, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
data_result =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif
}
......@@ -471,7 +468,7 @@ void ConvCudnnGradGradKernel(
args1.wdesc.desc(),
w,
args1.cdesc.desc(),
fwd_algo1,
fwd_result1.algo,
&beta,
args1.odesc.desc(),
transformed_ddy_channel,
......@@ -492,7 +489,7 @@ void ConvCudnnGradGradKernel(
args1.wdesc.desc(),
w + i * group_offset_filter,
args1.cdesc.desc(),
fwd_algo1,
fwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -517,7 +514,7 @@ void ConvCudnnGradGradKernel(
args2.wdesc.desc(),
ddw,
args2.cdesc.desc(),
fwd_algo2,
fwd_result2.algo,
&beta,
args2.odesc.desc(),
transformed_ddy_channel,
......@@ -538,7 +535,7 @@ void ConvCudnnGradGradKernel(
args2.wdesc.desc(),
ddw + i * group_offset_filter,
args2.cdesc.desc(),
fwd_algo2,
fwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
......@@ -568,7 +565,7 @@ void ConvCudnnGradGradKernel(
args3.idesc.desc(),
ddx,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args3.wdesc.desc(),
dw,
......@@ -589,7 +586,7 @@ void ConvCudnnGradGradKernel(
args3.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -615,7 +612,7 @@ void ConvCudnnGradGradKernel(
args4.wdesc.desc(),
ddw,
args4.cdesc.desc(),
data_algo,
data_result.algo,
&beta,
args4.idesc.desc(),
transformed_dx,
......@@ -636,7 +633,7 @@ void ConvCudnnGradGradKernel(
args4.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args4.cdesc.desc(),
data_algo,
data_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -322,17 +322,16 @@ void ConvCudnnGradKernel(const Context& ctx,
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = transformed_filter_channel.numel() / groups;
// ------------------- cudnn backward algorithm ---------------------
#ifdef PADDLE_WITH_HIP
miopenConvBwdDataAlgorithm_t data_algo =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
// input data workspace_size
size_t workspace_size_d = 0;
......@@ -368,14 +367,14 @@ void ConvCudnnGradKernel(const Context& ctx,
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size_d =
std::max(workspace_size_d, search1::GetWorkspaceSize(args1));
data_algo = search1::Find<T>(
bwd_result.algo = search1::Find<T>(
args1, exhaustive_search, deterministic, workspace_size_d, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size_d =
std::max(workspace_size_d, search1::GetWorkspaceSize(args1, data_algo));
bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size_d = std::max(
workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo));
#endif
}
......@@ -397,15 +396,17 @@ void ConvCudnnGradKernel(const Context& ctx,
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size_w =
std::max(workspace_size_w, search2::GetWorkspaceSize(args2));
filter_algo = search2::Find<T>(
filter_result.algo = search2::Find<T>(
args2, exhaustive_search, deterministic, workspace_size_w, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
filter_result =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
workspace_size_w = std::max(workspace_size_w,
search2::GetWorkspaceSize(args2, filter_algo));
VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time;
workspace_size_w = std::max(
workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif
}
......@@ -439,7 +440,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
&beta,
args1.idesc.desc(),
temp_tensor_data,
......@@ -471,7 +472,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
&beta,
args1.idesc.desc(),
transformed_input_grad_data,
......@@ -494,7 +495,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.odesc.desc(),
output_grad_data + i * group_offset_out,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
cudnn_workspace_ptr,
workspace_size_d,
&beta,
......@@ -554,7 +555,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.idesc.desc(),
input_data,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args2.wdesc.desc(),
filter_grad_data,
......@@ -575,7 +576,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.odesc.desc(),
output_grad_data + i * group_offset_out,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
cudnn_workspace_ptr,
workspace_size_w,
&beta_filter,
......
......@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/eigen.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/conv_miopen_helper.h"
#else
......@@ -68,7 +67,6 @@ void ConvCudnnKernel(const Context& ctx,
"FLAGS_cudnn_deterministic True at same time."));
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
auto dtype = paddle::platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
......@@ -309,17 +307,17 @@ void ConvCudnnKernel(const Context& ctx,
size_t workspace_size = 0; // final workspace to allocate.
// ------------------- cudnn conv algorithm ---------------------
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t algo{};
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
using search = paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search::GetWorkspaceSize(args);
algo = search::Find<T>(
fwd_result.algo = search::Find<T>(
args, exhaustive_search, deterministic, workspace_size, ctx);
#else
cudnnConvolutionFwdAlgo_t algo{};
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, fwd_result.algo);
#endif
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1)
......@@ -328,7 +326,7 @@ void ConvCudnnKernel(const Context& ctx,
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
// FWD_ALGO_IMPLICIT_GEMM manually.
if (groups > 1) {
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
fwd_result.algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
}
#endif
......@@ -352,7 +350,7 @@ void ConvCudnnKernel(const Context& ctx,
args.wdesc.desc(),
filter_data,
args.cdesc.desc(),
algo,
fwd_result.algo,
&beta,
args.odesc.desc(),
output_data,
......@@ -373,7 +371,7 @@ void ConvCudnnKernel(const Context& ctx,
args.wdesc.desc(),
filter_data + i * group_offset_filter,
args.cdesc.desc(),
algo,
fwd_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -188,11 +188,13 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t data_algo{};
miopenConvBwdWeightsAlgorithm_t filter_algo{};
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionFwdAlgo_t data_algo{};
cudnnConvolutionBwdFilterAlgo_t filter_algo{};
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
auto layout_tensor = paddle::platform::GetCudnnTensorFormat(layout);
......@@ -218,14 +220,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
using search1 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1));
data_algo =
fwd_result.algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
fwd_result = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo));
#endif
}
......@@ -245,14 +247,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
using search2 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2));
filter_algo =
filter_result.algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo));
filter_result = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif
}
......@@ -278,7 +280,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data + filter_offset * g,
args1.cdesc.desc(),
data_algo,
fwd_result.algo,
&beta,
args1.odesc.desc(),
dx_data + x_offset * g,
......@@ -295,7 +297,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data + filter_offset * g,
args1.cdesc.desc(),
data_algo,
fwd_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......@@ -338,7 +340,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args2.idesc.desc(),
dout_data + dout_offset * g,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args2.wdesc.desc(),
dfilter_data + filter_offset * g,
......@@ -355,7 +357,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args2.odesc.desc(),
x_data + x_offset * g,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......@@ -653,22 +655,17 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
dilations_,
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvBwdDataAlgorithm_t bwd_algo1 =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdDataAlgorithm_t bwd_algo2 =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvFwdAlgorithm_t data_algo = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result1;
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result2;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
#else
cudnnConvolutionBwdDataAlgo_t bwd_algo1 =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdDataAlgo_t bwd_algo2 =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionFwdAlgo_t data_algo =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result1;
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result2;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
#endif
auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW);
......@@ -696,13 +693,13 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search1 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
bwd_algo1 =
bwd_result1.algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo1 = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1);
bwd_result1 = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo);
#endif
ddfilter_ = ddfilter.data<T>();
......@@ -720,14 +717,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search2 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2));
bwd_algo2 =
bwd_result2.algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo2 = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2));
bwd_result2 = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo));
#endif
}
......@@ -736,9 +733,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.handle = handle;
args3.idesc.set(transformed_dout, iwo_group);
args3.wdesc.set(*dfilter, layout, iwo_group);
args3.odesc.set(transformed_ddx_channel, iwo_group);
args3.cdesc.set(dtype,
padding_common,
strides,
......@@ -749,14 +744,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search3 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo =
filter_result.algo =
search3::Find<T>(args3, false, deterministic, workspace_size, ctx);
#else
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search3::Find<T>(args3, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo));
filter_result = search3::Find<T>(args3, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
......@@ -777,14 +772,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search4 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo =
fwd_result.algo =
search4::Find<T>(args4, false, deterministic, workspace_size, ctx);
#else
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search4::Find<T>(args4, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
fwd_result = search4::Find<T>(args4, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo));
#endif
}
......@@ -831,7 +826,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args1.wdesc.desc(),
filter_ + i * group_offset_filter,
args1.cdesc.desc(),
bwd_algo1,
bwd_result1.algo,
&beta,
args1.idesc.desc(),
transformed_ddout_channel_ + i * group_offset_out,
......@@ -850,7 +845,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args1.odesc.desc(),
ddx_ + i * group_offset_in,
args1.cdesc.desc(),
bwd_algo1,
bwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -877,7 +872,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args2.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args2.cdesc.desc(),
bwd_algo2,
bwd_result2.algo,
&beta,
args2.idesc.desc(),
conv_x_ddfilter_data + i * group_offset_out,
......@@ -908,7 +903,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args2.odesc.desc(),
x_ + i * group_offset_in,
args2.cdesc.desc(),
bwd_algo2,
bwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
......@@ -964,7 +959,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.idesc.desc(),
transformed_dout_channel_ + i * group_offset_out,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args3.wdesc.desc(),
dfilter_ + i * group_offset_filter,
......@@ -983,7 +978,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.odesc.desc(),
ddx_ + i * group_offset_in,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -1009,7 +1004,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args4.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args4.cdesc.desc(),
data_algo,
fwd_result.algo,
&beta,
args4.odesc.desc(),
transformed_dx_ + i * group_offset_in,
......@@ -1028,7 +1023,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args4.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args4.cdesc.desc(),
data_algo,
fwd_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -217,16 +217,19 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
c_groups);
#ifdef PADDLE_WITH_HIP
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result;
using search =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args));
algo = search::Find<T>(args, false, deterministic, workspace_size, ctx);
bwd_result.algo =
search::Find<T>(args, false, deterministic, workspace_size, ctx);
#else
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
algo = search::Find<T>(args, false, deterministic, ctx);
bwd_result = search::Find<T>(args, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, algo));
std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo));
#endif
// ------------------- cudnn conv transpose forward ---------------------
......@@ -247,7 +250,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
args.wdesc.desc(),
filter_data + filter_offset * g,
args.cdesc.desc(),
algo,
bwd_result.algo,
&beta,
args.idesc.desc(),
transformed_out_data + out_offset * g,
......@@ -264,7 +267,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
args.odesc.desc(),
x_data + x_offset * g,
args.cdesc.desc(),
algo,
bwd_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......
......@@ -36,7 +36,7 @@
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
DECLARE_bool(cudnn_deterministic);
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_int64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
namespace phi {
......
......@@ -14,7 +14,7 @@
import paddle
import unittest
import numpy
import numpy as np
class SimpleNet(paddle.nn.Layer):
......@@ -27,6 +27,7 @@ class SimpleNet(paddle.nn.Layer):
def train_dygraph(net, data):
data.stop_gradient = False
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam(parameters=net.parameters())
......@@ -36,6 +37,7 @@ def train_dygraph(net, data):
def static_program(net, data):
data.stop_gradient = False
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam()
......@@ -44,47 +46,62 @@ def static_program(net, data):
class TestAutoTune(unittest.TestCase):
def set_flags(self, enable_autotune):
if paddle.is_compiled_with_cuda():
if enable_autotune:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1})
else:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512})
def get_flags(self, name):
res = paddle.get_flags(name)
return res[name]
def get_expected_res(self, step_id, enable_autotune):
expected_res = {
"step_id": step_id,
"cache_size": 0,
"cache_hit_rate": 0
}
if paddle.is_compiled_with_cuda():
# Total 3 * num_iters cache accesses, only iter 2 hits the cache.
if enable_autotune and step_id >= 1:
expected_res["cache_size"] = 3
if enable_autotune and step_id == 2:
expected_res["cache_hit_rate"] = np.round(
float(3) / float(9), 5)
return expected_res
def test_autotune(self):
paddle.fluid.core.disable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], False)
self.assertEqual(self.get_flags("FLAGS_use_autotune"), False)
paddle.fluid.core.enable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], True)
self.assertEqual(self.get_flags("FLAGS_use_autotune"), True)
def check_status(self, expected_res):
status = paddle.fluid.core.autotune_status()
for key in status.keys():
self.assertEqual(status[key], expected_res[key])
if key == "cache_hit_rate":
v = np.round(status[key], 5)
else:
v = status[key]
self.assertEqual(v, expected_res[key])
class TestDygraphAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
self.set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
paddle.fluid.core.set_autotune_range(1, 2)
x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.)
net = SimpleNet()
for i in range(3):
train_dygraph(net, x_var)
if i >= 1 and i < 2:
expected_res = {
"step_id": i,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
expected_res = self.get_expected_res(i, enable_autotune)
self.check_status(expected_res)
def func_enable_autotune(self):
......@@ -107,42 +124,32 @@ class TestDygraphAutoTuneStatus(TestAutoTune):
class TestStaticAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
paddle.enable_static()
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
data_shape = [1, 1, 8, 8]
data = paddle.static.data(name='X', shape=data_shape, dtype='float32')
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='X', shape=data_shape, dtype='float32')
net = SimpleNet()
loss = static_program(net, data)
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
x = numpy.random.random(size=data_shape).astype('float32')
exe.run(startup_program)
x = np.random.random(size=data_shape).astype('float32')
self.set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.set_autotune_range(1, 2)
for i in range(3):
exe.run(feed={'X': x}, fetch_list=[loss])
exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
status = paddle.fluid.core.autotune_status()
# In static mode, the startup_program will run at first.
# The expected step_id will be increased by 1.
if i >= 0 and i < 1:
expected_res = {
"step_id": i + 1,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i + 1,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
expected_res = self.get_expected_res(i, enable_autotune)
self.check_status(expected_res)
paddle.disable_static()
......@@ -150,16 +157,12 @@ class TestStaticAutoTuneStatus(TestAutoTune):
self.run_program(enable_autotune=True)
def test_enable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_enable_autotune()
self.func_enable_autotune()
def func_disable_autotune(self):
self.run_program(enable_autotune=False)
def test_disable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_disable_autotune()
self.func_disable_autotune()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册