From b4adbe5c5257c4eacf75779a34dd80b6c7f57ffc Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 19 Apr 2022 16:47:38 +0800 Subject: [PATCH] [Cherry-pick 2.3] Autotune the workspace and kernel choosing of conv (#41833) Cherry-pick #40338 #41741 #41313 --- paddle/fluid/eager/CMakeLists.txt | 2 +- paddle/fluid/framework/conv_search_cache.h | 1 - paddle/fluid/imperative/CMakeLists.txt | 4 +- paddle/fluid/operators/conv_base_helper.h | 113 ++ paddle/fluid/operators/conv_cudnn_helper.h | 1037 ++++++++++------- paddle/fluid/operators/conv_cudnn_op_cache.h | 2 +- paddle/fluid/operators/conv_miopen_helper.h | 72 +- .../fused/fusion_conv_inception_op.cu | 2 - paddle/fluid/platform/device/gpu/gpu_info.cc | 4 + paddle/fluid/platform/device_context.cc | 7 +- paddle/fluid/platform/flags.cc | 16 +- paddle/fluid/pybind/pybind.cc | 6 +- paddle/phi/backends/gpu/gpu_context.cc | 40 +- paddle/phi/backends/gpu/gpu_context.h | 16 +- paddle/phi/kernels/CMakeLists.txt | 14 +- paddle/phi/kernels/autotune/CMakeLists.txt | 9 +- paddle/phi/kernels/autotune/cache.cc | 37 + paddle/phi/kernels/autotune/cache.h | 96 +- paddle/phi/kernels/autotune/cache_test.cc | 2 +- .../phi/kernels/autotune/switch_autotune.cc | 74 ++ paddle/phi/kernels/autotune/switch_autotune.h | 94 +- .../kernels/gpudnn/conv_grad_grad_kernel.cu | 71 +- paddle/phi/kernels/gpudnn/conv_grad_kernel.cu | 43 +- paddle/phi/kernels/gpudnn/conv_kernel.cu | 18 +- .../gpudnn/conv_transpose_grad_kernel.cu | 107 +- .../kernels/gpudnn/conv_transpose_kernel.cu | 13 +- paddle/phi/kernels/impl/conv_cudnn_impl.h | 2 +- .../tests/unittests/test_switch_autotune.py | 115 +- 28 files changed, 1178 insertions(+), 839 deletions(-) create mode 100644 paddle/fluid/operators/conv_base_helper.h create mode 100644 paddle/phi/kernels/autotune/switch_autotune.cc diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index da326ff7d76..53ac895bfbc 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -15,7 +15,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) add_subdirectory(pylayer) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) add_dependencies(grad_tensor_holder eager_final_state_codegen) - cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) + cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info switch_autotune) endif() cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor) diff --git a/paddle/fluid/framework/conv_search_cache.h b/paddle/fluid/framework/conv_search_cache.h index 51446f287e9..4da2aeb4d04 100644 --- a/paddle/fluid/framework/conv_search_cache.h +++ b/paddle/fluid/framework/conv_search_cache.h @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator_kernel_configs.h" - #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" namespace paddle { diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 3d8a5ab21f0..69cd45222ce 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -9,8 +9,8 @@ cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_f add_subdirectory(jit) cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper) -cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) -cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) +cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune) +cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune) cc_library(imperative_profiler SRCS profiler.cc DEPS flags) if(NOT WIN32) if(WITH_NCCL OR WITH_RCCL) diff --git a/paddle/fluid/operators/conv_base_helper.h b/paddle/fluid/operators/conv_base_helper.h new file mode 100644 index 00000000000..9e1a323fc9f --- /dev/null +++ b/paddle/fluid/operators/conv_base_helper.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/conv_search_cache.h" +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/kernels/autotune/cache.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = platform::DataLayout; +using framework::AlgorithmsCache; +using framework::ConvSearchCache; + +template +using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; + +// As the basic for SearchAlgorithm struct. +template +struct SearchAlgorithm {}; + +// As the container of searchAlgorithm::Find() result. +template +struct SearchResult { + SearchResult() {} + explicit SearchResult(AlgoT a) : algo(a) {} + + AlgoT algo = static_cast(0); + float time = -1.f; + size_t workspace_size = 0; +}; + +template +static std::ostream& operator<<(std::ostream& out, const std::vector& v) { + out << "["; + for (auto const& tmp : v) out << tmp << ","; + out << "]"; + return out; +} + +// As the container of conv relevant descriptors. +template +struct ConvArgsBase { + HandleT handle; + platform::TensorDescriptor idesc, odesc; + platform::FilterDescriptor wdesc; + platform::ConvolutionDescriptor cdesc; + const framework::Tensor *x, *w, *o; + DataT cudnn_dtype; + + // strides + std::vector s; + // paddings + std::vector p; + // dilations + std::vector d; + + ConvArgsBase(const framework::Tensor* x, const framework::Tensor* w, + const framework::Tensor* o, const std::vector s, + const std::vector p, const std::vector d, DataT dtype) + : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} + + template + size_t GetCacheKey() const { + auto x_shape = phi::vectorize(x->dims()); + auto w_shape = phi::vectorize(w->dims()); + VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape + << ", strides=" << s << ", paddings=" << p << ", dilations=" << d; + return phi::autotune::ConvKey( + x_shape, w_shape, p, s, d, + paddle::experimental::CppTypeToDataType::Type()); + } +}; + +static inline void GetNCDHW(const framework::DDim& dims, + const DataLayout& layout, int* N, int* C, int* D, + int* H, int* W) { + *N = dims[0]; + *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + int i = layout == DataLayout::kNCHW ? 0 : 1; + if (dims.size() == 5) { + *D = dims[2 - i]; + *H = dims[3 - i]; + *W = dims[4 - i]; + } else { + *D = 1; + *H = dims[2 - i]; + *W = dims[3 - i]; + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index 4e6fda3d09a..419fb8a4ca7 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -14,44 +14,17 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/conv_search_cache.h" -#include "paddle/fluid/framework/operator_kernel_configs.h" -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/conv_base_helper.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using DataLayout = platform::DataLayout; -template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -using framework::AlgorithmsCache; -static inline void GetNCDHW(const framework::DDim& dims, - const DataLayout& layout, int* N, int* C, int* D, - int* H, int* W) { - *N = dims[0]; - *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - int i = layout == DataLayout::kNCHW ? 0 : 1; - if (dims.size() == 5) { - *D = dims[2 - i]; - *H = dims[3 - i]; - *W = dims[4 - i]; - } else { - *D = 1; - *H = dims[2 - i]; - *W = dims[3 - i]; - } -} +using ConvArgs = ConvArgsBase; template static void RemovePaddingSlice(const phi::GPUContext& context, @@ -68,121 +41,117 @@ static void RemovePaddingSlice(const phi::GPUContext& context, extents[i] = new_out_dims[i]; } - int start; for (size_t i = 0; i < axes.size(); ++i) { - start = starts[i]; + int start = starts[i]; if (start < 0) { start = (start + in_dims[axes[i]]); } start = std::max(start, 0); offsets[axes[i]] = start; } + auto in_t = framework::EigenTensor::From( *input); - auto out_t = framework::EigenTensor::From( *out, new_out_dims); - EigenSlice, T, D>::Eval(place, out_t, in_t, - offsets, extents); + + phi::funcs::EigenSlice, T, D>::Eval( + place, out_t, in_t, offsets, extents); } -template -std::ostream& operator<<(std::ostream& out, const std::vector& v) { - out << "["; - for (auto const& tmp : v) out << tmp << ","; - out << "]"; - return out; +static inline double ToMegaBytes(size_t bytes) { + return static_cast(bytes) / (1 << 20); } -inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) { - int max_algos = 0; -#if CUDNN_VERSION_MIN(7, 0, 1) - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( - cudnn_handle, &max_algos)); -#endif - return max_algos; +static inline bool UseFixedWorkspace() { + return FLAGS_conv_workspace_size_limit >= 0; } -template -void ChooseAlgoByWorkspace(PerfType* perf_results, size_t perf_num, - size_t workspace_byte, AlgoType* algo) { - for (size_t i = 0; i < perf_num; ++i) { - auto result = perf_results[i]; - if (result.status == CUDNN_STATUS_SUCCESS && - result.memory < workspace_byte) { - *algo = result.algo; - VLOG(3) << " algo: " << result.algo << ", time: " << result.time - << " ms, wksp = " << result.memory - << ", status = " << result.status; - return; - } +static size_t CalcWorkspaceLimitInBytes(bool use_fixed_workspace) { + if (!use_fixed_workspace) { + int device_id = platform::GetCurrentDeviceId(); + int64_t allocated = memory::StatGetCurrentValue("Allocated", device_id); + int64_t reserved = memory::StatGetCurrentValue("Reserved", device_id); + int64_t availble = platform::GpuAvailableMemToAlloc(); + VLOG(3) << "[memory] allocated=" << ToMegaBytes(allocated) + << " MB, reserved=" << ToMegaBytes(reserved) + << " MB, available_to_alloc=" << ToMegaBytes(availble) << " MB."; + return std::max(availble, reserved - allocated); + } else { + return FLAGS_conv_workspace_size_limit * 1024 * 1024; } - VLOG(3) << "Can not find alog that requires memory < " - << static_cast(workspace_byte) / (1 << 20) << " MB"; } -template -void ChooseAlgo(const std::vector& perf_results, - size_t workspace_byte, AlgoType* algo) { - VLOG(3) << "=========BwdFilterAlgo Perf result========="; - for (const auto& result : perf_results) { - auto math_type_str = "False"; - if (result.mathType == CUDNN_TENSOR_OP_MATH) { - math_type_str = "True"; - } - VLOG(3) << " algo: " << result.algo << ", TensorCore: " << math_type_str - << ", time: " << result.time << " ms" - << ", wksp = " << result.memory << ", status = " << result.status; +template +std::string GetPerfResultString(std::string prefix, + const std::vector& perf_results, + int actual_algo_count, size_t workspace_limit) { + std::ostringstream out; + out << prefix << " (workspace limit=" << ToMegaBytes(workspace_limit) + << " MB):\n"; + for (int i = 0; i < actual_algo_count; ++i) { + const auto& result = perf_results[i]; + auto math_type_str = (result.mathType == CUDNN_TENSOR_OP_MATH) ? "T" : "F"; + out << " algo=" << result.algo << ": tensor_core=" << math_type_str + << ", time=" << result.time + << " ms, memory=" << ToMegaBytes(result.memory) + << " MB, status=" << result.status << "\n"; } + return out.str(); +} - for (size_t i = 0; i != perf_results.size(); ++i) { - const auto& result = perf_results[i]; +// Choose an algorithm which has the minimize time cost and less memory. +// NOTE: perf_results is ordered by time. +template +void ChooseAlgoByWorkspace(const std::vector& perf_results, + size_t workspace_limit, + SearchResult* search_result) { + int best_algo_idx = -1; + for (size_t i = 0; i < perf_results.size(); ++i) { + auto result = perf_results[i]; if (result.status == CUDNN_STATUS_SUCCESS && - (result.memory <= workspace_byte)) { - if ((result.mathType == CUDNN_TENSOR_OP_MATH) && - (i != perf_results.size() - 1)) { - const auto& next_result = perf_results[i + 1]; - if (next_result.status == CUDNN_STATUS_SUCCESS && - next_result.algo == result.algo && - next_result.memory == result.memory && - next_result.mathType != CUDNN_TENSOR_OP_MATH && - next_result.time < 1.01 * result.time) { - // Skip over this result- it's not really a Tensor Core algo. - // Because it is only 1% performance difference. - // Prefer to choose the next equivalent non-Tensor Core algo. - continue; + result.memory < workspace_limit) { + if (best_algo_idx == -1) { + // The algorithm which has minimize time cost and need a workspace_size + // fitting the workspace_limit constraint. + best_algo_idx = i; + // Each perf_results[i].time is set to be -1 in heuristic search. + if (perf_results[best_algo_idx].time < 0) { + break; + } + } else { + float best_algo_time = perf_results[best_algo_idx].time; + if ((result.time - best_algo_time) / best_algo_time < 0.01) { + best_algo_idx = (result.memory < perf_results[best_algo_idx].memory) + ? i + : best_algo_idx; + break; } } - *algo = result.algo; - auto math_type_str = "0"; - if (result.mathType == CUDNN_TENSOR_OP_MATH) { - math_type_str = "1"; - } - VLOG(3) << " choose algo: " << result.algo << ", TC: " << math_type_str - << ", time: " << result.time << " ms" - << ", wksp = " << result.memory << ", status = " << result.status; - break; } } + if (best_algo_idx != -1) { + search_result->algo = perf_results[best_algo_idx].algo; + search_result->time = perf_results[best_algo_idx].time; + search_result->workspace_size = perf_results[best_algo_idx].memory; + } else { + VLOG(3) << "Can not find an algorithm that requires memory < " + << ToMegaBytes(workspace_limit) << " MB"; + } } -using framework::ConvSearchCache; - static void SetConvMathType(const phi::GPUContext& ctx, cudnnDataType_t dtype, const platform::ConvolutionDescriptor& cdesc) { #if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1) - auto& dev_ctx = ctx; - if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { + if (ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cdesc.desc(), CUDNN_TENSOR_OP_MATH)); VLOG(5) << "use cudnn_tensor_op_math"; #if CUDA_VERSION >= 11000 #if CUDNN_VERSION_MIN(8, 1, 0) - } else if (dev_ctx.GetComputeCapability() >= 80 && - dtype == CUDNN_DATA_BFLOAT16) { + } else if (ctx.GetComputeCapability() >= 80 && dtype == CUDNN_DATA_BFLOAT16) { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( cdesc.desc(), CUDNN_TENSOR_OP_MATH)); #endif // CUDNN_VERSION_MIN(8, 1, 0) @@ -198,411 +167,593 @@ static void SetConvMathType(const phi::GPUContext& ctx, cudnnDataType_t dtype, #endif } -struct ConvArgs { - cudnnHandle_t handle; - platform::TensorDescriptor idesc, odesc; - platform::FilterDescriptor wdesc; - platform::ConvolutionDescriptor cdesc; - const framework::Tensor *x, *w, *o; - cudnnDataType_t cudnn_dtype; - - // strides - std::vector s; - // paddings - std::vector p; - // dilations - std::vector d; - - ConvArgs(const framework::Tensor* x, const framework::Tensor* w, - const framework::Tensor* o, const std::vector s, - const std::vector p, const std::vector d, - cudnnDataType_t dtype) - : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} -}; - -template -struct SearchAlgorithm {}; - +// cuDNN convolution forward algorithm searcher, consisted of three searching +// modes, namely: deterministic, heuristic and exhaustive_search mode. +// As well as one workspace size acquirsition function with respect to +// the chosen alogrithm. template <> struct SearchAlgorithm { - using perf_t = cudnnConvolutionFwdAlgoPerf_t; - using algo_t = cudnnConvolutionFwdAlgo_t; + using PerfT = cudnnConvolutionFwdAlgoPerf_t; + using AlgoT = cudnnConvolutionFwdAlgo_t; template - static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, const phi::GPUContext& ctx) { + static SearchResult Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const phi::GPUContext& ctx) { + SearchResult result; auto dtype = platform::CudnnDataType::type; - bool has_got_workspace_size = true; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; - algo_t algo; SetConvMathType(ctx, dtype, args.cdesc); - if (!exhaustive_search && !deterministic) { + if (deterministic) { + result = FindAlgoDeterministic(); + } else { + // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. + // 2. Once turning on auto-tune, runn heuristic search(default) before + // auto-tune process, run exhaustive_search during mentioned process. + // 3. After auto-tune process, run cached algorithm if cached, run + // default mode for the rest. + size_t key = args.GetCacheKey(); + auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvForward(); + if (cache.Find(key)) { + result.algo = static_cast(cache.Get(key)); + } else { + bool use_autotune = + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + if (exhaustive_search || use_autotune) { + result = FindAlgoExhaustiveSearch(args, ctx); + cache.Set(key, static_cast(result.algo)); + } else { + result = FindAlgoHeuristic(args, ctx); + } + } + } + VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search + << ", deterministic=" << deterministic + << ", choose algo=" << result.algo << ", workspace=" + << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + return result; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, + cudnnConvolutionFwdAlgo_t algo) { + size_t workspace_size = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + args.handle, args.idesc.desc(), args.wdesc.desc(), + args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size)); + return workspace_size; + } + + private: + static SearchResult FindAlgoDeterministic() { + return SearchResult(static_cast(1)); + } + + // Heuristic search mode, calling the cudnnGetXxxAlgorithm. + static SearchResult FindAlgoHeuristic(const ConvArgs& args, + const phi::GPUContext& ctx) { + SearchResult result; + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + #if CUDNN_VERSION >= 7001 - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7( - args.handle, args.idesc.desc(), args.wdesc.desc(), - args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, - &perf_count, perf_results.get())); - algo = (perf_results.get())[best_algo_idx].algo; - workspace_size = (perf_results.get())[best_algo_idx].memory; + int actual_perf_count; + int best_algo_idx = 0; + std::vector perf_results(kNUM_CUDNN_FWD_ALGS); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7( + args.handle, args.idesc.desc(), args.wdesc.desc(), + args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, + &actual_perf_count, perf_results.data())); + result.algo = perf_results[best_algo_idx].algo; + result.workspace_size = perf_results[best_algo_idx].memory; - if (workspace_size > workspace_size_limit) { + if (result.workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 - // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8 - ChooseAlgoByWorkspace(perf_results.get(), - kNUM_CUDNN_FWD_ALGS, - workspace_size_limit, &algo); -#else - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << workspace_size << ") exceeds the limit(" - << workspace_size_limit << ")"; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardAlgorithm( - args.handle, args.idesc.desc(), args.wdesc.desc(), - args.cdesc.desc(), args.odesc.desc(), - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); -#endif - } + // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); #else + VLOG(3) << "Fallback to non-v7 method to find conv algorithm " + "becasue the workspace size request(" + << result.workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnGetConvolutionForwardAlgorithm( args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); + workspace_size_limit, &(result.algo))); #endif - VLOG(3) << "choose algo " << algo; - } else if (deterministic) { - algo = static_cast(1); - } else { - auto& dev_ctx = ctx; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetForward()); - - auto x_dims = phi::vectorize(args.x->dims()); - auto w_dims = phi::vectorize(args.w->dims()); - - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - args.handle, args.idesc.desc(), args.x->data(), - args.wdesc.desc(), args.w->data(), args.cdesc.desc(), - args.odesc.desc(), const_cast(args.o->data()), - kNUM_CUDNN_FWD_ALGS, &returned_algo_count, - perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return perf_stat[0].algo; - }); } - VLOG(3) << "choose algo " << algo; - return algo; - } - - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { - size_t workspace_size = 0; +#else PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + platform::dynload::cudnnGetConvolutionForwardAlgorithm( args.handle, args.idesc.desc(), args.wdesc.desc(), - args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size)); - return workspace_size; + args.cdesc.desc(), args.odesc.desc(), + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit, + &(result.algo))); +#endif + return result; + } + + template + static SearchResult FindAlgoExhaustiveSearch( + const ConvArgs& args, const phi::GPUContext& ctx) { + SearchResult result; + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + size_t max_workspace_size = GetMaxWorkspaceSize(args, workspace_size_limit); + VLOG(4) << "max_workspace_size=" << ToMegaBytes(max_workspace_size) + << " MB"; + + int returned_algo_count; + std::vector perf_results(kNUM_CUDNN_FWD_ALGS); + auto cudnn_find_func = [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( + args.handle, args.idesc.desc(), args.x->data(), + args.wdesc.desc(), args.w->data(), args.cdesc.desc(), + args.odesc.desc(), const_cast(args.o->data()), + kNUM_CUDNN_FWD_ALGS, &returned_algo_count, perf_results.data(), + workspace_ptr, max_workspace_size)); + }; + + auto workspace_handle = ctx.cudnn_workspace_handle(); + workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size, + UseFixedWorkspace()); + + VLOG(4) << GetPerfResultString( + "[Exhaustive Search] FwdAlgo Perf result", perf_results, + returned_algo_count, workspace_size_limit); + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); + + return result; + } + + static size_t GetMaxWorkspaceSize(const ConvArgs& args, + size_t workspace_size_limit) { + if (!UseFixedWorkspace()) { + size_t max_workspace_size = 0; + for (size_t algo = 0; algo < kNUM_CUDNN_FWD_ALGS; ++algo) { + size_t workspace_size = 0; + auto status = + platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( + args.handle, args.idesc.desc(), args.wdesc.desc(), + args.cdesc.desc(), args.odesc.desc(), + static_cast(algo), &workspace_size); + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { + max_workspace_size = std::max(workspace_size, max_workspace_size); + } + } + return max_workspace_size; + } else { + return workspace_size_limit; + } } }; +// cuDNN convolution backward data-algorithm searcher, consisting of three +// searching modes, namely: deterministic, heuristic, and exhaustive_search +// mode. Specially, there are 2 pattens of exhaustive search mode, one for +// HALF precision only, one for the rest. +// As well as one workspace size acquirsition function with +// respect to the chosen alogrithm. template <> struct SearchAlgorithm { - using perf_t = cudnnConvolutionBwdDataAlgoPerf_t; - using algo_t = cudnnConvolutionBwdDataAlgo_t; + using PerfT = cudnnConvolutionBwdDataAlgoPerf_t; + using AlgoT = cudnnConvolutionBwdDataAlgo_t; template - static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, const phi::GPUContext& ctx) { + static SearchResult Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const phi::GPUContext& ctx) { + SearchResult result; auto dtype = platform::CudnnDataType::type; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; - bool has_got_workspace_size = true; - algo_t algo; SetConvMathType(ctx, dtype, args.cdesc); - if (!exhaustive_search && !deterministic) { + if (deterministic) { + result = FindAlgoDeterministic(); + } else { + // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. + // 2. Once turning on auto-tune, runn heuristic search(default) before + // auto-tune process, run exhaustive_search during mentioned process. + // 3. After auto-tune process, run cached algorithm if cached, run + // default mode for the rest. + size_t key = args.GetCacheKey(); + auto& cache = + phi::autotune::AutoTuneCache::Instance().GetConvBackwardData(); + if (cache.Find(key)) { + result.algo = static_cast(cache.Get(key)); + } else { + bool use_autotune = + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + if (exhaustive_search || use_autotune) { + result = FindAlgoExhaustiveSearch(args, ctx); + cache.Set(key, static_cast(result.algo)); + } else { + result = FindAlgoHeuristic(args, ctx); + } + } + } + VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search + << ", deterministic=" << deterministic + << ", choose algo=" << result.algo << ", workspace=" + << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + return result; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, + cudnnConvolutionBwdDataAlgo_t algo) { + size_t workspace_size = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size)); + return workspace_size; + } + + private: + static SearchResult FindAlgoDeterministic() { + return SearchResult(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); + } + + static SearchResult FindAlgoHeuristic(const ConvArgs& args, + const phi::GPUContext& ctx) { + SearchResult result; + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + #if CUDNN_VERSION >= 7001 - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results( - new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7( - args.handle, args.wdesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS, - &perf_count, perf_results.get())); - algo = (perf_results.get())[best_algo_idx].algo; + int actual_perf_count; + int best_algo_idx = 0; + std::vector perf_results(kNUM_CUDNN_BWD_DATA_ALGS); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7( + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS, + &actual_perf_count, perf_results.data())); + result.algo = perf_results[best_algo_idx].algo; #if CUDNN_VERSION < 7500 - int stride_dim = args.x->dims().size() - 2; - bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim, - [=](int n) { return n != 1; }); - if (blacklist && (static_cast( - perf_results[best_algo_idx].algo) == - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || - static_cast( - perf_results[best_algo_idx].algo) == - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { - algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } + int stride_dim = args.x->dims().size() - 2; + bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim, + [=](int n) { return n != 1; }); + if (blacklist && (perf_results[best_algo_idx].algo == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || + perf_results[best_algo_idx].algo == + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { + result.algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } #endif - workspace_size = GetWorkspaceSize(args, algo); - if (workspace_size > workspace_size_limit) { - has_got_workspace_size = false; + result.workspace_size = GetWorkspaceSize(args, result.algo); + if (result.workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 - // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8 - ChooseAlgoByWorkspace(perf_results.get(), - kNUM_CUDNN_BWD_DATA_ALGS, - workspace_size_limit, &algo); -#else - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << workspace_size << ") exceeds the limit(" - << workspace_size_limit << ")"; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( - args.handle, args.wdesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.idesc.desc(), - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); -#endif - } + // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); #else + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << result.workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); + workspace_size_limit, &(result.algo))); #endif - } else if (deterministic) { - return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } else { - auto& dev_ctx = ctx; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetBackwardData()); - - auto x_dims = phi::vectorize(args.x->dims()); - auto w_dims = phi::vectorize(args.w->dims()); - - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload:: - cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, args.wdesc.desc(), args.w->data(), - args.odesc.desc(), args.o->data(), - args.cdesc.desc(), args.idesc.desc(), - const_cast(args.x->data()), - kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, - perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - - VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - - return perf_stat[0].algo; - }); } - VLOG(3) << "choose algo " << algo; - return algo; - } - - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { - size_t workspace_size = 0; +#else PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( args.handle, args.wdesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size)); - return workspace_size; + args.cdesc.desc(), args.idesc.desc(), + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &(result.algo))); +#endif + + return result; + } + + template + static SearchResult FindAlgoExhaustiveSearch( + const ConvArgs& args, const phi::GPUContext& ctx) { + SearchResult result; + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + size_t max_workspace_size = GetMaxWorkspaceSize(args, workspace_size_limit); + VLOG(3) << "max_workspace_size=" << ToMegaBytes(max_workspace_size) + << " MB"; + + int returned_algo_count; + std::vector perf_results(kNUM_CUDNN_BWD_DATA_ALGS); + auto cudnn_find_func = [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.wdesc.desc(), args.w->data(), + args.odesc.desc(), args.o->data(), args.cdesc.desc(), + args.idesc.desc(), const_cast(args.x->data()), + kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count, + perf_results.data(), workspace_ptr, max_workspace_size)); + }; + + auto workspace_handle = ctx.cudnn_workspace_handle(); + workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size, + UseFixedWorkspace()); + + VLOG(4) << GetPerfResultString( + "[Exhaustive Search] BwdDataAlgo Perf result", perf_results, + returned_algo_count, workspace_size_limit); + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); + + return result; + } + + static size_t GetMaxWorkspaceSize(const ConvArgs& args, + size_t workspace_size_limit) { + if (!UseFixedWorkspace()) { + size_t max_workspace_size = 0; + for (size_t algo = 0; algo < kNUM_CUDNN_BWD_DATA_ALGS; ++algo) { + size_t workspace_size = 0; + auto status = + platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + args.handle, args.wdesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.idesc.desc(), + static_cast(algo), + &workspace_size); + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { + max_workspace_size = std::max(workspace_size, max_workspace_size); + } + } + return max_workspace_size; + } else { + return workspace_size_limit; + } } }; +// cuDNN convution backward filter-algorithm searcher, consisted of three +// algorithm searching modes, namely: deterministic, heuristic, and +// exhaustive_search mode. As well as one workspace size acquirsition function +// with respect to the chosen alogrithm. template <> struct SearchAlgorithm { - using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; - using algo_t = cudnnConvolutionBwdFilterAlgo_t; + using PerfT = cudnnConvolutionBwdFilterAlgoPerf_t; + using AlgoT = cudnnConvolutionBwdFilterAlgo_t; template - static algo_t Find(const ConvArgs& args, bool exhaustive_search, - bool deterministic, const phi::GPUContext& ctx) { + static SearchResult Find(const ConvArgs& args, bool exhaustive_search, + bool deterministic, + const phi::GPUContext& ctx) { platform::CUDAGraphCaptureModeGuard guard; + SearchResult result; auto dtype = platform::CudnnDataType::type; - size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; - size_t workspace_size = 0; - bool has_got_workspace_size = true; SetConvMathType(ctx, dtype, args.cdesc); - algo_t algo; - if (!exhaustive_search && !deterministic) { + if (deterministic) { + result = FindAlgoDeterministic(); + } else { + // 1. Once turning on exhaustive FLAGS, always get exhaustive_search. + // 2. Once turning on auto-tune, runn heuristic search(default) before + // auto-tune process, run exhaustive_search during mentioned process. + // 3. After auto-tune process, run cached algorithm if cached, run + // default mode for the rest. + size_t key = args.GetCacheKey(); + auto& cache = + phi::autotune::AutoTuneCache::Instance().GetConvBackwardFilter(); + if (cache.Find(key)) { + result.algo = static_cast(cache.Get(key)); + } else { + bool use_autotune = + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + if (exhaustive_search || use_autotune) { + result = FindAlgoExhaustiveSearch(args, ctx); + cache.Set(key, static_cast(result.algo)); + } else { + result = FindAlgoHeuristic(args, ctx); + } + } + } + VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search + << ", deterministic=" << deterministic + << ", choose algo=" << result.algo << ", workspace=" + << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; + return result; + } + + static size_t GetWorkspaceSize(const ConvArgs& args, + cudnnConvolutionBwdFilterAlgo_t algo) { + platform::CUDAGraphCaptureModeGuard guard; + size_t workspace_size = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size)); + return workspace_size; + } + + private: + static SearchResult FindAlgoDeterministic() { + return SearchResult(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); + } + + static SearchResult FindAlgoHeuristic(const ConvArgs& args, + const phi::GPUContext& ctx) { + SearchResult result; + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + #if CUDNN_VERSION >= 7001 - using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; - int perf_count; - int best_algo_idx = 0; - std::unique_ptr perf_results( - new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7( - args.handle, args.idesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS, - &perf_count, perf_results.get())); - algo = (perf_results.get())[best_algo_idx].algo; - workspace_size = (perf_results.get())[best_algo_idx].memory; + int actual_perf_count; + int best_algo_idx = 0; + std::vector perf_results(kNUM_CUDNN_BWD_FILTER_ALGS); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS, + &actual_perf_count, perf_results.data())); + result.algo = perf_results[best_algo_idx].algo; + result.workspace_size = perf_results[best_algo_idx].memory; - if (workspace_size > workspace_size_limit) { - workspace_size = workspace_size_limit; + if (result.workspace_size > workspace_size_limit) { #if CUDNN_VERSION >= 8000 - // cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8 - ChooseAlgoByWorkspace(perf_results.get(), - kNUM_CUDNN_BWD_FILTER_ALGS, - workspace_size_limit, &algo); -#else - VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " - "the workspace size request(" - << workspace_size << ") exceeds the limit(" - << workspace_size_limit << ")"; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( - args.handle, args.idesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.wdesc.desc(), - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); -#endif - } + // cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8 + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); #else + VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue " + "the workspace size request(" + << result.workspace_size << ") exceeds the limit(" + << workspace_size_limit << ")"; PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( args.handle, args.idesc.desc(), args.odesc.desc(), args.cdesc.desc(), args.wdesc.desc(), CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, &algo)); + workspace_size_limit, &(result.algo))); #endif - } else if (deterministic) { - return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } +#else + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_size_limit, &(result.algo))); +#endif + + return result; + } + + template + static SearchResult FindAlgoExhaustiveSearch( + const ConvArgs& args, const phi::GPUContext& ctx) { + SearchResult result; + int returned_algo_count = 0; + std::vector perf_results(kNUM_CUDNN_BWD_FILTER_ALGS); + size_t workspace_size_limit = + CalcWorkspaceLimitInBytes(UseFixedWorkspace()); + auto workspace_handle = ctx.cudnn_workspace_handle(); + if (platform::CudnnDataType::type != CUDNN_DATA_HALF) { + size_t max_workspace_size = + GetMaxWorkspaceSize(args, workspace_size_limit); + VLOG(3) << "max_workspace_size=" << ToMegaBytes(max_workspace_size) + << " MB"; + + auto cudnn_find_func = [&](void* workspace_ptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.idesc.desc(), args.x->data(), + args.odesc.desc(), args.o->data(), args.cdesc.desc(), + args.wdesc.desc(), const_cast(args.w->data()), + kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, + perf_results.data(), workspace_ptr, max_workspace_size)); + }; + workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size, + UseFixedWorkspace()); + + VLOG(4) << GetPerfResultString( + "[Exhaustive Search] BwdFilterAlgo Perf result", perf_results, + returned_algo_count, workspace_size_limit); + ChooseAlgoByWorkspace(perf_results, workspace_size_limit, + &result); } else { - auto& dev_ctx = ctx; - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - AlgorithmsCache& algo_cache = - *(framework::ConvSearchCache::Instance().GetBackwardFilter()); - - auto x_dims = phi::vectorize(args.x->dims()); - auto w_dims = phi::vectorize(args.w->dims()); - - VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" - << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" - << args.s << ", args.p" << args.p << ", args.d" << args.d; - if (dtype != CUDNN_DATA_HALF) { - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - int returned_algo_count; - std::array perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, args.idesc.desc(), args.x->data(), - args.odesc.desc(), args.o->data(), - args.cdesc.desc(), args.wdesc.desc(), - const_cast(args.w->data()), - kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count, - perf_stat.data(), cudnn_workspace_ptr, - workspace_size_limit)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, - workspace_size_limit); - - VLOG(3) - << "BwdFilterAlgo Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time - << " " << stat.memory; - } - return perf_stat[0].algo; - }); - } else { - auto max_algos = MaxBwdFilterAlgos(args.handle); - algo = algo_cache.GetAlgorithm( - x_dims, w_dims, args.s, args.p, args.d, 0, - static_cast(args.cudnn_dtype), [&]() { - algo_t chosen_algo; - std::vector perf_results(max_algos); - int actual_algos = 0; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload:: - cudnnFindConvolutionBackwardFilterAlgorithm( - args.handle, args.idesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.wdesc.desc(), - perf_results.size(), &actual_algos, - perf_results.data())); - perf_results.resize(actual_algos); - ChooseAlgo(perf_results, workspace_size_limit, - &chosen_algo); - return chosen_algo; - }); + int max_algos = GetAlgorithmMaxCount(args.handle); + std::vector perf_results(max_algos); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithm( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), perf_results.size(), + &returned_algo_count, perf_results.data())); + perf_results.resize(returned_algo_count); + + VLOG(4) << GetPerfResultString( + "[Exhaustive Search] BwdFilterAlgo Perf result", perf_results, + perf_results.size(), workspace_size_limit); + ChooseAlgo(perf_results, workspace_size_limit, &result); + } + + return result; + } + + static int GetAlgorithmMaxCount(cudnnHandle_t handle) { +#if CUDNN_VERSION_MIN(7, 0, 1) + int max_algos = 0; + auto status = + platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + handle, &max_algos); + if (status == gpuSuccess) { + VLOG(5) << "[BackwardFilter] max_algos: predefined=" + << kNUM_CUDNN_BWD_FILTER_ALGS << ", actual=" << max_algos; + return max_algos; + } +#endif + return kNUM_CUDNN_BWD_FILTER_ALGS; + } + + static size_t GetMaxWorkspaceSize(const ConvArgs& args, + size_t workspace_size_limit) { + if (!UseFixedWorkspace()) { + size_t max_workspace_size = 0; + for (size_t algo = 0; algo < kNUM_CUDNN_BWD_FILTER_ALGS; ++algo) { + size_t workspace_size = 0; + auto status = + platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + args.handle, args.idesc.desc(), args.odesc.desc(), + args.cdesc.desc(), args.wdesc.desc(), + static_cast(algo), + &workspace_size); + if (status == CUDNN_STATUS_SUCCESS && + workspace_size <= workspace_size_limit) { + max_workspace_size = std::max(workspace_size, max_workspace_size); + } } + return max_workspace_size; + } else { + return workspace_size_limit; } - VLOG(3) << "choose algo " << algo; - return algo; } - static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) { - platform::CUDAGraphCaptureModeGuard guard; - size_t workspace_size = 0; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( - args.handle, args.idesc.desc(), args.odesc.desc(), - args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size)); - return workspace_size; + static void ChooseAlgo(const std::vector& perf_results, + size_t workspace_limit, + SearchResult* algo_result) { + for (size_t i = 0; i != perf_results.size(); ++i) { + const auto& result = perf_results[i]; + if (result.status == CUDNN_STATUS_SUCCESS && + (result.memory <= workspace_limit)) { + if ((result.mathType == CUDNN_TENSOR_OP_MATH) && + (i != perf_results.size() - 1)) { + const auto& next_result = perf_results[i + 1]; + if (next_result.status == CUDNN_STATUS_SUCCESS && + next_result.algo == result.algo && + next_result.memory == result.memory && + next_result.mathType != CUDNN_TENSOR_OP_MATH && + next_result.time < 1.01 * result.time) { + // Skip over this result- it's not really a Tensor Core algo. + // Because it is only 1% performance difference. + // Prefer to choose the next equivalent non-Tensor Core algo. + continue; + } + } + algo_result->algo = result.algo; + algo_result->time = result.time; + auto math_type_str = "0"; + if (result.mathType == CUDNN_TENSOR_OP_MATH) { + math_type_str = "1"; + } + VLOG(3) << " choose algo: " << result.algo + << ", TC: " << math_type_str << ", time: " << result.time + << " ms, wksp = " << result.memory + << ", status = " << result.status; + break; + } + } } }; diff --git a/paddle/fluid/operators/conv_cudnn_op_cache.h b/paddle/fluid/operators/conv_cudnn_op_cache.h index 291e5f92f32..af67d857e0e 100644 --- a/paddle/fluid/operators/conv_cudnn_op_cache.h +++ b/paddle/fluid/operators/conv_cudnn_op_cache.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -DECLARE_uint64(conv_workspace_size_limit); +DECLARE_int64(conv_workspace_size_limit); DECLARE_bool(cudnn_exhaustive_search); DECLARE_int64(cudnn_exhaustive_search_times); diff --git a/paddle/fluid/operators/conv_miopen_helper.h b/paddle/fluid/operators/conv_miopen_helper.h index 66f71869384..abc7be7fb8b 100644 --- a/paddle/fluid/operators/conv_miopen_helper.h +++ b/paddle/fluid/operators/conv_miopen_helper.h @@ -14,42 +14,12 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/conv_search_cache.h" -#include "paddle/fluid/framework/operator_kernel_configs.h" -#include "paddle/fluid/operators/conv_cudnn_op_cache.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/fluid/operators/conv_base_helper.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using DataLayout = platform::DataLayout; -template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; -using framework::AlgorithmsCache; -static inline void GetNCDHW(const framework::DDim& dims, - const DataLayout& layout, int* N, int* C, int* D, - int* H, int* W) { - *N = dims[0]; - *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - int i = layout == DataLayout::kNCHW ? 0 : 1; - if (dims.size() == 5) { - *D = dims[2 - i]; - *H = dims[3 - i]; - *W = dims[4 - i]; - } else { - *D = 1; - *H = dims[2 - i]; - *W = dims[3 - i]; - } -} +using ConvArgs = ConvArgsBase; template static void RemovePaddingSlice(const phi::GPUContext& context, @@ -66,9 +36,8 @@ static void RemovePaddingSlice(const phi::GPUContext& context, extents[i] = new_out_dims[i]; } - int start; for (size_t i = 0; i < axes.size(); ++i) { - start = starts[i]; + int start = starts[i]; if (start < 0) { start = (start + in_dims[axes[i]]); } @@ -85,41 +54,6 @@ static void RemovePaddingSlice(const phi::GPUContext& context, out_t.device(place) = in_t.slice(offsets, extents); } -template -std::ostream& operator<<(std::ostream& out, const std::vector& v) { - out << "["; - for (auto const& tmp : v) out << tmp << ","; - out << "]"; - return out; -} - -using framework::ConvSearchCache; - -struct ConvArgs { - miopenHandle_t handle; - platform::TensorDescriptor idesc, odesc; - platform::FilterDescriptor wdesc; - platform::ConvolutionDescriptor cdesc; - const framework::Tensor *x, *w, *o; - miopenDataType_t cudnn_dtype; - - // strides - std::vector s; - // paddings - std::vector p; - // dilations - std::vector d; - - ConvArgs(const framework::Tensor* x, const framework::Tensor* w, - const framework::Tensor* o, const std::vector s, - const std::vector p, const std::vector d, - miopenDataType_t dtype) - : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} -}; - -template -struct SearchAlgorithm {}; - template <> struct SearchAlgorithm { using perf_t = miopenConvAlgoPerf_t; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu index 39b42ec194c..bd7134f2f33 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cu +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cu @@ -16,8 +16,6 @@ limitations under the License. */ #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -DECLARE_uint64(conv_workspace_size_limit); - namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index a671381d07f..89e3b74bb3a 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -188,6 +188,8 @@ class RecordedGpuMallocHelper { if (UNLIKELY(malloc_managed_memory)) { result = cudaMallocManaged(ptr, size); } else { + VLOG(10) << "[cudaMalloc] size=" << static_cast(size) / (1 << 20) + << " MB"; result = cudaMalloc(ptr, size); } #endif @@ -226,6 +228,8 @@ class RecordedGpuMallocHelper { if (err != hipErrorDeinitialized) { #else auto err = cudaFree(ptr); + VLOG(10) << "[cudaFree] size=" << static_cast(size) / (1 << 20) + << " MB"; if (err != cudaErrorCudartUnloading) { #endif PADDLE_ENFORCE_GPU_SUCCESS(err); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index f3934c7d871..904e4854ba6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -522,8 +522,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) { cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place)); auto& instance = memory::allocation::AllocatorFacade::Instance(); instance.SetDefaultStream(place, phi::GPUContext::stream()); - workspace_.reset( - new phi::DnnWorkspaceHandle(instance.GetAllocator(place).get())); + workspace_.reset(new phi::DnnWorkspaceHandle( + instance.GetAllocator(place).get(), stream())); } CUDADeviceContext::~CUDADeviceContext() = default; @@ -623,7 +623,8 @@ phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return phi::DnnWorkspaceHandle( memory::allocation::AllocatorFacade::Instance() .GetAllocator(GetPlace()) - .get()); + .get(), + stream()); } return phi::GPUContext::cudnn_workspace_handle(); } diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 87bad9cbdfc..c70452c5016 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -158,10 +158,9 @@ PADDLE_DEFINE_EXPORTED_bool( * increased. * Users need to balance memory and speed. */ -PADDLE_DEFINE_EXPORTED_uint64( - conv_workspace_size_limit, - paddle::platform::kDefaultConvWorkspaceSizeLimitMB, - "cuDNN convolution workspace limit in MB unit."); +PADDLE_DEFINE_EXPORTED_int64(conv_workspace_size_limit, + paddle::platform::kDefaultConvWorkspaceSizeLimitMB, + "cuDNN convolution workspace limit in MB unit."); /** * CUDNN related FLAG @@ -800,3 +799,12 @@ DEFINE_bool(enable_ins_parser_file, false, #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); #endif + +/** + * Autotune related FLAG + * Name: FLAGS_use_autotune + * Since Version: 2.3.0 + * Value Range: bool, default=false + * Example: + */ +PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 89b475f075e..e192d428a15 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -4430,7 +4430,7 @@ All parameter, weight, gradient are variables in Paddle. return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune(); }); - m.def("autotune_range", [](int64_t start, int64_t stop) { + m.def("set_autotune_range", [](int64_t start, int64_t stop) { return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start, stop); }); @@ -4439,10 +4439,8 @@ All parameter, weight, gradient are variables in Paddle. [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); m.def("autotune_status", [] { - phi::autotune::AutoTuneCache::Instance().UpdateStatus(); py::dict res; - res["use_autotune"] = - phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + phi::autotune::AutoTuneCache::Instance().UpdateStatus(); res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID(); res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size(); res["cache_hit_rate"] = diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 0394835aa8b..ff238b79978 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -12,6 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/phi/backends/gpu/gpu_context.h" #include #include @@ -155,6 +156,39 @@ static void StreamCallbackFunc(gpuStream_t stream, } // namespace internal +void DnnWorkspaceHandle::RunFuncSync( + const std::function& cudnn_func, + size_t required_workspace_bytes, + bool use_cached_allocation) { + bool need_realloc = required_workspace_bytes > WorkspaceSize(); + if (need_realloc && !use_cached_allocation) { + void* workspace_ptr = nullptr; + size_t size = ((required_workspace_bytes + 255) >> 8) << 8; + std::lock_guard guard(*mtx_); +#ifdef PADDLE_WITH_HIP + auto status = hipMalloc(&workspace_ptr, size); +#else + auto status = cudaMalloc(&workspace_ptr, size); +#endif + if (status == gpuSuccess) { + cudnn_func(workspace_ptr); + phi::backends::gpu::GpuStreamSync(stream_); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr)); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr)); +#endif + return; + } + } + + RunFunc(cudnn_func, required_workspace_bytes); + if (need_realloc) { + // Release the workspace allocated in this running. + ResetWorkspace(); + } +} + void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; } void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) { @@ -295,13 +329,13 @@ struct GPUContext::Impl { void InitDnnWorkspace() { PD_CHECK(allocator_ != nullptr, "the device allocator for gpu context is nullptr."); - workspace_ = new DnnWorkspaceHandle(allocator_); + workspace_ = new DnnWorkspaceHandle(allocator_, stream_); } void DestoryInternalWorkspace() { if (owned_ && workspace_ != nullptr) { delete workspace_; - stream_ = nullptr; + workspace_ = nullptr; } } @@ -313,7 +347,7 @@ struct GPUContext::Impl { DnnWorkspaceHandle GetDnnWorkspace() { PD_CHECK(allocator_ != nullptr, "the device allocator for gpu context is nullptr."); - return DnnWorkspaceHandle(allocator_); + return DnnWorkspaceHandle(allocator_, stream_); } void InitStream() { diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index d268d4ae8d8..8d44acaa4a0 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_helper.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" @@ -28,8 +29,8 @@ namespace phi { class DnnWorkspaceHandle { public: - explicit inline DnnWorkspaceHandle(Allocator* allocator) - : allocator_(allocator) { + inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream) + : allocator_(allocator), stream_(stream) { mtx_.reset(new std::mutex()); } @@ -48,11 +49,9 @@ class DnnWorkspaceHandle { * running the function. Currently this function is only used when cudnn * exhaustive searching and callers have to guarantee that the input function * is host blocking */ - inline void RunFuncSync(const std::function& cudnn_func, - size_t required_workspace_bytes) { - RunFunc(cudnn_func, required_workspace_bytes); - ResetWorkspace(); - } + void RunFuncSync(const std::function& cudnn_func, + size_t required_workspace_bytes, + bool use_cached_allocation = true); inline size_t WorkspaceSize() { if (allocation_ == nullptr) { @@ -70,7 +69,8 @@ class DnnWorkspaceHandle { private: Allocator::AllocationPtr allocation_{nullptr}; - Allocator* allocator_{nullptr}; + Allocator* allocator_{nullptr}; // Not owned + gpuStream_t stream_{nullptr}; // Not owned std::unique_ptr mtx_; }; diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 9c8756359e0..a6fe6fecb57 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -6,12 +6,15 @@ file(APPEND ${kernel_declare_file} "#include \"paddle/phi/core/kernel_registry.h # phi functors and functions called by kernels add_subdirectory(funcs) +# kernel autotune +add_subdirectory(autotune) + # phi depends all phi kernel targets set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor ) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) @@ -27,13 +30,17 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS cross_entropy_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel +set(AUTOTUNE_KERNELS conv_kernel conv_grad_kernel conv_grad_grad_kernel conv_transpose_kernel conv_transpose_grad_kernel) +set(MANUAL_BUILD_KERNELS ${AUTOTUNE_KERNELS} cross_entropy_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel triangular_solve_grad_kernel determinant_grad_kernel reduce_sum_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel) +foreach(src ${AUTOTUNE_KERNELS}) + kernel_library(${src} DEPS ${COMMON_KERNEL_DEPS} switch_autotune) +endforeach() kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_grad_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) @@ -74,6 +81,3 @@ add_subdirectory(selected_rows) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) # For strings kernels add_subdirectory(strings) - -# 5. kernel autotune -add_subdirectory(autotune) diff --git a/paddle/phi/kernels/autotune/CMakeLists.txt b/paddle/phi/kernels/autotune/CMakeLists.txt index b933e0993de..63dc2245944 100644 --- a/paddle/phi/kernels/autotune/CMakeLists.txt +++ b/paddle/phi/kernels/autotune/CMakeLists.txt @@ -1,11 +1,12 @@ if (WITH_GPU) - nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) - nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) + nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) + nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) elseif (WITH_ROCM) - hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) - hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) + hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) + hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) endif() cc_library(cache SRCS cache.cc DEPS boost) +cc_library(switch_autotune SRCS switch_autotune.cc DEPS cache flags) cc_test(cache_test SRCS cache_test.cc DEPS gtest cache) diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index bf68e201015..ef2cbe633d4 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/phi/kernels/autotune/cache.h" +#include +#include "glog/logging.h" namespace phi { namespace autotune { @@ -32,5 +34,40 @@ size_t ConvKey(const std::vector& x_dims, static_cast(dtype)); } +std::string AlgorithmTypeString(int64_t algo_type) { + if (algo_type == static_cast(AlgorithmType::kConvForward)) { + return "conv_forward"; + } else if (algo_type == + static_cast(AlgorithmType::kConvBackwardData)) { + return "conv_backward_data"; + } else if (algo_type == + static_cast(AlgorithmType::kConvBackwardFilter)) { + return "conv_backward_filter"; + } + return std::to_string(algo_type); +} + +void AutoTuneCache::UpdateStatus() { + int64_t size = 0; + int64_t cache_hits = 0; + int64_t cache_misses = 0; + int name_width = 24; + std::cout.setf(std::ios::left); + for (auto& v : auto_tune_map_) { + VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width) + << AlgorithmTypeString(v.first) + << " Cache Size: " << v.second.Size() + << " Hits: " << v.second.CacheHits() + << " Misses: " << v.second.CacheMisses() + << " Hit Rate: " << v.second.CacheHitRate(); + size += v.second.Size(); + cache_hits += v.second.CacheHits(); + cache_misses += v.second.CacheMisses(); + } + total_size_ = size; + total_cache_hits_ = cache_hits; + total_cache_misses_ = cache_misses; +} + } // namespace autotune } // namespace phi diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index d492e7c151f..37c5d134e8a 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -13,11 +13,12 @@ // limitations under the License. #pragma once + #include #include +#include #include #include -#include "glog/logging.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" @@ -92,6 +93,13 @@ class AlgorithmsCache { return ret; } + void Clean() { + std::lock_guard lock(*cache_mutex_); + hash_.clear(); + cache_hits_ = 0; + cache_misses_ = 0; + } + void Set(size_t key, AlgorithmT algo) { std::lock_guard lock(*cache_mutex_); hash_[key] = algo; @@ -116,15 +124,22 @@ class AlgorithmsCache { private: std::unordered_map hash_; std::shared_ptr cache_mutex_; - int64_t cache_hits_ = 0; - int64_t cache_misses_ = 0; + + int64_t cache_hits_{0}; + int64_t cache_misses_{0}; +}; + +enum class AlgorithmType { + kConvForward = 1, + kConvBackwardData = 2, + kConvBackwardFilter = 3, + kAlgorithmCount = 4 }; // AlgorithmsConfigKey -> AlgorithmsID -using AlgorithmsConfigKeyMap = AlgorithmsCache; -// AlgorithmsType -> AlgorithmsCache -using AlgorithmsTypeMap = - std::unordered_map; +using AlgorithmsCacheMap = AlgorithmsCache; +// AlgorithmType -> AlgorithmsCache +using AlgorithmsTypeMap = std::unordered_map; class AutoTuneCache { public: @@ -133,42 +148,30 @@ class AutoTuneCache { return autotune_cache; } - AlgorithmsConfigKeyMap& RegisterOrGet(const std::string& algo_type) { - std::lock_guard lock(*autotune_cache_mutex_); - if (auto_tune_map_.find(algo_type) == auto_tune_map_.end()) { - AlgorithmsConfigKeyMap cache; - auto_tune_map_[algo_type] = cache; - } - return auto_tune_map_[algo_type]; + AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) { + return auto_tune_map_[static_cast(algo_type)]; } - void Clean(float miss_rate) { - std::lock_guard lock(*autotune_cache_mutex_); - // Set a small tolerance to avoid performance degradation - // due to large cache size under dynamic shape. - if (miss_rate > 0.01) { - auto_tune_map_.clear(); - } + AlgorithmsCacheMap& GetConvForward() { + return Get(AlgorithmType::kConvForward); + } + + AlgorithmsCacheMap& GetConvBackwardData() { + return Get(AlgorithmType::kConvBackwardData); + } + + AlgorithmsCacheMap& GetConvBackwardFilter() { + return Get(AlgorithmType::kConvBackwardFilter); } - void UpdateStatus() { - int64_t size = 0; - int64_t cache_hits = 0; - int64_t cache_misses = 0; + void Clean() { for (auto& v : auto_tune_map_) { - VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size() - << " Hits: " << v.second.CacheHits() - << " Misses: " << v.second.CacheMisses() - << " Hit Rate: " << v.second.CacheHitRate(); - size += v.second.Size(); - cache_hits += v.second.CacheHits(); - cache_misses += v.second.CacheMisses(); + v.second.Clean(); } - total_size_ = size; - total_cache_hits_ = cache_hits; - total_cache_misses_ = cache_misses; } + void UpdateStatus(); + // The number of total config cached int64_t Size() const { return total_size_; } @@ -183,17 +186,30 @@ class AutoTuneCache { total_cache_hit_rate = static_cast(total_cache_hits_) / static_cast(total_num_accesses); } - return total_cache_hit_rate; } private: - AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {} + AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) { + for (int i = 1; i < static_cast(AlgorithmType::kAlgorithmCount); ++i) { + Register(static_cast(i)); + } + } + + void Register(const AlgorithmType& algo_type) { + std::lock_guard lock(*autotune_cache_mutex_); + int64_t key = static_cast(algo_type); + if (auto_tune_map_.find(key) == auto_tune_map_.end()) { + AlgorithmsCacheMap cache; + auto_tune_map_[key] = cache; + } + } + AlgorithmsTypeMap auto_tune_map_; std::shared_ptr autotune_cache_mutex_; - int64_t total_cache_hits_ = 0; - int64_t total_cache_misses_ = 0; - int64_t total_size_ = 0; + int64_t total_cache_hits_{0}; + int64_t total_cache_misses_{0}; + int64_t total_size_{0}; }; } // namespace autotune diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 92ba411624f..f99f8bfc8b8 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -22,7 +22,7 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; TEST(AlgosCache, AlgosCache) { auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); - auto& cache = autotune_cache.RegisterOrGet("conv_fw"); + auto& cache = autotune_cache.GetConvForward(); std::vector x_shape = {4, 224, 224, 3}; std::vector w_shape = {32, 3, 3, 3}; diff --git a/paddle/phi/kernels/autotune/switch_autotune.cc b/paddle/phi/kernels/autotune/switch_autotune.cc new file mode 100644 index 00000000000..6fda24ef3c8 --- /dev/null +++ b/paddle/phi/kernels/autotune/switch_autotune.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/autotune/switch_autotune.h" + +#include "gflags/gflags.h" +#include "glog/logging.h" + +DECLARE_bool(use_autotune); + +namespace phi { +namespace autotune { + +void AutoTuneStatus::EnableAutoTune() { + FLAGS_use_autotune = true; + Init(); +} + +void AutoTuneStatus::DisableAutoTune() { + FLAGS_use_autotune = false; + Init(); +} + +void AutoTuneStatus::Update() { + current_steps_id_ += 1; + if (!FLAGS_use_autotune) { + return; + } + + // This fuction is called when each iter finished. + if (current_steps_id_ + 1 < start_step_id_) { + use_autotune_ = false; + } else if (current_steps_id_ + 1 >= start_step_id_ && + current_steps_id_ + 1 < stop_step_id_) { + use_autotune_ = true; + AutoTuneCache::Instance().UpdateStatus(); + step_hit_rates_.push_back(StepHitRate()); + VLOG(3) << "Step ID: " << current_steps_id_ + << ", Accumulative Cache Hit Rate: " + << static_cast(AutoTuneCache::Instance().CacheHitRate() * 100) + << "%, Cache Size: " << AutoTuneCache::Instance().Size() + << ", Current Step Hit Rate: " + << static_cast(StepHitRate() * 100) << "%"; + } else { + use_autotune_ = false; + // Set a small tolerance to avoid performance degradation + // due to large cache size under dynamic shape. + // TODO(limingshu): Currently works for conv op only, this + // method shall be opimized when more ops involved in. + // float miss_rate = static_cast(1) - RecentHitRate(); + // if (current_steps_id_ == stop_step_id_) { + // AutoTuneCache::Instance().Clean(miss_rate); + // } + if (VLOG_IS_ON(4)) { + AutoTuneCache::Instance().UpdateStatus(); + VLOG(4) << "Step ID: " << current_steps_id_ << ", Current Step Hit Rate: " + << static_cast(StepHitRate() * 100) << "%"; + } + } +} + +} // namespace autotune +} // namespace phi diff --git a/paddle/phi/kernels/autotune/switch_autotune.h b/paddle/phi/kernels/autotune/switch_autotune.h index 2f9621ed207..1793940542d 100644 --- a/paddle/phi/kernels/autotune/switch_autotune.h +++ b/paddle/phi/kernels/autotune/switch_autotune.h @@ -13,10 +13,8 @@ // limitations under the License. #pragma once + #include -#include -#include -#include "glog/logging.h" #include "paddle/phi/kernels/autotune/cache.h" namespace phi { @@ -31,45 +29,11 @@ class AutoTuneStatus { bool UseAutoTune() { return use_autotune_; } - // EnableAutoTune and DisableAutoTune Should be used for debug only. - void EnableAutoTune() { - use_autotune_ = true; - Init(); - } - - void DisableAutoTune() { - use_autotune_ = false; - Init(); - } + // EnableAutoTune and DisableAutoTune should be used for debug only. + void EnableAutoTune(); + void DisableAutoTune(); - void Update() { - current_steps_id_ += 1; - - if (!use_autotune_ && !update_use_autotune_) { - return; - } - - if (current_steps_id_ < start_step_id_) { - use_autotune_ = false; - } else if (current_steps_id_ >= start_step_id_ && - current_steps_id_ < stop_step_id_) { - use_autotune_ = true; - AutoTuneCache::Instance().UpdateStatus(); - step_hit_rates_.push_back(StepHitRate()); - VLOG(3) << "Step ID " << current_steps_id_ - << ", Accumulative Cache Hit Rate: " - << AutoTuneCache::Instance().CacheHitRate() - << ", Cache Size: " << AutoTuneCache::Instance().Size() - << ", Current Step Hit Rate: " << StepHitRate(); - } else if (current_steps_id_ == stop_step_id_) { - use_autotune_ = false; - update_use_autotune_ = false; - // clean cache according miss rate - float miss_rate = static_cast(1) - RecentHitRate(); - AutoTuneCache::Instance().Clean(miss_rate); - VLOG(3) << "Recent Miss Rate: " << miss_rate; - } - } + void Update(); int64_t StepID() { return current_steps_id_; } @@ -84,19 +48,25 @@ class AutoTuneStatus { // Hit Rate of Current Step float StepHitRate() { - int64_t current_hits = AutoTuneCache::Instance().CacheHits(); - int64_t current_misses = AutoTuneCache::Instance().CacheMisses(); - int64_t step_hits_ = current_hits - previous_hits_; - int64_t step_misses_ = current_misses - previous_misses_; - float step_hit_rate = 0.; - int64_t step_num_accesses = step_hits_ + step_misses_; - if (step_num_accesses != 0) { - step_hit_rate = static_cast(step_hits_) / - static_cast(step_num_accesses); + static int64_t last_step_id = -2; + + if (last_step_id != current_steps_id_) { + int64_t current_hits = AutoTuneCache::Instance().CacheHits(); + int64_t current_misses = AutoTuneCache::Instance().CacheMisses(); + int64_t step_hits_ = current_hits - previous_hits_; + int64_t step_misses_ = current_misses - previous_misses_; + float step_hit_rate = 0.; + int64_t step_num_accesses = step_hits_ + step_misses_; + if (step_num_accesses != 0) { + step_hit_rate = static_cast(step_hits_) / + static_cast(step_num_accesses); + } + previous_hits_ = current_hits; + previous_misses_ = current_misses; + current_step_hit_rate_ = step_hit_rate; + last_step_id = current_steps_id_; } - previous_hits_ = current_hits; - previous_misses_ = current_misses; - return step_hit_rate; + return current_step_hit_rate_; } void SetAutoTuneRange(int64_t start, int64_t stop) { @@ -108,21 +78,21 @@ class AutoTuneStatus { AutoTuneStatus() = default; void Init() { - update_use_autotune_ = use_autotune_; + use_autotune_ = false; current_steps_id_ = -1; previous_hits_ = 0; previous_misses_ = 0; step_hit_rates_.clear(); - AutoTuneCache::Instance().Clean(1.0); + AutoTuneCache::Instance().Clean(); } - int64_t start_step_id_ = 0; - int64_t stop_step_id_ = 10; - int64_t current_steps_id_ = -1; - bool use_autotune_ = false; - bool update_use_autotune_ = false; - int64_t previous_hits_ = 0; - int64_t previous_misses_ = 0; + bool use_autotune_{false}; + int64_t start_step_id_{1}; + int64_t stop_step_id_{10}; + int64_t current_steps_id_{-1}; + int64_t previous_hits_{0}; + int64_t previous_misses_{0}; + float current_step_hit_rate_{0.f}; std::vector step_hit_rates_; }; diff --git a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu index 9c5e77d5fd8..74525e63f47 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_grad_kernel.cu @@ -289,21 +289,17 @@ void ConvCudnnGradGradKernel( dtype}; #ifdef PADDLE_WITH_HIP - miopenConvFwdAlgorithm_t fwd_algo1 = static_cast(0); - miopenConvFwdAlgorithm_t fwd_algo2 = static_cast(0); - miopenConvBwdDataAlgorithm_t data_algo = - static_cast(0); - miopenConvBwdWeightsAlgorithm_t filter_algo = - static_cast(0); + paddle::operators::SearchResult fwd_result1; + paddle::operators::SearchResult fwd_result2; + paddle::operators::SearchResult data_result; + paddle::operators::SearchResult + filter_result; #else - cudnnConvolutionFwdAlgo_t fwd_algo1 = - static_cast(0); - cudnnConvolutionFwdAlgo_t fwd_algo2 = - static_cast(0); - cudnnConvolutionBwdDataAlgo_t data_algo = - static_cast(0); - cudnnConvolutionBwdFilterAlgo_t filter_algo = - static_cast(0); + paddle::operators::SearchResult fwd_result1; + paddle::operators::SearchResult fwd_result2; + paddle::operators::SearchResult data_result; + paddle::operators::SearchResult + filter_result; #endif auto layout = paddle::platform::GetCudnnTensorFormat( @@ -332,13 +328,13 @@ void ConvCudnnGradGradKernel( using search1 = paddle::operators::SearchAlgorithm; workspace_size = search1::GetWorkspaceSize(args1); - fwd_algo1 = search1::Find( + fwd_result1.algo = search1::Find( args1, exhaustive_search, false, workspace_size, ctx); #else using search1 = paddle::operators::SearchAlgorithm; - fwd_algo1 = search1::Find(args1, exhaustive_search, false, ctx); - workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1); + fwd_result1 = search1::Find(args1, exhaustive_search, false, ctx); + workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); #endif } @@ -360,14 +356,14 @@ void ConvCudnnGradGradKernel( paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); - fwd_algo2 = search2::Find( + fwd_result2.algo = search2::Find( args2, exhaustive_search, false, workspace_size, ctx); #else using search2 = paddle::operators::SearchAlgorithm; - fwd_algo2 = search2::Find(args2, exhaustive_search, false, ctx); - workspace_size = - std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2)); + fwd_result2 = search2::Find(args2, exhaustive_search, false, ctx); + workspace_size = std::max( + workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); #endif } } @@ -389,15 +385,15 @@ void ConvCudnnGradGradKernel( using search3 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3)); - filter_algo = search3::Find( + filter_result.algo = search3::Find( args3, exhaustive_search, deterministic, workspace_size, ctx); #else using search3 = paddle::operators::SearchAlgorithm; - filter_algo = + filter_result = search3::Find(args3, exhaustive_search, deterministic, ctx); - workspace_size = - std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); + workspace_size = std::max( + workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif } @@ -419,14 +415,15 @@ void ConvCudnnGradGradKernel( using search4 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4)); - data_algo = search4::Find( + data_result.algo = search4::Find( args4, exhaustive_search, deterministic, workspace_size, ctx); #else using search4 = paddle::operators::SearchAlgorithm; - data_algo = search4::Find(args4, exhaustive_search, deterministic, ctx); - workspace_size = - std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); + data_result = + search4::Find(args4, exhaustive_search, deterministic, ctx); + workspace_size = std::max( + workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); #endif } @@ -471,7 +468,7 @@ void ConvCudnnGradGradKernel( args1.wdesc.desc(), w, args1.cdesc.desc(), - fwd_algo1, + fwd_result1.algo, &beta, args1.odesc.desc(), transformed_ddy_channel, @@ -492,7 +489,7 @@ void ConvCudnnGradGradKernel( args1.wdesc.desc(), w + i * group_offset_filter, args1.cdesc.desc(), - fwd_algo1, + fwd_result1.algo, workspace_ptr, workspace_size, &beta, @@ -517,7 +514,7 @@ void ConvCudnnGradGradKernel( args2.wdesc.desc(), ddw, args2.cdesc.desc(), - fwd_algo2, + fwd_result2.algo, &beta, args2.odesc.desc(), transformed_ddy_channel, @@ -538,7 +535,7 @@ void ConvCudnnGradGradKernel( args2.wdesc.desc(), ddw + i * group_offset_filter, args2.cdesc.desc(), - fwd_algo2, + fwd_result2.algo, workspace_ptr, workspace_size, &alpha, @@ -568,7 +565,7 @@ void ConvCudnnGradGradKernel( args3.idesc.desc(), ddx, args3.cdesc.desc(), - filter_algo, + filter_result.algo, &beta, args3.wdesc.desc(), dw, @@ -589,7 +586,7 @@ void ConvCudnnGradGradKernel( args3.odesc.desc(), transformed_dy_channel + i * group_offset_out, args3.cdesc.desc(), - filter_algo, + filter_result.algo, workspace_ptr, workspace_size, &beta, @@ -615,7 +612,7 @@ void ConvCudnnGradGradKernel( args4.wdesc.desc(), ddw, args4.cdesc.desc(), - data_algo, + data_result.algo, &beta, args4.idesc.desc(), transformed_dx, @@ -636,7 +633,7 @@ void ConvCudnnGradGradKernel( args4.odesc.desc(), transformed_dy_channel + i * group_offset_out, args4.cdesc.desc(), - data_algo, + data_result.algo, workspace_ptr, workspace_size, &beta, diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index 9856fabfa15..9d4acb95ea4 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -322,17 +322,16 @@ void ConvCudnnGradKernel(const Context& ctx, int group_offset_in = i_c / groups * i_h * i_w * i_d; int group_offset_out = o_c / groups * o_h * o_w * o_d; int group_offset_filter = transformed_filter_channel.numel() / groups; + // ------------------- cudnn backward algorithm --------------------- #ifdef PADDLE_WITH_HIP - miopenConvBwdDataAlgorithm_t data_algo = - static_cast(0); - miopenConvBwdWeightsAlgorithm_t filter_algo = - static_cast(0); + paddle::operators::SearchResult bwd_result; + paddle::operators::SearchResult + filter_result; #else - cudnnConvolutionBwdDataAlgo_t data_algo = - static_cast(0); - cudnnConvolutionBwdFilterAlgo_t filter_algo = - static_cast(0); + paddle::operators::SearchResult bwd_result; + paddle::operators::SearchResult + filter_result; #endif // input data workspace_size size_t workspace_size_d = 0; @@ -368,14 +367,14 @@ void ConvCudnnGradKernel(const Context& ctx, paddle::operators::SearchAlgorithm; workspace_size_d = std::max(workspace_size_d, search1::GetWorkspaceSize(args1)); - data_algo = search1::Find( + bwd_result.algo = search1::Find( args1, exhaustive_search, deterministic, workspace_size_d, ctx); #else using search1 = paddle::operators::SearchAlgorithm; - data_algo = search1::Find(args1, exhaustive_search, deterministic, ctx); - workspace_size_d = - std::max(workspace_size_d, search1::GetWorkspaceSize(args1, data_algo)); + bwd_result = search1::Find(args1, exhaustive_search, deterministic, ctx); + workspace_size_d = std::max( + workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo)); #endif } @@ -397,15 +396,17 @@ void ConvCudnnGradKernel(const Context& ctx, paddle::operators::SearchAlgorithm; workspace_size_w = std::max(workspace_size_w, search2::GetWorkspaceSize(args2)); - filter_algo = search2::Find( + filter_result.algo = search2::Find( args2, exhaustive_search, deterministic, workspace_size_w, ctx); #else using search2 = paddle::operators::SearchAlgorithm; - filter_algo = + filter_result = search2::Find(args2, exhaustive_search, deterministic, ctx); - workspace_size_w = std::max(workspace_size_w, - search2::GetWorkspaceSize(args2, filter_algo)); + VLOG(3) << "filter algo: " << filter_result.algo << ", time " + << filter_result.time; + workspace_size_w = std::max( + workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo)); #endif } @@ -439,7 +440,7 @@ void ConvCudnnGradKernel(const Context& ctx, args1.wdesc.desc(), filter_data, args1.cdesc.desc(), - data_algo, + bwd_result.algo, &beta, args1.idesc.desc(), temp_tensor_data, @@ -471,7 +472,7 @@ void ConvCudnnGradKernel(const Context& ctx, args1.wdesc.desc(), filter_data, args1.cdesc.desc(), - data_algo, + bwd_result.algo, &beta, args1.idesc.desc(), transformed_input_grad_data, @@ -494,7 +495,7 @@ void ConvCudnnGradKernel(const Context& ctx, args1.odesc.desc(), output_grad_data + i * group_offset_out, args1.cdesc.desc(), - data_algo, + bwd_result.algo, cudnn_workspace_ptr, workspace_size_d, &beta, @@ -554,7 +555,7 @@ void ConvCudnnGradKernel(const Context& ctx, args2.idesc.desc(), input_data, args2.cdesc.desc(), - filter_algo, + filter_result.algo, &beta, args2.wdesc.desc(), filter_grad_data, @@ -575,7 +576,7 @@ void ConvCudnnGradKernel(const Context& ctx, args2.odesc.desc(), output_grad_data + i * group_offset_out, args2.cdesc.desc(), - filter_algo, + filter_result.algo, cudnn_workspace_ptr, workspace_size_w, &beta_filter, diff --git a/paddle/phi/kernels/gpudnn/conv_kernel.cu b/paddle/phi/kernels/gpudnn/conv_kernel.cu index 256dcd4baac..3d3ab7b7a4e 100644 --- a/paddle/phi/kernels/gpudnn/conv_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_kernel.cu @@ -18,7 +18,6 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/fluid/framework/eigen.h" #ifdef PADDLE_WITH_HIP #include "paddle/fluid/operators/conv_miopen_helper.h" #else @@ -68,7 +67,6 @@ void ConvCudnnKernel(const Context& ctx, "FLAGS_cudnn_deterministic True at same time.")); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); - auto dtype = paddle::platform::CudnnDataType::type; #ifdef PADDLE_WITH_HIP @@ -309,17 +307,17 @@ void ConvCudnnKernel(const Context& ctx, size_t workspace_size = 0; // final workspace to allocate. // ------------------- cudnn conv algorithm --------------------- #ifdef PADDLE_WITH_HIP - miopenConvFwdAlgorithm_t algo{}; + paddle::operators::SearchResult fwd_result; using search = paddle::operators::SearchAlgorithm; workspace_size = search::GetWorkspaceSize(args); - algo = search::Find( + fwd_result.algo = search::Find( args, exhaustive_search, deterministic, workspace_size, ctx); #else - cudnnConvolutionFwdAlgo_t algo{}; + paddle::operators::SearchResult fwd_result; using search = paddle::operators::SearchAlgorithm; - algo = search::Find(args, exhaustive_search, deterministic, ctx); - workspace_size = search::GetWorkspaceSize(args, algo); + fwd_result = search::Find(args, exhaustive_search, deterministic, ctx); + workspace_size = search::GetWorkspaceSize(args, fwd_result.algo); #endif #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) @@ -328,7 +326,7 @@ void ConvCudnnKernel(const Context& ctx, // in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\ // FWD_ALGO_IMPLICIT_GEMM manually. if (groups > 1) { - algo = static_cast(0); + fwd_result.algo = static_cast(0); } #endif @@ -352,7 +350,7 @@ void ConvCudnnKernel(const Context& ctx, args.wdesc.desc(), filter_data, args.cdesc.desc(), - algo, + fwd_result.algo, &beta, args.odesc.desc(), output_data, @@ -373,7 +371,7 @@ void ConvCudnnKernel(const Context& ctx, args.wdesc.desc(), filter_data + i * group_offset_filter, args.cdesc.desc(), - algo, + fwd_result.algo, workspace_ptr, workspace_size, &beta, diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu index 2893bd74b1b..601ac43eeef 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_grad_kernel.cu @@ -188,11 +188,13 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, dtype}; #ifdef PADDLE_WITH_HIP - miopenConvFwdAlgorithm_t data_algo{}; - miopenConvBwdWeightsAlgorithm_t filter_algo{}; + paddle::operators::SearchResult fwd_result; + paddle::operators::SearchResult + filter_result; #else - cudnnConvolutionFwdAlgo_t data_algo{}; - cudnnConvolutionBwdFilterAlgo_t filter_algo{}; + paddle::operators::SearchResult fwd_result; + paddle::operators::SearchResult + filter_result; #endif auto layout_tensor = paddle::platform::GetCudnnTensorFormat(layout); @@ -218,14 +220,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, using search1 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1)); - data_algo = + fwd_result.algo = search1::Find(args1, false, deterministic, workspace_size, ctx); #else using search1 = paddle::operators::SearchAlgorithm; - data_algo = search1::Find(args1, false, deterministic, ctx); - workspace_size = - std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo)); + fwd_result = search1::Find(args1, false, deterministic, ctx); + workspace_size = std::max( + workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo)); #endif } @@ -245,14 +247,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, using search2 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); - filter_algo = + filter_result.algo = search2::Find(args2, false, deterministic, workspace_size, ctx); #else using search2 = paddle::operators::SearchAlgorithm; - filter_algo = search2::Find(args2, false, deterministic, ctx); - workspace_size = - std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo)); + filter_result = search2::Find(args2, false, deterministic, ctx); + workspace_size = std::max( + workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo)); #endif } @@ -278,7 +280,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, args1.wdesc.desc(), filter_data + filter_offset * g, args1.cdesc.desc(), - data_algo, + fwd_result.algo, &beta, args1.odesc.desc(), dx_data + x_offset * g, @@ -295,7 +297,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, args1.wdesc.desc(), filter_data + filter_offset * g, args1.cdesc.desc(), - data_algo, + fwd_result.algo, cudnn_workspace, workspace_size, &beta, @@ -338,7 +340,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, args2.idesc.desc(), dout_data + dout_offset * g, args2.cdesc.desc(), - filter_algo, + filter_result.algo, &beta, args2.wdesc.desc(), dfilter_data + filter_offset * g, @@ -355,7 +357,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, args2.odesc.desc(), x_data + x_offset * g, args2.cdesc.desc(), - filter_algo, + filter_result.algo, cudnn_workspace, workspace_size, &beta, @@ -653,22 +655,17 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( dilations_, dtype}; #ifdef PADDLE_WITH_HIP - miopenConvBwdDataAlgorithm_t bwd_algo1 = - static_cast(0); - miopenConvBwdDataAlgorithm_t bwd_algo2 = - static_cast(0); - miopenConvFwdAlgorithm_t data_algo = static_cast(0); - miopenConvBwdWeightsAlgorithm_t filter_algo = - static_cast(0); + paddle::operators::SearchResult bwd_result1; + paddle::operators::SearchResult bwd_result2; + paddle::operators::SearchResult + filter_result; + paddle::operators::SearchResult fwd_result; #else - cudnnConvolutionBwdDataAlgo_t bwd_algo1 = - static_cast(0); - cudnnConvolutionBwdDataAlgo_t bwd_algo2 = - static_cast(0); - cudnnConvolutionFwdAlgo_t data_algo = - static_cast(0); - cudnnConvolutionBwdFilterAlgo_t filter_algo = - static_cast(0); + paddle::operators::SearchResult bwd_result1; + paddle::operators::SearchResult bwd_result2; + paddle::operators::SearchResult + filter_result; + paddle::operators::SearchResult fwd_result; #endif auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); @@ -696,13 +693,13 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( using search1 = paddle::operators::SearchAlgorithm; workspace_size = search1::GetWorkspaceSize(args1); - bwd_algo1 = + bwd_result1.algo = search1::Find(args1, false, deterministic, workspace_size, ctx); #else using search1 = paddle::operators::SearchAlgorithm; - bwd_algo1 = search1::Find(args1, false, deterministic, ctx); - workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1); + bwd_result1 = search1::Find(args1, false, deterministic, ctx); + workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo); #endif ddfilter_ = ddfilter.data(); @@ -720,14 +717,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( using search2 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2)); - bwd_algo2 = + bwd_result2.algo = search2::Find(args2, false, deterministic, workspace_size, ctx); #else using search2 = paddle::operators::SearchAlgorithm; - bwd_algo2 = search2::Find(args2, false, deterministic, ctx); - workspace_size = - std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2)); + bwd_result2 = search2::Find(args2, false, deterministic, ctx); + workspace_size = std::max( + workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo)); #endif } @@ -736,9 +733,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args3.handle = handle; args3.idesc.set(transformed_dout, iwo_group); args3.wdesc.set(*dfilter, layout, iwo_group); - args3.odesc.set(transformed_ddx_channel, iwo_group); - args3.cdesc.set(dtype, padding_common, strides, @@ -749,14 +744,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( using search3 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3)); - filter_algo = + filter_result.algo = search3::Find(args3, false, deterministic, workspace_size, ctx); #else using search3 = paddle::operators::SearchAlgorithm; - filter_algo = search3::Find(args3, false, deterministic, ctx); - workspace_size = - std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo)); + filter_result = search3::Find(args3, false, deterministic, ctx); + workspace_size = std::max( + workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); #endif } @@ -777,14 +772,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( using search4 = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4)); - data_algo = + fwd_result.algo = search4::Find(args4, false, deterministic, workspace_size, ctx); #else using search4 = paddle::operators::SearchAlgorithm; - data_algo = search4::Find(args4, false, deterministic, ctx); - workspace_size = - std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo)); + fwd_result = search4::Find(args4, false, deterministic, ctx); + workspace_size = std::max( + workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo)); #endif } @@ -831,7 +826,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args1.wdesc.desc(), filter_ + i * group_offset_filter, args1.cdesc.desc(), - bwd_algo1, + bwd_result1.algo, &beta, args1.idesc.desc(), transformed_ddout_channel_ + i * group_offset_out, @@ -850,7 +845,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args1.odesc.desc(), ddx_ + i * group_offset_in, args1.cdesc.desc(), - bwd_algo1, + bwd_result1.algo, workspace_ptr, workspace_size, &beta, @@ -877,7 +872,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args2.wdesc.desc(), ddfilter_ + i * group_offset_filter, args2.cdesc.desc(), - bwd_algo2, + bwd_result2.algo, &beta, args2.idesc.desc(), conv_x_ddfilter_data + i * group_offset_out, @@ -908,7 +903,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args2.odesc.desc(), x_ + i * group_offset_in, args2.cdesc.desc(), - bwd_algo2, + bwd_result2.algo, workspace_ptr, workspace_size, &alpha, @@ -964,7 +959,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args3.idesc.desc(), transformed_dout_channel_ + i * group_offset_out, args3.cdesc.desc(), - filter_algo, + filter_result.algo, &beta, args3.wdesc.desc(), dfilter_ + i * group_offset_filter, @@ -983,7 +978,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args3.odesc.desc(), ddx_ + i * group_offset_in, args3.cdesc.desc(), - filter_algo, + filter_result.algo, workspace_ptr, workspace_size, &beta, @@ -1009,7 +1004,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args4.wdesc.desc(), ddfilter_ + i * group_offset_filter, args4.cdesc.desc(), - data_algo, + fwd_result.algo, &beta, args4.odesc.desc(), transformed_dx_ + i * group_offset_in, @@ -1028,7 +1023,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( args4.wdesc.desc(), ddfilter_ + i * group_offset_filter, args4.cdesc.desc(), - data_algo, + fwd_result.algo, workspace_ptr, workspace_size, &beta, diff --git a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu index 5de2df4a70c..ce02a00162b 100644 --- a/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_transpose_kernel.cu @@ -217,16 +217,19 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, c_groups); #ifdef PADDLE_WITH_HIP + paddle::operators::SearchResult bwd_result; using search = paddle::operators::SearchAlgorithm; workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args)); - algo = search::Find(args, false, deterministic, workspace_size, ctx); + bwd_result.algo = + search::Find(args, false, deterministic, workspace_size, ctx); #else + paddle::operators::SearchResult bwd_result; using search = paddle::operators::SearchAlgorithm; - algo = search::Find(args, false, deterministic, ctx); + bwd_result = search::Find(args, false, deterministic, ctx); workspace_size = - std::max(workspace_size, search::GetWorkspaceSize(args, algo)); + std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo)); #endif // ------------------- cudnn conv transpose forward --------------------- @@ -247,7 +250,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, args.wdesc.desc(), filter_data + filter_offset * g, args.cdesc.desc(), - algo, + bwd_result.algo, &beta, args.idesc.desc(), transformed_out_data + out_offset * g, @@ -264,7 +267,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, args.odesc.desc(), x_data + x_offset * g, args.cdesc.desc(), - algo, + bwd_result.algo, cudnn_workspace, workspace_size, &beta, diff --git a/paddle/phi/kernels/impl/conv_cudnn_impl.h b/paddle/phi/kernels/impl/conv_cudnn_impl.h index 93bc5b64adc..5cf59fe0192 100644 --- a/paddle/phi/kernels/impl/conv_cudnn_impl.h +++ b/paddle/phi/kernels/impl/conv_cudnn_impl.h @@ -36,7 +36,7 @@ #include "paddle/phi/kernels/funcs/batch_norm_utils.h" DECLARE_bool(cudnn_deterministic); -DECLARE_uint64(conv_workspace_size_limit); +DECLARE_int64(conv_workspace_size_limit); DECLARE_bool(cudnn_exhaustive_search); namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_switch_autotune.py b/python/paddle/fluid/tests/unittests/test_switch_autotune.py index 9fad1eeb5c2..1775272aac6 100644 --- a/python/paddle/fluid/tests/unittests/test_switch_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_switch_autotune.py @@ -14,7 +14,7 @@ import paddle import unittest -import numpy +import numpy as np class SimpleNet(paddle.nn.Layer): @@ -27,6 +27,7 @@ class SimpleNet(paddle.nn.Layer): def train_dygraph(net, data): + data.stop_gradient = False out = net(data) loss = paddle.mean(out) adam = paddle.optimizer.Adam(parameters=net.parameters()) @@ -36,6 +37,7 @@ def train_dygraph(net, data): def static_program(net, data): + data.stop_gradient = False out = net(data) loss = paddle.mean(out) adam = paddle.optimizer.Adam() @@ -44,48 +46,63 @@ def static_program(net, data): class TestAutoTune(unittest.TestCase): + def set_flags(self, enable_autotune): + if paddle.is_compiled_with_cuda(): + if enable_autotune: + paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1}) + else: + paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512}) + + def get_flags(self, name): + res = paddle.get_flags(name) + return res[name] + + def get_expected_res(self, step_id, enable_autotune): + expected_res = { + "step_id": step_id, + "cache_size": 0, + "cache_hit_rate": 0 + } + if paddle.is_compiled_with_cuda(): + # Total 3 * num_iters cache accesses, only iter 2 hits the cache. + if enable_autotune and step_id >= 1: + expected_res["cache_size"] = 3 + if enable_autotune and step_id == 2: + expected_res["cache_hit_rate"] = np.round( + float(3) / float(9), 5) + return expected_res + def test_autotune(self): paddle.fluid.core.disable_autotune() - status = paddle.fluid.core.autotune_status() - self.assertEqual(status["use_autotune"], False) + self.assertEqual(self.get_flags("FLAGS_use_autotune"), False) paddle.fluid.core.enable_autotune() - status = paddle.fluid.core.autotune_status() - self.assertEqual(status["use_autotune"], True) + self.assertEqual(self.get_flags("FLAGS_use_autotune"), True) def check_status(self, expected_res): status = paddle.fluid.core.autotune_status() for key in status.keys(): - self.assertEqual(status[key], expected_res[key]) + if key == "cache_hit_rate": + v = np.round(status[key], 5) + else: + v = status[key] + self.assertEqual(v, expected_res[key]) class TestDygraphAutoTuneStatus(TestAutoTune): def run_program(self, enable_autotune): + self.set_flags(enable_autotune) if enable_autotune: paddle.fluid.core.enable_autotune() else: paddle.fluid.core.disable_autotune() - paddle.fluid.core.autotune_range(1, 2) + paddle.fluid.core.set_autotune_range(1, 2) x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.) net = SimpleNet() for i in range(3): train_dygraph(net, x_var) - if i >= 1 and i < 2: - expected_res = { - "step_id": i, - "use_autotune": enable_autotune, - "cache_size": 0, - "cache_hit_rate": 0 - } - self.check_status(expected_res) - else: - expected_res = { - "step_id": i, - "use_autotune": False, - "cache_size": 0, - "cache_hit_rate": 0 - } - self.check_status(expected_res) + expected_res = self.get_expected_res(i, enable_autotune) + self.check_status(expected_res) def func_enable_autotune(self): self.run_program(enable_autotune=True) @@ -107,59 +124,45 @@ class TestDygraphAutoTuneStatus(TestAutoTune): class TestStaticAutoTuneStatus(TestAutoTune): def run_program(self, enable_autotune): paddle.enable_static() - if enable_autotune: - paddle.fluid.core.enable_autotune() - else: - paddle.fluid.core.disable_autotune() - paddle.fluid.core.autotune_range(1, 2) data_shape = [1, 1, 8, 8] - data = paddle.static.data(name='X', shape=data_shape, dtype='float32') - net = SimpleNet() - loss = static_program(net, data) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + data = paddle.static.data( + name='X', shape=data_shape, dtype='float32') + net = SimpleNet() + loss = static_program(net, data) place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda( ) else paddle.CPUPlace() exe = paddle.static.Executor(place) - exe.run(paddle.static.default_startup_program()) - x = numpy.random.random(size=data_shape).astype('float32') + exe.run(startup_program) + x = np.random.random(size=data_shape).astype('float32') + + self.set_flags(enable_autotune) + if enable_autotune: + paddle.fluid.core.enable_autotune() + else: + paddle.fluid.core.disable_autotune() + paddle.fluid.core.set_autotune_range(1, 2) for i in range(3): - exe.run(feed={'X': x}, fetch_list=[loss]) + exe.run(program=main_program, feed={'X': x}, fetch_list=[loss]) status = paddle.fluid.core.autotune_status() - # In static mode, the startup_program will run at first. - # The expected step_id will be increased by 1. - if i >= 0 and i < 1: - expected_res = { - "step_id": i + 1, - "use_autotune": enable_autotune, - "cache_size": 0, - "cache_hit_rate": 0 - } - self.check_status(expected_res) - else: - expected_res = { - "step_id": i + 1, - "use_autotune": False, - "cache_size": 0, - "cache_hit_rate": 0 - } - self.check_status(expected_res) + expected_res = self.get_expected_res(i, enable_autotune) + self.check_status(expected_res) paddle.disable_static() def func_enable_autotune(self): self.run_program(enable_autotune=True) def test_enable_autotune(self): - with paddle.fluid.framework._test_eager_guard(): - self.func_enable_autotune() self.func_enable_autotune() def func_disable_autotune(self): self.run_program(enable_autotune=False) def test_disable_autotune(self): - with paddle.fluid.framework._test_eager_guard(): - self.func_disable_autotune() self.func_disable_autotune() -- GitLab