提交 f1aa6922 编写于 作者: M Megvii Engine Team

feat(dnn/apicache): add generic apicache

GitOrigin-RevId: 40b8ac2ab62fc60f8c23b17463d9983ea70d16f0
上级 d7f0fc8f
/**
* \file dnn/src/common/api_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <unordered_map>
#include <memory>
#include <cstring>
#include <tuple>
#include "megdnn/thin/function.h"
namespace megdnn {
template <typename TSignature>
class FunctionCache;
template <typename TRet, typename... TArgs>
class FunctionCache<TRet(TArgs...)> {
public:
using key_t = std::string;
using value_t = TRet;
using key_mapper_t = thin_function<key_t(TArgs...)>;
using value_mapper_t = thin_function<value_t(TArgs...)>;
using storage_t = std::unordered_map<key_t, value_t>;
public:
storage_t storage;
key_mapper_t key_mapper;
value_mapper_t value_mapper;
public:
TRet operator()(TArgs... args) {
key_t key = key_mapper(args...);
if (storage.count(key) == 0) {
storage[key] = value_mapper(std::forward<TArgs>(args)...);
}
return storage[key];
}
};
// FIFO
class StringSerializer {
private:
std::string m_buffer;
size_t m_cursor = 0;
public:
template <typename T>
T read_plain() {
T result;
std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T));
m_cursor += sizeof(T);
return result;
}
template <typename T>
void write_plain(T value) {
m_buffer.resize(m_buffer.size() + sizeof(T));
std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T));
}
std::string take() {
std::string result;
m_buffer.erase(0, m_cursor);
return std::move(m_buffer);
}
void set(std::string new_buf) {
m_cursor = 0;
m_buffer = new_buf;
}
};
struct Empty {};
template <typename... TParams>
class ParamBundle {
private:
template<std::size_t N, std::size_t... Seq>
static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){
return {};
}
template<std::size_t Min, std::size_t Max>
using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>()));
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
storage_t m_storage;
template <typename TFunctor, size_t... Indices>
auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
return functor(std::get<Indices>(m_storage).value...);
}
template <size_t Index, size_t... Indices, typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>());
}
template <typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>());
}
template <typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) {
std::get<Index>(m_storage).value = arg;
set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...);
}
template <size_t... Indices>
void set_values_helper(std::index_sequence<Indices...>) {
static_assert(sizeof...(Indices) == 0, "redundant indices");
}
public:
template <typename TFunctor>
auto call_by(TFunctor&& functor) {
return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>());
}
template <size_t NBegin, size_t NEnd>
void serialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range");
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
}
template <size_t NBegin, size_t NEnd>
void deserialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range");
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
}
template <size_t NBegin, size_t NEnd, typename... TArgs>
void set_values(TArgs&&... args) {
set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...);
}
};
template <typename T>
class RetParam {
public:
T value;
Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(value);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
value = ser.read_plain<T>();
return Empty{};
}
};
template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>>
class FunctionCacheBuilder {
private:
static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; }
template <size_t... Indices>
static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; }
static auto declfunction() {
return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>());
}
template <size_t... Indices>
static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; }
static auto declbundle() {
return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>());
}
using function_t = decltype(declfunction());
using bundle_t = decltype(declbundle());
public:
template <typename TNewRet>
auto ret() {
static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition");
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
}
template <typename TNewInput>
auto input() {
using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>())));
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
}
template <typename TNewOutput>
auto output() {
using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>())));
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
}
template <typename TFunctor>
function_t build(TFunctor func) {
FunctionCache<std::string(bundle_t)> cache;
cache.key_mapper = [](bundle_t bundle) {
StringSerializer ser;
bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser);
return ser.take();
};
cache.value_mapper = [=](bundle_t bundle) {
StringSerializer ser;
TRet ret;
ret.value = bundle.call_by(func);
ret.serialize(ser, Empty{});
bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser);
return ser.take();
};
return [=](auto&&... args) mutable {
bundle_t bundle;
TRet ret;
StringSerializer ser;
static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value,
"arg count mismatch");
bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...);
ser.set(cache(bundle));
ret.deserialize(ser, Empty{});
constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser);
return ret.value;
};
}
};
template <typename T>
class PlainParam {
public:
T value;
Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(value);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
value = ser.read_plain<T>();
return Empty{};
}
};
template <typename T>
class RefParam {
public:
T* value;
Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(*value);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
*value = ser.read_plain<T>();
return Empty{};
}
};
template <typename T>
class RefArraySizeParam {
public:
T* value;
T serialize(StringSerializer& ser, Empty) {
ser.write_plain(*value);
return *value;
}
T deserialize(StringSerializer& ser, Empty) {
return *value = ser.read_plain<T>();
}
};
template <typename TSize, typename TItem>
class ArrayParam {
public:
TItem* value;
Empty serialize(StringSerializer& ser, TSize size) {
for (TSize i = 0; i < size; ++i) {
ser.write_plain(value[i]);
}
return Empty{};
}
Empty deserialize(StringSerializer& ser, TSize size) {
for (TSize i = 0; i < size; ++i) {
value[i] = ser.read_plain<TItem>();
}
return Empty{};
}
};
}
/**
* \file dnn/src/cuda/api_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/api_cache.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
class CudnnConvDescParam {
public:
cudnnConvolutionDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) {
int ndim = MEGDNN_MAX_NDIM;
int padA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
int dilationA[MEGDNN_MAX_NDIM];
cudnnConvolutionMode_t mode;
cudnnDataType_t computeType;
cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType);
ser.write_plain(ndim);
for (int i = 0; i < ndim; ++i) {
ser.write_plain(padA[i]);
ser.write_plain(strideA[i]);
ser.write_plain(dilationA[i]);
}
ser.write_plain(mode);
ser.write_plain(computeType);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
int ndim = ser.read_plain<int>();
int padA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
int dilationA[MEGDNN_MAX_NDIM];
for (int i = 0; i < ndim; ++i) {
padA[i] = ser.read_plain<int>();
strideA[i] = ser.read_plain<int>();
dilationA[i] = ser.read_plain<int>();
}
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>();
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>();
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType);
return Empty{};
}
};
class CudnnTensorDescParam {
public:
cudnnTensorDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA);
ser.write_plain(nbDims);
for (int i = 0; i < nbDims; ++i) {
ser.write_plain(dimA[i]);
ser.write_plain(strideA[i]);
}
ser.write_plain(dataType);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
nbDims = ser.read_plain<int>();
for (int i = 0; i < nbDims; ++i) {
dimA[i] = ser.read_plain<int>();
strideA[i] = ser.read_plain<int>();
}
dataType = ser.read_plain<cudnnDataType_t>();
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA);
return Empty{};
}
};
class CudnnFilterDescParam {
public:
cudnnFilterDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType;
cudnnTensorFormat_t format;
int filterDimA[MEGDNN_MAX_NDIM];
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA);
ser.write_plain(nbDims);
for (int i = 0; i < nbDims; ++i) {
ser.write_plain(filterDimA[i]);
}
ser.write_plain(dataType);
ser.write_plain(format);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType;
cudnnTensorFormat_t format;
int filterDimA[MEGDNN_MAX_NDIM];
nbDims = ser.read_plain<int>();
for (int i = 0; i < nbDims; ++i) {
filterDimA[i] = ser.read_plain<int>();
}
dataType = ser.read_plain<cudnnDataType_t>();
format = ser.read_plain<cudnnTensorFormat_t>();
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA);
return Empty{};
}
};
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册