提交 729ee649 编写于 作者: M Megvii Engine Team

Revert "fix(api_cache): lock api cache for thread safety"

This reverts commit 8a244677c3d2d8b2a7eaafb51c3ac13e2dfc55d6.

GitOrigin-RevId: 582488adeb8046ec14d0ca8c6d4af2f13accd9fb
上级 64c922c4
...@@ -12,79 +12,19 @@ ...@@ -12,79 +12,19 @@
#pragma once #pragma once
#include <atomic>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
#include <mutex>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include "megdnn/thin/function.h" #include "megdnn/thin/function.h"
#include "./utils.h"
namespace megdnn { namespace megdnn {
template <typename... TArgs>
// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ class FunctionCache {
class RWSpin {
public:
class Lock {
private:
RWSpin* m_spin;
void (RWSpin::*m_lock)(void);
void (RWSpin::*m_unlock)(void);
public:
Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock)
: m_spin{spin}, m_lock{lock}, m_unlock{unlock} {}
void lock() { (m_spin->*m_lock)(); }
void unlock() { (m_spin->*m_unlock)(); }
};
private:
std::atomic<uint32_t> m_atomic{0};
static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF;
static constexpr uint32_t sm_writer_mask = 0x80000000;
void _reader_lock() {
uint32_t expected = m_atomic;
do {
expected &= sm_reader_mask;
} while (!m_atomic.compare_exchange_strong(expected, expected + 1));
}
void _reader_unlock() { m_atomic--; }
void _writer_lock() {
uint32_t expected = m_atomic;
do {
expected &= sm_reader_mask;
} while (!m_atomic.compare_exchange_strong(expected,
expected | sm_writer_mask));
while (m_atomic.load() != sm_writer_mask)
;
}
void _writer_unlock() {
// assert m_atomic == sm_writer_mask
m_atomic = 0;
}
public:
Lock reader() {
return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock};
}
Lock writer() {
return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock};
}
};
template <typename TSignature>
class FunctionCache;
template <typename TRet, typename... TArgs>
class FunctionCache<TRet(TArgs...)> {
public: public:
using key_t = std::string; using key_t = std::string;
using value_t = TRet; using value_t = std::string;
using key_mapper_t = thin_function<key_t(TArgs...)>; using key_mapper_t = thin_function<key_t(TArgs...)>;
using value_mapper_t = thin_function<value_t(TArgs...)>; using value_mapper_t = thin_function<value_t(TArgs...)>;
using storage_t = std::unordered_map<key_t, value_t>; using storage_t = std::unordered_map<key_t, value_t>;
...@@ -93,30 +33,12 @@ public: ...@@ -93,30 +33,12 @@ public:
key_mapper_t key_mapper; key_mapper_t key_mapper;
value_mapper_t value_mapper; value_mapper_t value_mapper;
RWSpin spin; value_t operator()(TArgs... args) {
public:
TRet operator()(TArgs... args) {
key_t key = key_mapper(args...); key_t key = key_mapper(args...);
auto reader_lock = spin.reader(); if (storage.count(key) == 0) {
auto writer_lock = spin.writer(); storage[key] = value_mapper(std::forward<TArgs>(args)...);
{
MEGDNN_LOCK_GUARD(reader_lock);
auto iter = storage.find(key);
if (iter != storage.end()) {
return iter->second;
}
}
// RWSpin doesn't support upgrade
{
MEGDNN_LOCK_GUARD(writer_lock);
if (storage.count(key) != 0) {
return storage[key];
}
value_t ret = value_mapper(std::forward<TArgs>(args)...);
storage[key] = ret;
return ret;
} }
return storage[key];
} }
}; };
...@@ -129,8 +51,8 @@ private: ...@@ -129,8 +51,8 @@ private:
public: public:
template <typename T> template <typename T>
T read_plain() { T read_plain() {
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); static_assert(std::is_trivially_copyable<T>::value, "invalid type");
T ret; T ret;
memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
m_cursor += sizeof(T); m_cursor += sizeof(T);
return ret; return ret;
...@@ -141,8 +63,10 @@ public: ...@@ -141,8 +63,10 @@ public:
"type should be trivially copyable"); "type should be trivially copyable");
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
} }
std::string take() { return std::move(m_buffer); } std::string take() {
void reset(std::string new_buf) { return std::move(m_buffer);
}
void set(std::string new_buf) {
m_cursor = 0; m_cursor = 0;
m_buffer = new_buf; m_buffer = new_buf;
} }
...@@ -150,32 +74,26 @@ public: ...@@ -150,32 +74,26 @@ public:
struct Empty {}; struct Empty {};
// in: seq[1, 2, ..., m]
// out: seq[N+1, N+2, ... N+m]
template <std::size_t N, std::size_t... Seq>
static std::index_sequence<N + Seq...> inc_index_sequence(
std::index_sequence<Seq...>) {
return {};
}
template <typename... TParams> template <typename... TParams>
class ParamBundle { class ParamBundle {
private: private:
// out: Min, Min+1, ..., Max template <std::size_t N, std::size_t... Seq>
static std::index_sequence<N + Seq...> add_all(
std::index_sequence<Seq...>) {
return {};
}
template <std::size_t Min, std::size_t Max> template <std::size_t Min, std::size_t Max>
using make_index_range = decltype( using make_index_range =
inc_index_sequence<Min>(std::make_index_sequence<Max - Min>())); decltype(add_all<Min>(std::make_index_sequence<Max - Min>()));
// store params in a tuple
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
storage_t m_storage; storage_t m_storage;
// deconstruct tuple and call functor
template <typename TFunctor, size_t... Indices> template <typename TFunctor, size_t... Indices>
auto call_helper(TFunctor functor, std::index_sequence<Indices...>) { auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
return functor(std::get<Indices>(m_storage).value...); return functor(std::get<Indices>(m_storage).value...);
} }
template <size_t Index, size_t... Indices, typename TPrev> template <size_t Index, size_t... Indices, typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, auto serialize_helper(StringSerializer& ser, TPrev&& prev,
std::index_sequence<Index, Indices...>) { std::index_sequence<Index, Indices...>) {
...@@ -183,11 +101,9 @@ private: ...@@ -183,11 +101,9 @@ private:
std::get<Index>(m_storage).serialize(ser, prev), std::get<Index>(m_storage).serialize(ser, prev),
std::index_sequence<Indices...>()); std::index_sequence<Indices...>());
} }
template <typename TPrev> template <typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, auto serialize_helper(StringSerializer& ser, TPrev&& prev,
std::index_sequence<>) {} std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TPrev> template <size_t Index, size_t... Indices, typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
std::index_sequence<Index, Indices...>) { std::index_sequence<Index, Indices...>) {
...@@ -195,11 +111,9 @@ private: ...@@ -195,11 +111,9 @@ private:
ser, std::get<Index>(m_storage).deserialize(ser, prev), ser, std::get<Index>(m_storage).deserialize(ser, prev),
std::index_sequence<Indices...>()); std::index_sequence<Indices...>());
} }
template <typename TPrev> template <typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, auto deserialize_helper(StringSerializer& ser, TPrev&& prev,
std::index_sequence<>) {} std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg,
TArgs&&... args) { TArgs&&... args) {
...@@ -207,7 +121,6 @@ private: ...@@ -207,7 +121,6 @@ private:
set_values_helper(std::index_sequence<Indices...>(), set_values_helper(std::index_sequence<Indices...>(),
std::forward<TArgs>(args)...); std::forward<TArgs>(args)...);
} }
template <size_t... Indices> template <size_t... Indices>
void set_values_helper(std::index_sequence<Indices...>) { void set_values_helper(std::index_sequence<Indices...>) {
static_assert(sizeof...(Indices) == 0, "redundant indices"); static_assert(sizeof...(Indices) == 0, "redundant indices");
...@@ -219,26 +132,25 @@ public: ...@@ -219,26 +132,25 @@ public:
return call_helper(std::forward<TFunctor>(functor), return call_helper(std::forward<TFunctor>(functor),
std::make_index_sequence<sizeof...(TParams)>()); std::make_index_sequence<sizeof...(TParams)>());
} }
// recursively store params into ser
template <size_t NBegin, size_t NEnd> template <size_t NBegin, size_t NEnd>
void serialize_params(StringSerializer& ser) { void serialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range"); static_assert(NEnd >= NBegin, "invalid range");
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); serialize_helper(
ser, Empty{},
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
} }
// recursively load params from ser
template <size_t NBegin, size_t NEnd> template <size_t NBegin, size_t NEnd>
void deserialize_params(StringSerializer& ser) { void deserialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range"); static_assert(NEnd >= NBegin, "invalid range");
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); deserialize_helper(
ser, Empty{},
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
} }
// recursively set params into m_storage
template <size_t NBegin, size_t NEnd, typename... TArgs> template <size_t NBegin, size_t NEnd, typename... TArgs>
void set_values(TArgs&&... args) { void set_values(TArgs&&... args) {
set_values_helper(make_index_range<NBegin, NEnd>(), set_values_helper(
std::forward<TArgs>(args)...); add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()),
std::forward<TArgs>(args)...);
} }
}; };
...@@ -246,12 +158,10 @@ template <typename T> ...@@ -246,12 +158,10 @@ template <typename T>
class Param { class Param {
public: public:
T value; T value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(value); ser.write_plain(value);
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, Empty) { Empty deserialize(StringSerializer& ser, Empty) {
value = ser.read_plain<T>(); value = ser.read_plain<T>();
return Empty{}; return Empty{};
...@@ -262,54 +172,42 @@ template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, ...@@ -262,54 +172,42 @@ template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
typename TOutputs = std::tuple<>> typename TOutputs = std::tuple<>>
class FunctionCacheBuilder { class FunctionCacheBuilder {
private: private:
// decl value with type of tuple-of-args
static auto declargs() static auto declargs()
-> decltype(std::tuple_cat(std::declval<TInputs>(), -> decltype(std::tuple_cat(std::declval<TInputs>(),
std::declval<TOutputs>())) { std::declval<TOutputs>())) {
return {}; return {};
} }
template <size_t... Indices> template <size_t... Indices>
static auto declfunction_helper(std::index_sequence<Indices...>) static auto declfunction_helper(std::index_sequence<Indices...>)
-> thin_function<decltype(std::declval<TRet>().value)( -> thin_function<decltype(std::declval<TRet>().value)(
decltype(std::get<Indices>(declargs()).value)...)> { decltype(std::get<Indices>(declargs()).value)...)> {
return {}; return {};
} }
// decl value with type of original function
static auto declfunction() { static auto declfunction() {
return declfunction_helper( return declfunction_helper(
std::make_index_sequence<std::tuple_size<TInputs>::value + std::make_index_sequence<std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>()); std::tuple_size<TOutputs>::value>());
} }
template <size_t... Indices> template <size_t... Indices>
static auto declbundle_helper(std::index_sequence<Indices...>) static auto declbundle_helper(std::index_sequence<Indices...>)
-> ParamBundle<decltype(std::get<Indices>(declargs()))...> { -> ParamBundle<decltype(std::get<Indices>(declargs()))...> {
return {}; return {};
} }
// decl value with type of bundle-of-args
static auto declbundle() { static auto declbundle() {
return declbundle_helper( return declbundle_helper(
std::make_index_sequence<std::tuple_size<TInputs>::value + std::make_index_sequence<std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>()); std::tuple_size<TOutputs>::value>());
} }
// type of original function
using function_t = decltype(declfunction()); using function_t = decltype(declfunction());
// type of bundle-of-args
using bundle_t = decltype(declbundle()); using bundle_t = decltype(declbundle());
public: public:
// declare new return type, cannot be override
template <typename TNewRet> template <typename TNewRet>
auto ret() { auto ret() {
static_assert(std::is_same<TRet, Param<Empty>>::value, static_assert(std::is_same<TRet, Param<Empty>>::value,
"return value redefinition"); "return value redefinition");
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
} }
// declare new input
template <typename TNewInput> template <typename TNewInput>
auto input() { auto input() {
using TNewInputs = decltype( using TNewInputs = decltype(
...@@ -317,7 +215,6 @@ public: ...@@ -317,7 +215,6 @@ public:
std::make_tuple(std::declval<TNewInput>()))); std::make_tuple(std::declval<TNewInput>())));
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
} }
// declare new output
template <typename TNewOutput> template <typename TNewOutput>
auto output() { auto output() {
using TNewOutputs = decltype( using TNewOutputs = decltype(
...@@ -325,20 +222,17 @@ public: ...@@ -325,20 +222,17 @@ public:
std::make_tuple(std::declval<TNewOutput>()))); std::make_tuple(std::declval<TNewOutput>())));
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
} }
// summary
template <typename TFunctor> template <typename TFunctor>
function_t build(TFunctor func) { function_t build(TFunctor func) {
auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>(); FunctionCache<bundle_t> cache;
// bundle -> ser(in args) cache.key_mapper = [](bundle_t bundle) {
cache->key_mapper = [](bundle_t bundle) {
StringSerializer ser; StringSerializer ser;
bundle.template serialize_params<0, bundle.template serialize_params<0,
std::tuple_size<TInputs>::value>( std::tuple_size<TInputs>::value>(
ser); ser);
return ser.take(); return ser.take();
}; };
// bundle -> ser(out args) cache.value_mapper = [=](bundle_t bundle) {
cache->value_mapper = [=](bundle_t bundle) {
StringSerializer ser; StringSerializer ser;
TRet ret; TRet ret;
ret.value = bundle.call_by(func); ret.value = bundle.call_by(func);
...@@ -359,7 +253,7 @@ public: ...@@ -359,7 +253,7 @@ public:
"args count mismatch"); "args count mismatch");
bundle.template set_values<0, sizeof...(args)>( bundle.template set_values<0, sizeof...(args)>(
std::forward<decltype(args)>(args)...); std::forward<decltype(args)>(args)...);
ser.reset((*cache)(bundle)); ser.set(cache(bundle));
ret.deserialize(ser, Empty{}); ret.deserialize(ser, Empty{});
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
...@@ -384,7 +278,6 @@ public: ...@@ -384,7 +278,6 @@ public:
} }
}; };
// like RefParam but return *value while ser and deser. Working with ArrayParam
template <typename T> template <typename T>
class RefArraySizeParam { class RefArraySizeParam {
public: public:
...@@ -398,7 +291,6 @@ public: ...@@ -398,7 +291,6 @@ public:
} }
}; };
// accept array length from previous param. Working with RefArraySizeParam
template <typename TSize, typename TItem> template <typename TSize, typename TItem>
class ArrayParam { class ArrayParam {
public: public:
......
...@@ -20,7 +20,7 @@ class CudnnConvDescParam { ...@@ -20,7 +20,7 @@ class CudnnConvDescParam {
public: public:
cudnnConvolutionDescriptor_t value; cudnnConvolutionDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM; constexpr int nbDims = MEGDNN_MAX_NDIM;
int padA[MEGDNN_MAX_NDIM]; int padA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM];
int dilationA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM];
...@@ -59,7 +59,7 @@ class CudnnTensorDescParam { ...@@ -59,7 +59,7 @@ class CudnnTensorDescParam {
public: public:
cudnnTensorDescriptor_t value; cudnnTensorDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM; constexpr int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType; cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM]; int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM];
...@@ -74,7 +74,7 @@ public: ...@@ -74,7 +74,7 @@ public:
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, Empty) { Empty deserialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM; constexpr int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType; cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM]; int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM];
...@@ -92,7 +92,7 @@ class CudnnFilterDescParam { ...@@ -92,7 +92,7 @@ class CudnnFilterDescParam {
public: public:
cudnnFilterDescriptor_t value; cudnnFilterDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM; constexpr int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType; cudnnDataType_t dataType;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
int filterDimA[MEGDNN_MAX_NDIM]; int filterDimA[MEGDNN_MAX_NDIM];
...@@ -107,7 +107,7 @@ public: ...@@ -107,7 +107,7 @@ public:
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, Empty) { Empty deserialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM; constexpr int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType; cudnnDataType_t dataType;
cudnnTensorFormat_t format; cudnnTensorFormat_t format;
int filterDimA[MEGDNN_MAX_NDIM]; int filterDimA[MEGDNN_MAX_NDIM];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册