diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..6bb2d4d7e7c0944ca9fad4306afde759bff37fc2 --- /dev/null +++ b/dnn/src/common/api_cache.h @@ -0,0 +1,288 @@ +/** + * \file dnn/src/common/api_cache.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include + +#include "megdnn/thin/function.h" + +namespace megdnn { + +template +class FunctionCache; + +template +class FunctionCache { +public: + using key_t = std::string; + using value_t = TRet; + using key_mapper_t = thin_function; + using value_mapper_t = thin_function; + using storage_t = std::unordered_map; +public: + storage_t storage; + key_mapper_t key_mapper; + value_mapper_t value_mapper; +public: + TRet operator()(TArgs... args) { + key_t key = key_mapper(args...); + if (storage.count(key) == 0) { + storage[key] = value_mapper(std::forward(args)...); + } + return storage[key]; + } +}; + + +// FIFO +class StringSerializer { +private: + std::string m_buffer; + size_t m_cursor = 0; +public: + template + T read_plain() { + T result; + std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); + m_cursor += sizeof(T); + return result; + } + template + void write_plain(T value) { + m_buffer.resize(m_buffer.size() + sizeof(T)); + std::memcpy(const_cast(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); + } + std::string take() { + std::string result; + m_buffer.erase(0, m_cursor); + return std::move(m_buffer); + } + void set(std::string new_buf) { + m_cursor = 0; + m_buffer = new_buf; + } +}; + + +struct Empty {}; + + +template +class ParamBundle { +private: + template + static std::index_sequence add_all(std::index_sequence){ + return {}; + } + + template + using make_index_range = decltype(add_all(std::make_index_sequence())); + + using storage_t = std::tuple...>; + storage_t m_storage; + + template + auto call_helper(TFunctor functor, std::index_sequence) { + return functor(std::get(m_storage).value...); + } + template + auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { + return serialize_helper(ser, std::get(m_storage).serialize(ser, prev), std::index_sequence()); + } + template + auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + template + auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { + return deserialize_helper(ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); + } + template + auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + template + void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { + std::get(m_storage).value = arg; + set_values_helper(std::index_sequence(), std::forward(args)...); + } + template + void set_values_helper(std::index_sequence) { + static_assert(sizeof...(Indices) == 0, "redundant indices"); + } + +public: + template + auto call_by(TFunctor&& functor) { + return call_helper(std::forward(functor), std::make_index_sequence()); + } + template + void serialize_params(StringSerializer& ser) { + static_assert(NEnd >= NBegin, "invalid range"); + serialize_helper(ser, Empty{}, make_index_range()); + } + template + void deserialize_params(StringSerializer& ser) { + static_assert(NEnd >= NBegin, "invalid range"); + deserialize_helper(ser, Empty{}, make_index_range()); + } + template + void set_values(TArgs&&... args) { + set_values_helper(make_index_range(), std::forward(args)...); + } +}; + + +template +class RetParam { +public: + T value; + Empty serialize(StringSerializer& ser, Empty) { + ser.write_plain(value); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + value = ser.read_plain(); + return Empty{}; + } +}; + + +template , typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> +class FunctionCacheBuilder { +private: + static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } + template + static auto declfunction_helper(std::index_sequence) -> thin_function().value)(decltype(std::get(declargs()).value)...)> { return {}; } + static auto declfunction() { + return declfunction_helper(std::make_index_sequence::value + std::tuple_size::value>()); + } + template + static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))...> { return {}; } + static auto declbundle() { + return declbundle_helper(std::make_index_sequence::value+std::tuple_size::value>()); + } + using function_t = decltype(declfunction()); + using bundle_t = decltype(declbundle()); +public: + template + auto ret() { + static_assert(std::is_same>::value, "return value redefinition"); + return FunctionCacheBuilder{}; + } + template + auto input() { + using TNewInputs = decltype(std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); + return FunctionCacheBuilder{}; + } + template + auto output() { + using TNewOutputs = decltype(std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); + return FunctionCacheBuilder{}; + } + template + function_t build(TFunctor func) { + FunctionCache cache; + cache.key_mapper = [](bundle_t bundle) { + StringSerializer ser; + bundle.template serialize_params<0, std::tuple_size::value>(ser); + return ser.take(); + }; + cache.value_mapper = [=](bundle_t bundle) { + StringSerializer ser; + TRet ret; + ret.value = bundle.call_by(func); + ret.serialize(ser, Empty{}); + bundle.template serialize_params::value, std::tuple_size::value+std::tuple_size::value>(ser); + return ser.take(); + }; + return [=](auto&&... args) mutable { + bundle_t bundle; + TRet ret; + StringSerializer ser; + static_assert(sizeof...(args) == std::tuple_size::value+std::tuple_size::value, + "arg count mismatch"); + bundle.template set_values<0, sizeof...(args)>(std::forward(args)...); + ser.set(cache(bundle)); + ret.deserialize(ser, Empty{}); + constexpr size_t n_inputs = std::tuple_size::value; + constexpr size_t n_outputs = std::tuple_size::value; + bundle.template deserialize_params(ser); + return ret.value; + }; + } +}; + + +template +class PlainParam { +public: + T value; + Empty serialize(StringSerializer& ser, Empty) { + ser.write_plain(value); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + value = ser.read_plain(); + return Empty{}; + } +}; + + +template +class RefParam { +public: + T* value; + Empty serialize(StringSerializer& ser, Empty) { + ser.write_plain(*value); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + *value = ser.read_plain(); + return Empty{}; + } +}; + + +template +class RefArraySizeParam { +public: + T* value; + T serialize(StringSerializer& ser, Empty) { + ser.write_plain(*value); + return *value; + } + T deserialize(StringSerializer& ser, Empty) { + return *value = ser.read_plain(); + } +}; + + +template +class ArrayParam { +public: + TItem* value; + Empty serialize(StringSerializer& ser, TSize size) { + for (TSize i = 0; i < size; ++i) { + ser.write_plain(value[i]); + } + return Empty{}; + } + Empty deserialize(StringSerializer& ser, TSize size) { + for (TSize i = 0; i < size; ++i) { + value[i] = ser.read_plain(); + } + return Empty{}; + } +}; + +} diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..c1531ea7efed154b0ae4c4d8d4d4399d111eebe7 --- /dev/null +++ b/dnn/src/cuda/api_cache.h @@ -0,0 +1,120 @@ +/** + * \file dnn/src/cuda/api_cache.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "src/common/api_cache.h" +#include "src/cuda/cudnn_wrapper.h" + +namespace megdnn { + class CudnnConvDescParam { + public: + cudnnConvolutionDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + int ndim = MEGDNN_MAX_NDIM; + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; + cudnnConvolutionMode_t mode; + cudnnDataType_t computeType; + cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); + ser.write_plain(ndim); + for (int i = 0; i < ndim; ++i) { + ser.write_plain(padA[i]); + ser.write_plain(strideA[i]); + ser.write_plain(dilationA[i]); + } + ser.write_plain(mode); + ser.write_plain(computeType); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + int ndim = ser.read_plain(); + int padA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + int dilationA[MEGDNN_MAX_NDIM]; + for (int i = 0; i < ndim; ++i) { + padA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); + dilationA[i] = ser.read_plain(); + } + cudnnConvolutionMode_t mode = ser.read_plain(); + cudnnDataType_t computeType = ser.read_plain(); + cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); + return Empty{}; + } + }; + class CudnnTensorDescParam { + public: + cudnnTensorDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + int dimA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); + ser.write_plain(nbDims); + for (int i = 0; i < nbDims; ++i) { + ser.write_plain(dimA[i]); + ser.write_plain(strideA[i]); + } + ser.write_plain(dataType); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + int dimA[MEGDNN_MAX_NDIM]; + int strideA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + dimA[i] = ser.read_plain(); + strideA[i] = ser.read_plain(); + } + dataType = ser.read_plain(); + cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); + return Empty{}; + } + }; + class CudnnFilterDescParam { + public: + cudnnFilterDescriptor_t value; + Empty serialize(StringSerializer& ser, Empty) { + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + cudnnTensorFormat_t format; + int filterDimA[MEGDNN_MAX_NDIM]; + cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); + ser.write_plain(nbDims); + for (int i = 0; i < nbDims; ++i) { + ser.write_plain(filterDimA[i]); + } + ser.write_plain(dataType); + ser.write_plain(format); + return Empty{}; + } + Empty deserialize(StringSerializer& ser, Empty) { + int nbDims = MEGDNN_MAX_NDIM; + cudnnDataType_t dataType; + cudnnTensorFormat_t format; + int filterDimA[MEGDNN_MAX_NDIM]; + nbDims = ser.read_plain(); + for (int i = 0; i < nbDims; ++i) { + filterDimA[i] = ser.read_plain(); + } + dataType = ser.read_plain(); + format = ser.read_plain(); + cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); + return Empty{}; + } + }; +}