From 5419a95d1ed483cd05d968700b51f3245ae06079 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 19:37:30 +0800 Subject: [PATCH] perf(cuda/conv): cache serval cudnn api GitOrigin-RevId: 188c62cdd65ba27793b0af89164c6eee0122b21f --- dnn/src/common/api_cache.h | 176 +++++++++------- dnn/src/cuda/api_cache.h | 196 +++++++++--------- dnn/src/cuda/conv_bias/cudnn_conv.cpp | 6 +- .../conv_bias/cudnn_conv_bias_activation.cpp | 6 +- dnn/src/cuda/conv_bias/opr_impl.cpp | 5 +- .../cuda/convolution/backward_data/cudnn.cpp | 6 +- .../convolution/backward_filter/cudnn.cpp | 6 +- dnn/src/cuda/convolution/opr_impl.cpp | 10 +- .../convolution3d/backward_data/cudnn.cpp | 6 +- .../convolution3d/backward_filter/cudnn.cpp | 6 +- dnn/src/cuda/convolution3d/forward/cudnn.cpp | 6 +- dnn/src/cuda/convolution3d/helper.h | 17 +- dnn/src/cuda/convolution3d/opr_impl.cpp | 3 +- .../conv_trait/ibatch_conv_trait.cuh | 2 +- .../conv_trait/iconv_imma_trait.cuh | 2 +- .../conv_trait/iconv_trait.cuh | 2 +- dnn/src/cuda/handle.cpp | 112 +++++++++- dnn/src/cuda/handle.h | 29 +++ 18 files changed, 389 insertions(+), 207 deletions(-) diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h index 6bb2d4d7e..c39589bc5 100644 --- a/dnn/src/common/api_cache.h +++ b/dnn/src/common/api_cache.h @@ -12,32 +12,28 @@ #pragma once -#include -#include #include +#include #include +#include #include "megdnn/thin/function.h" namespace megdnn { - -template -class FunctionCache; - -template -class FunctionCache { +template +class FunctionCache { public: using key_t = std::string; - using value_t = TRet; + using value_t = std::string; using key_mapper_t = thin_function; using value_mapper_t = thin_function; using storage_t = std::unordered_map; -public: + storage_t storage; key_mapper_t key_mapper; value_mapper_t value_mapper; -public: - TRet operator()(TArgs... args) { + + value_t operator()(TArgs... args) { key_t key = key_mapper(args...); if (storage.count(key) == 0) { storage[key] = value_mapper(std::forward(args)...); @@ -46,28 +42,28 @@ public: } }; - // FIFO class StringSerializer { private: std::string m_buffer; size_t m_cursor = 0; + public: template T read_plain() { - T result; - std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); + static_assert(std::is_trivially_copyable::value, "invalid type"); + T ret; + memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); - return result; + return ret; } template void write_plain(T value) { - m_buffer.resize(m_buffer.size() + sizeof(T)); - std::memcpy(const_cast(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); + static_assert(std::is_trivially_copyable::value, + "type should be trivially copyable"); + m_buffer.append(reinterpret_cast(&value), sizeof(T)); } std::string take() { - std::string result; - m_buffer.erase(0, m_cursor); return std::move(m_buffer); } void set(std::string new_buf) { @@ -76,20 +72,20 @@ public: } }; - struct Empty {}; - template class ParamBundle { private: - template - static std::index_sequence add_all(std::index_sequence){ + template + static std::index_sequence add_all( + std::index_sequence) { return {}; } - template - using make_index_range = decltype(add_all(std::make_index_sequence())); + template + using make_index_range = + decltype(add_all(std::make_index_sequence())); using storage_t = std::tuple...>; storage_t m_storage; @@ -99,21 +95,31 @@ private: return functor(std::get(m_storage).value...); } template - auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { - return serialize_helper(ser, std::get(m_storage).serialize(ser, prev), std::index_sequence()); + auto serialize_helper(StringSerializer& ser, TPrev&& prev, + std::index_sequence) { + return serialize_helper(ser, + std::get(m_storage).serialize(ser, prev), + std::index_sequence()); } template - auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + auto serialize_helper(StringSerializer& ser, TPrev&& prev, + std::index_sequence<>) {} template - auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { - return deserialize_helper(ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); + auto deserialize_helper(StringSerializer& ser, TPrev&& prev, + std::index_sequence) { + return deserialize_helper( + ser, std::get(m_storage).deserialize(ser, prev), + std::index_sequence()); } template - auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + auto deserialize_helper(StringSerializer& ser, TPrev&& prev, + std::index_sequence<>) {} template - void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { + void set_values_helper(std::index_sequence, TArg&& arg, + TArgs&&... args) { std::get(m_storage).value = arg; - set_values_helper(std::index_sequence(), std::forward(args)...); + set_values_helper(std::index_sequence(), + std::forward(args)...); } template void set_values_helper(std::index_sequence) { @@ -123,27 +129,33 @@ private: public: template auto call_by(TFunctor&& functor) { - return call_helper(std::forward(functor), std::make_index_sequence()); + return call_helper(std::forward(functor), + std::make_index_sequence()); } template void serialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - serialize_helper(ser, Empty{}, make_index_range()); + serialize_helper( + ser, Empty{}, + add_all(std::make_index_sequence())); } template void deserialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - deserialize_helper(ser, Empty{}, make_index_range()); + deserialize_helper( + ser, Empty{}, + add_all(std::make_index_sequence())); } template void set_values(TArgs&&... args) { - set_values_helper(make_index_range(), std::forward(args)...); + set_values_helper( + add_all(std::make_index_sequence()), + std::forward(args)...); } }; - template -class RetParam { +class Param { public: T value; Empty serialize(StringSerializer& ser, Empty) { @@ -156,45 +168,68 @@ public: } }; - -template , typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> +template , typename TInputs = std::tuple<>, + typename TOutputs = std::tuple<>> class FunctionCacheBuilder { private: - static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } + static auto declargs() + -> decltype(std::tuple_cat(std::declval(), + std::declval())) { + return {}; + } template - static auto declfunction_helper(std::index_sequence) -> thin_function().value)(decltype(std::get(declargs()).value)...)> { return {}; } + static auto declfunction_helper(std::index_sequence) + -> thin_function().value)( + decltype(std::get(declargs()).value)...)> { + return {}; + } static auto declfunction() { - return declfunction_helper(std::make_index_sequence::value + std::tuple_size::value>()); + return declfunction_helper( + std::make_index_sequence::value + + std::tuple_size::value>()); } template - static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))...> { return {}; } + static auto declbundle_helper(std::index_sequence) + -> ParamBundle(declargs()))...> { + return {}; + } static auto declbundle() { - return declbundle_helper(std::make_index_sequence::value+std::tuple_size::value>()); + return declbundle_helper( + std::make_index_sequence::value + + std::tuple_size::value>()); } using function_t = decltype(declfunction()); using bundle_t = decltype(declbundle()); + public: template auto ret() { - static_assert(std::is_same>::value, "return value redefinition"); + static_assert(std::is_same>::value, + "return value redefinition"); return FunctionCacheBuilder{}; } template auto input() { - using TNewInputs = decltype(std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); + using TNewInputs = decltype( + std::tuple_cat(std::declval(), + std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } template auto output() { - using TNewOutputs = decltype(std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); + using TNewOutputs = decltype( + std::tuple_cat(std::declval(), + std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } template function_t build(TFunctor func) { - FunctionCache cache; + FunctionCache cache; cache.key_mapper = [](bundle_t bundle) { StringSerializer ser; - bundle.template serialize_params<0, std::tuple_size::value>(ser); + bundle.template serialize_params<0, + std::tuple_size::value>( + ser); return ser.take(); }; cache.value_mapper = [=](bundle_t bundle) { @@ -202,42 +237,33 @@ public: TRet ret; ret.value = bundle.call_by(func); ret.serialize(ser, Empty{}); - bundle.template serialize_params::value, std::tuple_size::value+std::tuple_size::value>(ser); + bundle.template serialize_params< + std::tuple_size::value, + std::tuple_size::value + + std::tuple_size::value>(ser); return ser.take(); }; return [=](auto&&... args) mutable { bundle_t bundle; TRet ret; StringSerializer ser; - static_assert(sizeof...(args) == std::tuple_size::value+std::tuple_size::value, - "arg count mismatch"); - bundle.template set_values<0, sizeof...(args)>(std::forward(args)...); + static_assert( + sizeof...(args) == std::tuple_size::value + + std::tuple_size::value, + "args count mismatch"); + bundle.template set_values<0, sizeof...(args)>( + std::forward(args)...); ser.set(cache(bundle)); ret.deserialize(ser, Empty{}); constexpr size_t n_inputs = std::tuple_size::value; constexpr size_t n_outputs = std::tuple_size::value; - bundle.template deserialize_params(ser); + bundle.template deserialize_params( + ser); return ret.value; }; } }; - -template -class PlainParam { -public: - T value; - Empty serialize(StringSerializer& ser, Empty) { - ser.write_plain(value); - return Empty{}; - } - Empty deserialize(StringSerializer& ser, Empty) { - value = ser.read_plain(); - return Empty{}; - } -}; - - template class RefParam { public: @@ -252,7 +278,6 @@ public: } }; - template class RefArraySizeParam { public: @@ -266,7 +291,6 @@ public: } }; - template class ArrayParam { public: @@ -285,4 +309,4 @@ public: } }; -} +} // namespace megdnn diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h index c1531ea7e..f6f51b754 100644 --- a/dnn/src/cuda/api_cache.h +++ b/dnn/src/cuda/api_cache.h @@ -16,105 +16,109 @@ #include "src/cuda/cudnn_wrapper.h" namespace megdnn { - class CudnnConvDescParam { - public: - cudnnConvolutionDescriptor_t value; - Empty serialize(StringSerializer& ser, Empty) { - int ndim = MEGDNN_MAX_NDIM; - int padA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - int dilationA[MEGDNN_MAX_NDIM]; - cudnnConvolutionMode_t mode; - cudnnDataType_t computeType; - cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); - ser.write_plain(ndim); - for (int i = 0; i < ndim; ++i) { - ser.write_plain(padA[i]); - ser.write_plain(strideA[i]); - ser.write_plain(dilationA[i]); - } - ser.write_plain(mode); - ser.write_plain(computeType); - return Empty{}; +class CudnnConvDescParam { +public: + cudnnConvolutionDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + constexpr int nbDims = MEGDNN_MAX_NDIM; + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; + cudnnConvolutionMode_t mode; + cudnnDataType_t computeType; + cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, + dilationA, &mode, &computeType); + ser.write_plain(nbDims); + for (int i = 0; i < nbDims; ++i) { + ser.write_plain(padA[i]); + ser.write_plain(strideA[i]); + ser.write_plain(dilationA[i]); } - Empty deserialize(StringSerializer& ser, Empty) { - int ndim = ser.read_plain(); - int padA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - int dilationA[MEGDNN_MAX_NDIM]; - for (int i = 0; i < ndim; ++i) { - padA[i] = ser.read_plain(); - strideA[i] = ser.read_plain(); - dilationA[i] = ser.read_plain(); - } - cudnnConvolutionMode_t mode = ser.read_plain(); - cudnnDataType_t computeType = ser.read_plain(); - cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); - return Empty{}; + ser.write_plain(mode); + ser.write_plain(computeType); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + int ndim = ser.read_plain(); + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; + for (int i = 0; i < ndim; ++i) { + padA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); + dilationA[i] = ser.read_plain(); } - }; - class CudnnTensorDescParam { - public: - cudnnTensorDescriptor_t value; - Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - int dimA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); - ser.write_plain(nbDims); - for (int i = 0; i < nbDims; ++i) { - ser.write_plain(dimA[i]); - ser.write_plain(strideA[i]); - } - ser.write_plain(dataType); - return Empty{}; + cudnnConvolutionMode_t mode = ser.read_plain(); + cudnnDataType_t computeType = ser.read_plain(); + cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, + mode, computeType); + return Empty{}; + } +}; +class CudnnTensorDescParam { +public: + cudnnTensorDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + constexpr int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + int dimA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, + strideA); + ser.write_plain(nbDims); + for (int i = 0; i < nbDims; ++i) { + ser.write_plain(dimA[i]); + ser.write_plain(strideA[i]); } - Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - int dimA[MEGDNN_MAX_NDIM]; - int strideA[MEGDNN_MAX_NDIM]; - nbDims = ser.read_plain(); - for (int i = 0; i < nbDims; ++i) { - dimA[i] = ser.read_plain(); - strideA[i] = ser.read_plain(); - } - dataType = ser.read_plain(); - cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); - return Empty{}; + ser.write_plain(dataType); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + constexpr int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + int dimA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + dimA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); } - }; - class CudnnFilterDescParam { - public: - cudnnFilterDescriptor_t value; - Empty serialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - cudnnTensorFormat_t format; - int filterDimA[MEGDNN_MAX_NDIM]; - cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); - ser.write_plain(nbDims); - for (int i = 0; i < nbDims; ++i) { - ser.write_plain(filterDimA[i]); - } - ser.write_plain(dataType); - ser.write_plain(format); - return Empty{}; + dataType = ser.read_plain(); + cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); + return Empty{}; + } +}; +class CudnnFilterDescParam { +public: + cudnnFilterDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + constexpr int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + cudnnTensorFormat_t format; + int filterDimA[MEGDNN_MAX_NDIM]; + cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, + filterDimA); + ser.write_plain(nbDims); + for (int i = 0; i < nbDims; ++i) { + ser.write_plain(filterDimA[i]); } - Empty deserialize(StringSerializer& ser, Empty) { - int nbDims = MEGDNN_MAX_NDIM; - cudnnDataType_t dataType; - cudnnTensorFormat_t format; - int filterDimA[MEGDNN_MAX_NDIM]; - nbDims = ser.read_plain(); - for (int i = 0; i < nbDims; ++i) { - filterDimA[i] = ser.read_plain(); - } - dataType = ser.read_plain(); - format = ser.read_plain(); - cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); - return Empty{}; + ser.write_plain(dataType); + ser.write_plain(format); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + constexpr int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + cudnnTensorFormat_t format; + int filterDimA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + filterDimA[i] = ser.read_plain(); } - }; -} + dataType = ser.read_plain(); + format = ser.read_plain(); + cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); + return Empty{}; + } +}; +} // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/cudnn_conv.cpp b/dnn/src/cuda/conv_bias/cudnn_conv.cpp index 64a164032..dbc3f4a96 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv.cpp @@ -39,7 +39,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( conv_args.init_conv_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = conv_args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); @@ -65,7 +66,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( conv_args.init_conv_desc(D); size_t conv_workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = conv_args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &conv_workspace_size); diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index ab0968def..b3f3df785 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -108,7 +108,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( megdnn_throw("unsupported NonlineMode"); } size_t workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); @@ -121,7 +122,8 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( args.init_conv_bias_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 3af19ee46..075cb7c7c 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -83,12 +83,13 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( CUDNNForwardDescs desc; conv_args.init_conv_desc(desc); #if CUDNN_MAJOR >= 7 + auto& cudnn = static_cast(this->handle())->cudnn(); int max_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, + cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_count)); SmallVector algo_perf(max_count); int ret_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( + cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, &ret_count, algo_perf.data())); diff --git a/dnn/src/cuda/convolution/backward_data/cudnn.cpp b/dnn/src/cuda/convolution/backward_data/cudnn.cpp index e7ec05309..0d2d4be5e 100644 --- a/dnn/src/cuda/convolution/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_data/cudnn.cpp @@ -44,9 +44,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( } #endif + auto& cudnn = args.handle->cudnn(); args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, @@ -59,10 +60,11 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( const SizeArgs &args) const { + auto& cudnn = args.handle->cudnn(); CUDNNBwdDataDescs D; args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, diff --git a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp index f66e2b5bd..acf39d75c 100644 --- a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp @@ -21,6 +21,7 @@ using namespace convolution; bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( const SizeArgs &args) const { + auto& cudnn = args.handle->cudnn(); CUDNNBwdFilterDescs D; if (!is_cudnn_supported(args.as_fwd_args())) @@ -28,7 +29,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, @@ -41,10 +42,11 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( const SizeArgs &args) const { + auto& cudnn = args.handle->cudnn(); CUDNNBwdFilterDescs D; args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index c6661d5cc..eb3e3e302 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -141,12 +141,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, #if CUDNN_MAJOR >= 7 MEGDNN_MARK_USED_VAR(negative_attr); + auto& cudnn = args.handle->cudnn(); int max_count = 0; - cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( cudnn_handle, &max_count)); SmallVector algo_perf(max_count); int ret_count = 0; - cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( + cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, algo_perf.data())); @@ -286,12 +287,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( #endif #if CUDNN_MAJOR >= 7 MEGDNN_MARK_USED_VAR(negative_attr); + auto& cudnn = args.handle->cudnn(); int max_count = 0; - cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( cudnn_handle, &max_count)); SmallVector algo_perf(max_count); int ret_count = 0; - cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, algo_perf.data())); diff --git a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp index c7fcafdfe..614af9caa 100644 --- a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp @@ -28,7 +28,8 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, @@ -44,7 +45,8 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( CUDNNBwdDataDescs D; args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, diff --git a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp index a0afe7c25..50f330109 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp @@ -28,7 +28,8 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; @@ -40,7 +41,8 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); megdnn_assert(status == CUDNN_STATUS_SUCCESS, diff --git a/dnn/src/cuda/convolution3d/forward/cudnn.cpp b/dnn/src/cuda/convolution3d/forward/cudnn.cpp index da801e31c..845c1baff 100644 --- a/dnn/src/cuda/convolution3d/forward/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/forward/cudnn.cpp @@ -27,7 +27,8 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, @@ -43,7 +44,8 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( CUDNNForwardDescs D; args.init_desc(D); size_t workspace_size; - auto status = cudnnGetConvolutionForwardWorkspaceSize( + auto& cudnn = args.handle->cudnn(); + auto status = cudnn.GetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, diff --git a/dnn/src/cuda/convolution3d/helper.h b/dnn/src/cuda/convolution3d/helper.h index 6916515b6..e8f68ccf3 100644 --- a/dnn/src/cuda/convolution3d/helper.h +++ b/dnn/src/cuda/convolution3d/helper.h @@ -92,7 +92,7 @@ namespace convolution3d { const Workspace &workspace, void *&raw_ptr); inline bool cudnn_get_convolution_fwd_algo_helper( - cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, + Handle* handle, const cudnnTensorDescriptor_t x_desc, const cudnnFilterDescriptor_t w_desc, const cudnnConvolutionDescriptor_t conv_desc, const cudnnTensorDescriptor_t y_desc, @@ -102,13 +102,14 @@ namespace convolution3d { MEGDNN_MARK_USED_VAR(positive_attr); MEGDNN_MARK_USED_VAR(negative_attr); #if CUDNN_MAJOR >= 7 + auto& cudnn = static_cast(handle)->cudnn(); int algo_max_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( - cudnn_handle, &algo_max_count)); + cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( + cuda::cudnn_handle(handle), &algo_max_count)); SmallVector algo_perf(algo_max_count); int algo_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, + cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( + cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, &algo_count, algo_perf.data())); for (int i = 0; i < algo_count; ++i) { if (algo_perf[i].algo == @@ -116,8 +117,8 @@ namespace convolution3d { CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) continue; size_t workspace_size = 0; - cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, + cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( + cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_perf[i].algo, &workspace_size)); if (workspace_size > workspace_limit_in_bytes) continue; if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { @@ -133,7 +134,7 @@ namespace convolution3d { return false; #else cudnn_check(cudnnGetConvolutionForwardAlgorithm( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, + cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_limit_in_bytes, algo)); return true; diff --git a/dnn/src/cuda/convolution3d/opr_impl.cpp b/dnn/src/cuda/convolution3d/opr_impl.cpp index b16c102af..c7719d6e1 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.cpp +++ b/dnn/src/cuda/convolution3d/opr_impl.cpp @@ -74,13 +74,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic( auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes, positive_attr, negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { - auto cudnn_handle = cuda::cudnn_handle(this->handle()); cudnnConvolutionFwdAlgo_t algo; CUDNNForwardDescs desc; args.init_desc(desc); bool got = cudnn_get_convolution_fwd_algo_helper( - cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, + this->handle(), desc.src_desc.desc, desc.filter_desc.desc, desc.conv_desc.desc, desc.dst_desc.desc, workspace_limit_in_bytes, &algo, positive_attr, negative_attr); if (got) { diff --git a/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh b/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh index f21197e15..87391f16a 100644 --- a/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh +++ b/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh @@ -56,7 +56,7 @@ namespace convolution { using KernLayout = _kern_layout; \ using OutputLayout = _output_layout; \ using Param = _conv_param; \ - static constexpr bool check_bounds = check_bounds_; + static constexpr bool check_bounds = check_bounds_ #define MEGDNN_COMMA , template #include +#include #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) @@ -88,6 +91,8 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): // check tk1 m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); m_cusolver_handle = nullptr; + + m_cudnn_api_cache = std::make_unique(m_cudnn_handle); } HandleImpl::~HandleImpl() noexcept { @@ -133,8 +138,111 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { return HandleVendorType::CUDA; } -} // namespace cuda -} // namespace megdnn +HandleImpl::CUDNN& HandleImpl::cudnn() { + return *m_cudnn_api_cache; +} + +HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { + m_handle = handle; + GetConvolutionForwardWorkspaceSize = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionForwardWorkspaceSize); +#if CUDNN_MAJOR >= 7 + GetConvolutionForwardAlgorithm_v7 = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .output>() + .ret>() + .build(&cudnnGetConvolutionForwardAlgorithm_v7); + GetConvolutionForwardAlgorithmMaxCount = + FunctionCacheBuilder<>() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionForwardAlgorithmMaxCount); +#endif + GetConvolutionBackwardDataWorkspaceSize = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardDataWorkspaceSize); +#if CUDNN_MAJOR >= 7 + GetConvolutionBackwardDataAlgorithm_v7 = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); + GetConvolutionBackwardDataAlgorithmMaxCount = + FunctionCacheBuilder<>() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); +#endif + GetConvolutionBackwardFilterWorkspaceSize = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); +#if CUDNN_MAJOR >= 7 + GetConvolutionBackwardFilterAlgorithm_v7 = + FunctionCacheBuilder<>() + .input>() + .input() + .input() + .input() + .input() + .input>() + .output>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); + GetConvolutionBackwardFilterAlgorithmMaxCount = + FunctionCacheBuilder<>() + .input>() + .output>() + .ret>() + .build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); +#endif +} + +} // namespace cuda +} // namespace megdnn MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); diff --git a/dnn/src/cuda/handle.h b/dnn/src/cuda/handle.h index a886c5118..cf9fb5159 100644 --- a/dnn/src/cuda/handle.h +++ b/dnn/src/cuda/handle.h @@ -124,6 +124,10 @@ class HandleImpl: public HandleImplHelper { size_t image2d_pitch_alignment() const override; HandleVendorType vendor_type() const override; + + class CUDNN; + + CUDNN& cudnn(); private: bool m_is_tegra_k1; int m_device_id; @@ -156,9 +160,34 @@ class HandleImpl: public HandleImplHelper { //! device ptr to const scalars ConstScalars* m_const_scalars; + std::unique_ptr m_cudnn_api_cache; + void initialize_cusolver(); }; +class HandleImpl::CUDNN { + cudnnHandle_t m_handle; +public: + CUDNN(cudnnHandle_t handle); +#define WRAP_CUDNN_API(NAME) thin_function NAME; + WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); +#if CUDNN_MAJOR >= 7 + WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); + WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); +#endif +#if CUDNN_MAJOR >= 7 + WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); + WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); +#endif + WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); +#if CUDNN_MAJOR >= 7 + WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); + WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); +#endif + WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); +#undef WRAP_CUDNN_API +}; + } // namespace cuda } // namespace megdnn -- GitLab