From 64c922c4bbf2951779708bf816788c686a4a64c3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 15 Sep 2021 14:17:35 +0800 Subject: [PATCH] Revert "fix(api_cache): fix serialization for conv_desc" This reverts commit 95dbc9c685cced46dd910997bd585363c392ccbd. GitOrigin-RevId: ca8c67b6b307b547603e083554b7ffd5736e2d98 --- dnn/src/common/api_cache.h | 64 +++++++++++++---------------- dnn/src/cuda/api_cache.h | 83 ++++++++++++++++++++++---------------- dnn/src/cuda/handle.cpp | 11 +++-- 3 files changed, 81 insertions(+), 77 deletions(-) diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h index 5e50ece31..9009f5e1a 100644 --- a/dnn/src/common/api_cache.h +++ b/dnn/src/common/api_cache.h @@ -131,18 +131,12 @@ public: T read_plain() { static_assert(std::is_trivially_copyable::value, "invalid type"); T ret; - std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); + memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; } template - void read_plain(T* dest) { - static_assert(std::is_trivially_copyable::value, "invalid type"); - std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T)); - m_cursor += sizeof(T); - } - template - void write_plain(const T& value) { + void write_plain(T value) { static_assert(std::is_trivially_copyable::value, "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); @@ -150,7 +144,7 @@ public: std::string take() { return std::move(m_buffer); } void reset(std::string new_buf) { m_cursor = 0; - m_buffer = std::move(new_buf); + m_buffer = new_buf; } }; @@ -159,7 +153,7 @@ struct Empty {}; // in: seq[1, 2, ..., m] // out: seq[N+1, N+2, ... N+m] template -inline std::index_sequence inc_index_sequence( +static std::index_sequence inc_index_sequence( std::index_sequence) { return {}; } @@ -178,7 +172,7 @@ private: // deconstruct tuple and call functor template - auto call_helper(TFunctor&& functor, std::index_sequence) { + auto call_helper(TFunctor functor, std::index_sequence) { return functor(std::get(m_storage).value...); } @@ -209,7 +203,7 @@ private: template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { - std::get(m_storage).value = std::forward(arg); + std::get(m_storage).value = arg; set_values_helper(std::index_sequence(), std::forward(args)...); } @@ -259,7 +253,7 @@ public: } Empty deserialize(StringSerializer& ser, Empty) { - ser.read_plain(&value); + value = ser.read_plain(); return Empty{}; } }; @@ -291,8 +285,7 @@ private: template static auto declbundle_helper(std::index_sequence) - -> ParamBundle(declargs()))>...> { + -> ParamBundle(declargs()))...> { return {}; } @@ -319,11 +312,9 @@ public: // declare new input template auto input() { - static_assert(std::tuple_size::value == 0, - "input arg cannot be declared after output"); - using TNewInputs = - decltype(std::tuple_cat(std::declval(), - std::declval>())); + using TNewInputs = decltype( + std::tuple_cat(std::declval(), + std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } // declare new output @@ -331,29 +322,31 @@ public: auto output() { using TNewOutputs = decltype( std::tuple_cat(std::declval(), - std::declval>())); + std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } // summary template - function_t build(TFunctor&& func) { - constexpr size_t n_inputs = std::tuple_size::value; - constexpr size_t n_outputs = std::tuple_size::value; + function_t build(TFunctor func) { auto cache = std::make_shared>(); // bundle -> ser(in args) cache->key_mapper = [](bundle_t bundle) { StringSerializer ser; - bundle.template serialize_params<0, n_inputs>(ser); + bundle.template serialize_params<0, + std::tuple_size::value>( + ser); return ser.take(); }; // bundle -> ser(out args) - cache->value_mapper = [func](bundle_t bundle) { + cache->value_mapper = [=](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); ret.serialize(ser, Empty{}); - bundle.template serialize_params( - ser); + bundle.template serialize_params< + std::tuple_size::value, + std::tuple_size::value + + std::tuple_size::value>(ser); return ser.take(); }; return [=](auto&&... args) mutable { @@ -368,6 +361,8 @@ public: std::forward(args)...); ser.reset((*cache)(bundle)); ret.deserialize(ser, Empty{}); + constexpr size_t n_inputs = std::tuple_size::value; + constexpr size_t n_outputs = std::tuple_size::value; bundle.template deserialize_params( ser); return ret.value; @@ -399,8 +394,7 @@ public: return *value; } T deserialize(StringSerializer& ser, Empty) { - ser.read_plain(value); - return *value; + return *value = ser.read_plain(); } }; @@ -408,20 +402,16 @@ public: template class ArrayParam { public: - decltype(std::declval().value)* value; + TItem* value; Empty serialize(StringSerializer& ser, TSize size) { - TItem param; for (TSize i = 0; i < size; ++i) { - param.value = value[i]; - param.serialize(ser, Empty{}); + ser.write_plain(value[i]); } return Empty{}; } Empty deserialize(StringSerializer& ser, TSize size) { - TItem param; for (TSize i = 0; i < size; ++i) { - param.deserialize(ser, Empty{}); - value[i] = param.value; + value[i] = ser.read_plain(); } return Empty{}; } diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h index 2299a18a2..f58f6d75b 100644 --- a/dnn/src/cuda/api_cache.h +++ b/dnn/src/cuda/api_cache.h @@ -20,16 +20,14 @@ class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - constexpr int maxNbDims = CUDNN_DIM_MAX - 2; - int nbDims = maxNbDims; - int padA[maxNbDims]; - int strideA[maxNbDims]; - int dilationA[maxNbDims]; + int nbDims = MEGDNN_MAX_NDIM; + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; cudnnConvolutionMode_t mode; cudnnDataType_t computeType; - cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA, - strideA, dilationA, &mode, - &computeType); + cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, + dilationA, &mode, &computeType); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(padA[i]); @@ -40,8 +38,23 @@ public: ser.write_plain(computeType); return Empty{}; } + Empty deserialize(StringSerializer& ser, Empty) { + int ndim = ser.read_plain(); + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; + for (int i = 0; i < ndim; ++i) { + padA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); + dilationA[i] = ser.read_plain(); + } + cudnnConvolutionMode_t mode = ser.read_plain(); + cudnnDataType_t computeType = ser.read_plain(); + cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, + mode, computeType); + return Empty{}; + } }; - class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; @@ -50,8 +63,8 @@ public: cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; - cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims, - dimA, strideA); + cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, + strideA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(dimA[i]); @@ -60,8 +73,21 @@ public: ser.write_plain(dataType); return Empty{}; } + Empty deserialize(StringSerializer& ser, Empty) { + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + int dimA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + dimA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); + } + dataType = ser.read_plain(); + cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); + return Empty{}; + } }; - class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; @@ -80,29 +106,18 @@ public: ser.write_plain(format); return Empty{}; } -}; - -template -class CudnnConvAlgoPerfParam { -public: - T value; - Empty serialize(StringSerializer& ser, Empty) { - ser.write_plain(value.algo); - ser.write_plain(value.status); - ser.write_plain(value.time); - ser.write_plain(value.memory); - ser.write_plain(value.determinism); - ser.write_plain(value.mathType); - return Empty{}; - } - Empty deserialize(StringSerializer& ser, Empty) { - ser.read_plain(&value.algo); - ser.read_plain(&value.status); - ser.read_plain(&value.time); - ser.read_plain(&value.memory); - ser.read_plain(&value.determinism); - ser.read_plain(&value.mathType); + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + cudnnTensorFormat_t format; + int filterDimA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + filterDimA[i] = ser.read_plain(); + } + dataType = ser.read_plain(); + format = ser.read_plain(); + cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); return Empty{}; } }; diff --git a/dnn/src/cuda/handle.cpp b/dnn/src/cuda/handle.cpp index 3ab550941..d7347c972 100644 --- a/dnn/src/cuda/handle.cpp +++ b/dnn/src/cuda/handle.cpp @@ -168,8 +168,7 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>>() + .output>() .ret>() .build(&cudnnGetConvolutionForwardAlgorithm_v7); GetConvolutionForwardAlgorithmMaxCount = @@ -200,8 +199,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>>() + .output>() .ret>() .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); GetConvolutionBackwardDataAlgorithmMaxCount = @@ -232,8 +231,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { .input() .input>() .output>() - .output>>() + .output>() .ret>() .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); GetConvolutionBackwardFilterAlgorithmMaxCount = -- GitLab