/** * \file dnn/src/common/api_cache.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include #include #include #include #include "megdnn/thin/function.h" namespace megdnn { template class FunctionCache { public: using key_t = std::string; using value_t = std::string; using key_mapper_t = thin_function; using value_mapper_t = thin_function; using storage_t = std::unordered_map; storage_t storage; key_mapper_t key_mapper; value_mapper_t value_mapper; value_t operator()(TArgs... args) { key_t key = key_mapper(args...); if (storage.count(key) == 0) { storage[key] = value_mapper(std::forward(args)...); } return storage[key]; } }; // FIFO class StringSerializer { private: std::string m_buffer; size_t m_cursor = 0; public: template T read_plain() { static_assert(std::is_trivially_copyable::value, "invalid type"); T ret; memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); m_cursor += sizeof(T); return ret; } template void write_plain(T value) { static_assert(std::is_trivially_copyable::value, "type should be trivially copyable"); m_buffer.append(reinterpret_cast(&value), sizeof(T)); } std::string take() { return std::move(m_buffer); } void set(std::string new_buf) { m_cursor = 0; m_buffer = new_buf; } }; struct Empty {}; template class ParamBundle { private: template static std::index_sequence add_all( std::index_sequence) { return {}; } template using make_index_range = decltype(add_all(std::make_index_sequence())); using storage_t = std::tuple...>; storage_t m_storage; template auto call_helper(TFunctor functor, std::index_sequence) { return functor(std::get(m_storage).value...); } template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { return serialize_helper(ser, std::get(m_storage).serialize(ser, prev), std::index_sequence()); } template auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence) { return deserialize_helper( ser, std::get(m_storage).deserialize(ser, prev), std::index_sequence()); } template auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} template void set_values_helper(std::index_sequence, TArg&& arg, TArgs&&... args) { std::get(m_storage).value = arg; set_values_helper(std::index_sequence(), std::forward(args)...); } template void set_values_helper(std::index_sequence) { static_assert(sizeof...(Indices) == 0, "redundant indices"); } public: template auto call_by(TFunctor&& functor) { return call_helper(std::forward(functor), std::make_index_sequence()); } template void serialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); serialize_helper( ser, Empty{}, add_all(std::make_index_sequence())); } template void deserialize_params(StringSerializer& ser) { static_assert(NEnd >= NBegin, "invalid range"); deserialize_helper( ser, Empty{}, add_all(std::make_index_sequence())); } template void set_values(TArgs&&... args) { set_values_helper( add_all(std::make_index_sequence()), std::forward(args)...); } }; template class Param { public: T value; Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(value); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { value = ser.read_plain(); return Empty{}; } }; template , typename TInputs = std::tuple<>, typename TOutputs = std::tuple<>> class FunctionCacheBuilder { private: static auto declargs() -> decltype(std::tuple_cat(std::declval(), std::declval())) { return {}; } template static auto declfunction_helper(std::index_sequence) -> thin_function().value)( decltype(std::get(declargs()).value)...)> { return {}; } static auto declfunction() { return declfunction_helper( std::make_index_sequence::value + std::tuple_size::value>()); } template static auto declbundle_helper(std::index_sequence) -> ParamBundle(declargs()))...> { return {}; } static auto declbundle() { return declbundle_helper( std::make_index_sequence::value + std::tuple_size::value>()); } using function_t = decltype(declfunction()); using bundle_t = decltype(declbundle()); public: template auto ret() { static_assert(std::is_same>::value, "return value redefinition"); return FunctionCacheBuilder{}; } template auto input() { using TNewInputs = decltype( std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } template auto output() { using TNewOutputs = decltype( std::tuple_cat(std::declval(), std::make_tuple(std::declval()))); return FunctionCacheBuilder{}; } template function_t build(TFunctor func) { FunctionCache cache; cache.key_mapper = [](bundle_t bundle) { StringSerializer ser; bundle.template serialize_params<0, std::tuple_size::value>( ser); return ser.take(); }; cache.value_mapper = [=](bundle_t bundle) { StringSerializer ser; TRet ret; ret.value = bundle.call_by(func); ret.serialize(ser, Empty{}); bundle.template serialize_params< std::tuple_size::value, std::tuple_size::value + std::tuple_size::value>(ser); return ser.take(); }; return [=](auto&&... args) mutable { bundle_t bundle; TRet ret; StringSerializer ser; static_assert( sizeof...(args) == std::tuple_size::value + std::tuple_size::value, "args count mismatch"); bundle.template set_values<0, sizeof...(args)>( std::forward(args)...); ser.set(cache(bundle)); ret.deserialize(ser, Empty{}); constexpr size_t n_inputs = std::tuple_size::value; constexpr size_t n_outputs = std::tuple_size::value; bundle.template deserialize_params( ser); return ret.value; }; } }; template class RefParam { public: T* value; Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(*value); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { *value = ser.read_plain(); return Empty{}; } }; template class RefArraySizeParam { public: T* value; T serialize(StringSerializer& ser, Empty) { ser.write_plain(*value); return *value; } T deserialize(StringSerializer& ser, Empty) { return *value = ser.read_plain(); } }; template class ArrayParam { public: TItem* value; Empty serialize(StringSerializer& ser, TSize size) { for (TSize i = 0; i < size; ++i) { ser.write_plain(value[i]); } return Empty{}; } Empty deserialize(StringSerializer& ser, TSize size) { for (TSize i = 0; i < size; ++i) { value[i] = ser.read_plain(); } return Empty{}; } }; } // namespace megdnn