diff --git a/dnn/src/common/api_cache.h b/dnn/src/common/api_cache.h index c39589bc5732987634cab63291fc2bbac921416a..9009f5e1a26f7de4ae2002f59408f9426111fe9f 100644 --- a/dnn/src/common/api_cache.h +++ b/dnn/src/common/api_cache.h @@ -12,19 +12,79 @@ #pragma once +#include #include #include +#include #include #include #include "megdnn/thin/function.h" +#include "./utils.h" + namespace megdnn { -template -class FunctionCache { + +// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ +class RWSpin { +public: + class Lock { + private: + RWSpin* m_spin; + void (RWSpin::*m_lock)(void); + void (RWSpin::*m_unlock)(void); + + public: + Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock) + : m_spin{spin}, m_lock{lock}, m_unlock{unlock} {} + void lock() { (m_spin->*m_lock)(); } + void unlock() { (m_spin->*m_unlock)(); } + }; + +private: + std::atomic m_atomic{0}; + + static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF; + static constexpr uint32_t sm_writer_mask = 0x80000000; + + void _reader_lock() { + uint32_t expected = m_atomic; + do { + expected &= sm_reader_mask; + } while (!m_atomic.compare_exchange_strong(expected, expected + 1)); + } + void _reader_unlock() { m_atomic--; } + void _writer_lock() { + uint32_t expected = m_atomic; + do { + expected &= sm_reader_mask; + } while (!m_atomic.compare_exchange_strong(expected, + expected | sm_writer_mask)); + while (m_atomic.load() != sm_writer_mask) + ; + } + void _writer_unlock() { + // assert m_atomic == sm_writer_mask + m_atomic = 0; + } + +public: + Lock reader() { + return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock}; + } + Lock writer() { + return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock}; + } +}; + +template +class FunctionCache; + +template +class FunctionCache { public: using key_t = std::string; - using value_t = std::string; + using value_t = TRet; using key_mapper_t = thin_function; using value_mapper_t = thin_function; using storage_t = std::unordered_map; @@ -33,12 +93,30 @@ public: key_mapper_t key_mapper; value_mapper_t value_mapper; - value_t operator()(TArgs... args) { + RWSpin spin; + +public: + TRet operator()(TArgs... args) { key_t key = key_mapper(args...); - if (storage.count(key) == 0) { - storage[key] = value_mapper(std::forward(args)...); + auto reader_lock = spin.reader(); + auto writer_lock = spin.writer(); + { + MEGDNN_LOCK_GUARD(reader_lock); + auto iter = storage.find(key); + if (iter != storage.end()) { + return iter->second; + } + } + // RWSpin doesn't support upgrade + { + MEGDNN_LOCK_GUARD(writer_lock); + if (storage.count(key) != 0) { + return storage[key]; + } + value_t ret = value_mapper(std::forward(args)...); + storage[key] = ret; + return ret; } - return storage[key]; } }; @@ -51,8 +129,8 @@ private: public: template T read_plain() { - static_assert(std::is_trivially_copyable::value, "invalid type"); - T ret; + static_assert(std::is_trivially_copyable::value, "invalid type"); + T ret; memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; @@ -63,10 +141,8 @@ public: "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); } - std::string take() { - return std::move(m_buffer); - } - void set(std::string new_buf) { + std::string take() { return std::move(m_buffer); } + void reset(std::string new_buf) { m_cursor = 0; m_buffer = new_buf; } @@ -74,26 +150,32 @@ public: struct Empty {}; +// in: seq[1, 2, ..., m] +// out: seq[N+1, N+2, ... N+m] +template +static std::index_sequence inc_index_sequence( + std::index_sequence) { + return {}; +} + template class ParamBundle { private: - template - static std::index_sequence add_all( - std::index_sequence) { - return {}; - } - + // out: Min, Min+1, ..., Max template - using make_index_range = - decltype(add_all(std::make_index_sequence())); + using make_index_range = decltype( + inc_index_sequence(std::make_index_sequence())); + // store params in a tuple using storage_t = std::tuple...>; storage_t m_storage; + // deconstruct tuple and call functor template auto call_helper(TFunctor functor, std::index_sequence) { return functor(std::get(m_storage).value...); } + template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { @@ -101,9 +183,11 @@ private: std::get(m_storage).serialize(ser, prev), std::index_sequence()); } + template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { @@ -111,9 +195,11 @@ private: ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); } + template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} + template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { @@ -121,6 +207,7 @@ private: set_values_helper(std::index_sequence(), std::forward(args)...); } + template void set_values_helper(std::index_sequence) { static_assert(sizeof...(Indices) == 0, "redundant indices"); @@ -132,25 +219,26 @@ public: return call_helper(std::forward(functor), std::make_index_sequence()); } + + // recursively store params into ser template void serialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - serialize_helper( - ser, Empty{}, - add_all(std::make_index_sequence())); + serialize_helper(ser, Empty{}, make_index_range()); } + + // recursively load params from ser template void deserialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); - deserialize_helper( - ser, Empty{}, - add_all(std::make_index_sequence())); + deserialize_helper(ser, Empty{}, make_index_range()); } + + // recursively set params into m_storage template void set_values(TArgs&&... args) { - set_values_helper( - add_all(std::make_index_sequence()), - std::forward(args)...); + set_values_helper(make_index_range(), + std::forward(args)...); } }; @@ -158,10 +246,12 @@ template class Param { public: T value; + Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(value); return Empty{}; } + Empty deserialize(StringSerializer& ser, Empty) { value = ser.read_plain(); return Empty{}; @@ -172,42 +262,54 @@ template , typename TInputs = std::tuple<>, typename TOutputs = std::tuple<>> class FunctionCacheBuilder { private: + // decl value with type of tuple-of-args static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } + template static auto declfunction_helper(std::index_sequence) -> thin_function().value)( decltype(std::get(declargs()).value)...)> { return {}; } + + // decl value with type of original function static auto declfunction() { return declfunction_helper( std::make_index_sequence::value + std::tuple_size::value>()); } + template static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))...> { return {}; } + + // decl value with type of bundle-of-args static auto declbundle() { return declbundle_helper( std::make_index_sequence::value + std::tuple_size::value>()); } + + // type of original function using function_t = decltype(declfunction()); + // type of bundle-of-args using bundle_t = decltype(declbundle()); public: + // declare new return type, cannot be override template auto ret() { static_assert(std::is_same>::value, "return value redefinition"); return FunctionCacheBuilder{}; } + // declare new input template auto input() { using TNewInputs = decltype( @@ -215,6 +317,7 @@ public: std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } + // declare new output template auto output() { using TNewOutputs = decltype( @@ -222,17 +325,20 @@ public: std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } + // summary template function_t build(TFunctor func) { - FunctionCache cache; - cache.key_mapper = [](bundle_t bundle) { + auto cache = std::make_shared>(); + // bundle -> ser(in args) + cache->key_mapper = [](bundle_t bundle) { StringSerializer ser; bundle.template serialize_params<0, std::tuple_size::value>( ser); return ser.take(); }; - cache.value_mapper = [=](bundle_t bundle) { + // bundle -> ser(out args) + cache->value_mapper = [=](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); @@ -253,7 +359,7 @@ public: "args count mismatch"); bundle.template set_values<0, sizeof...(args)>( std::forward(args)...); - ser.set(cache(bundle)); + ser.reset((*cache)(bundle)); ret.deserialize(ser, Empty{}); constexpr size_t n_inputs = std::tuple_size::value; constexpr size_t n_outputs = std::tuple_size::value; @@ -278,6 +384,7 @@ public: } }; +// like RefParam but return *value while ser and deser. Working with ArrayParam template class RefArraySizeParam { public: @@ -291,6 +398,7 @@ public: } }; +// accept array length from previous param. Working with RefArraySizeParam template class ArrayParam { public: diff --git a/dnn/src/cuda/api_cache.h b/dnn/src/cuda/api_cache.h index f6f51b754b82ab1a2f28c74db3eee3962ca4f667..f58f6d75b4230ac9259a34ef925f6f0ad6cdcf4c 100644 --- a/dnn/src/cuda/api_cache.h +++ b/dnn/src/cuda/api_cache.h @@ -20,7 +20,7 @@ class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - constexpr int nbDims = MEGDNN_MAX_NDIM; + int nbDims = MEGDNN_MAX_NDIM; int padA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM]; @@ -59,7 +59,7 @@ class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - constexpr int nbDims = MEGDNN_MAX_NDIM; + int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; @@ -74,7 +74,7 @@ public: return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { - constexpr int nbDims = MEGDNN_MAX_NDIM; + int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; @@ -92,7 +92,7 @@ class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { - constexpr int nbDims = MEGDNN_MAX_NDIM; + int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM]; @@ -107,7 +107,7 @@ public: return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { - constexpr int nbDims = MEGDNN_MAX_NDIM; + int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM];