提交 64c922c4 编写于 作者: M Megvii Engine Team

Revert "fix(api_cache): fix serialization for conv_desc"

This reverts commit 95dbc9c685cced46dd910997bd585363c392ccbd.

GitOrigin-RevId: ca8c67b6b307b547603e083554b7ffd5736e2d98
上级 4e95c136
...@@ -131,18 +131,12 @@ public: ...@@ -131,18 +131,12 @@ public:
T read_plain() { T read_plain() {
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); static_assert(std::is_trivially_copyable<T>::value, "invalid type");
T ret; T ret;
std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
m_cursor += sizeof(T); m_cursor += sizeof(T);
return ret; return ret;
} }
template <typename T> template <typename T>
void read_plain(T* dest) { void write_plain(T value) {
static_assert(std::is_trivially_copyable<T>::value, "invalid type");
std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T));
m_cursor += sizeof(T);
}
template <typename T>
void write_plain(const T& value) {
static_assert(std::is_trivially_copyable<T>::value, static_assert(std::is_trivially_copyable<T>::value,
"type should be trivially copyable"); "type should be trivially copyable");
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
...@@ -150,7 +144,7 @@ public: ...@@ -150,7 +144,7 @@ public:
std::string take() { return std::move(m_buffer); } std::string take() { return std::move(m_buffer); }
void reset(std::string new_buf) { void reset(std::string new_buf) {
m_cursor = 0; m_cursor = 0;
m_buffer = std::move(new_buf); m_buffer = new_buf;
} }
}; };
...@@ -159,7 +153,7 @@ struct Empty {}; ...@@ -159,7 +153,7 @@ struct Empty {};
// in: seq[1, 2, ..., m] // in: seq[1, 2, ..., m]
// out: seq[N+1, N+2, ... N+m] // out: seq[N+1, N+2, ... N+m]
template <std::size_t N, std::size_t... Seq> template <std::size_t N, std::size_t... Seq>
inline std::index_sequence<N + Seq...> inc_index_sequence( static std::index_sequence<N + Seq...> inc_index_sequence(
std::index_sequence<Seq...>) { std::index_sequence<Seq...>) {
return {}; return {};
} }
...@@ -178,7 +172,7 @@ private: ...@@ -178,7 +172,7 @@ private:
// deconstruct tuple and call functor // deconstruct tuple and call functor
template <typename TFunctor, size_t... Indices> template <typename TFunctor, size_t... Indices>
auto call_helper(TFunctor&& functor, std::index_sequence<Indices...>) { auto call_helper(TFunctor functor, std::index_sequence<Indices...>) {
return functor(std::get<Indices>(m_storage).value...); return functor(std::get<Indices>(m_storage).value...);
} }
...@@ -209,7 +203,7 @@ private: ...@@ -209,7 +203,7 @@ private:
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg,
TArgs&&... args) { TArgs&&... args) {
std::get<Index>(m_storage).value = std::forward<TArg>(arg); std::get<Index>(m_storage).value = arg;
set_values_helper(std::index_sequence<Indices...>(), set_values_helper(std::index_sequence<Indices...>(),
std::forward<TArgs>(args)...); std::forward<TArgs>(args)...);
} }
...@@ -259,7 +253,7 @@ public: ...@@ -259,7 +253,7 @@ public:
} }
Empty deserialize(StringSerializer& ser, Empty) { Empty deserialize(StringSerializer& ser, Empty) {
ser.read_plain(&value); value = ser.read_plain<T>();
return Empty{}; return Empty{};
} }
}; };
...@@ -291,8 +285,7 @@ private: ...@@ -291,8 +285,7 @@ private:
template <size_t... Indices> template <size_t... Indices>
static auto declbundle_helper(std::index_sequence<Indices...>) static auto declbundle_helper(std::index_sequence<Indices...>)
-> ParamBundle<std::remove_reference_t< -> ParamBundle<decltype(std::get<Indices>(declargs()))...> {
decltype(std::get<Indices>(declargs()))>...> {
return {}; return {};
} }
...@@ -319,11 +312,9 @@ public: ...@@ -319,11 +312,9 @@ public:
// declare new input // declare new input
template <typename TNewInput> template <typename TNewInput>
auto input() { auto input() {
static_assert(std::tuple_size<TOutputs>::value == 0, using TNewInputs = decltype(
"input arg cannot be declared after output"); std::tuple_cat(std::declval<TInputs>(),
using TNewInputs = std::make_tuple(std::declval<TNewInput>())));
decltype(std::tuple_cat(std::declval<TInputs>(),
std::declval<std::tuple<TNewInput>>()));
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
} }
// declare new output // declare new output
...@@ -331,29 +322,31 @@ public: ...@@ -331,29 +322,31 @@ public:
auto output() { auto output() {
using TNewOutputs = decltype( using TNewOutputs = decltype(
std::tuple_cat(std::declval<TOutputs>(), std::tuple_cat(std::declval<TOutputs>(),
std::declval<std::tuple<TNewOutput>>())); std::make_tuple(std::declval<TNewOutput>())));
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
} }
// summary // summary
template <typename TFunctor> template <typename TFunctor>
function_t build(TFunctor&& func) { function_t build(TFunctor func) {
constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>(); auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>();
// bundle -> ser(in args) // bundle -> ser(in args)
cache->key_mapper = [](bundle_t bundle) { cache->key_mapper = [](bundle_t bundle) {
StringSerializer ser; StringSerializer ser;
bundle.template serialize_params<0, n_inputs>(ser); bundle.template serialize_params<0,
std::tuple_size<TInputs>::value>(
ser);
return ser.take(); return ser.take();
}; };
// bundle -> ser(out args) // bundle -> ser(out args)
cache->value_mapper = [func](bundle_t bundle) { cache->value_mapper = [=](bundle_t bundle) {
StringSerializer ser; StringSerializer ser;
TRet ret; TRet ret;
ret.value = bundle.call_by(func); ret.value = bundle.call_by(func);
ret.serialize(ser, Empty{}); ret.serialize(ser, Empty{});
bundle.template serialize_params<n_inputs, n_inputs + n_outputs>( bundle.template serialize_params<
ser); std::tuple_size<TInputs>::value,
std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>(ser);
return ser.take(); return ser.take();
}; };
return [=](auto&&... args) mutable { return [=](auto&&... args) mutable {
...@@ -368,6 +361,8 @@ public: ...@@ -368,6 +361,8 @@ public:
std::forward<decltype(args)>(args)...); std::forward<decltype(args)>(args)...);
ser.reset((*cache)(bundle)); ser.reset((*cache)(bundle));
ret.deserialize(ser, Empty{}); ret.deserialize(ser, Empty{});
constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>(
ser); ser);
return ret.value; return ret.value;
...@@ -399,8 +394,7 @@ public: ...@@ -399,8 +394,7 @@ public:
return *value; return *value;
} }
T deserialize(StringSerializer& ser, Empty) { T deserialize(StringSerializer& ser, Empty) {
ser.read_plain(value); return *value = ser.read_plain<T>();
return *value;
} }
}; };
...@@ -408,20 +402,16 @@ public: ...@@ -408,20 +402,16 @@ public:
template <typename TSize, typename TItem> template <typename TSize, typename TItem>
class ArrayParam { class ArrayParam {
public: public:
decltype(std::declval<TItem>().value)* value; TItem* value;
Empty serialize(StringSerializer& ser, TSize size) { Empty serialize(StringSerializer& ser, TSize size) {
TItem param;
for (TSize i = 0; i < size; ++i) { for (TSize i = 0; i < size; ++i) {
param.value = value[i]; ser.write_plain(value[i]);
param.serialize(ser, Empty{});
} }
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, TSize size) { Empty deserialize(StringSerializer& ser, TSize size) {
TItem param;
for (TSize i = 0; i < size; ++i) { for (TSize i = 0; i < size; ++i) {
param.deserialize(ser, Empty{}); value[i] = ser.read_plain<TItem>();
value[i] = param.value;
} }
return Empty{}; return Empty{};
} }
......
...@@ -20,16 +20,14 @@ class CudnnConvDescParam { ...@@ -20,16 +20,14 @@ class CudnnConvDescParam {
public: public:
cudnnConvolutionDescriptor_t value; cudnnConvolutionDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
constexpr int maxNbDims = CUDNN_DIM_MAX - 2; int nbDims = MEGDNN_MAX_NDIM;
int nbDims = maxNbDims; int padA[MEGDNN_MAX_NDIM];
int padA[maxNbDims]; int strideA[MEGDNN_MAX_NDIM];
int strideA[maxNbDims]; int dilationA[MEGDNN_MAX_NDIM];
int dilationA[maxNbDims];
cudnnConvolutionMode_t mode; cudnnConvolutionMode_t mode;
cudnnDataType_t computeType; cudnnDataType_t computeType;
cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA, cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA,
strideA, dilationA, &mode, dilationA, &mode, &computeType);
&computeType);
ser.write_plain(nbDims); ser.write_plain(nbDims);
for (int i = 0; i < nbDims; ++i) { for (int i = 0; i < nbDims; ++i) {
ser.write_plain(padA[i]); ser.write_plain(padA[i]);
...@@ -40,8 +38,23 @@ public: ...@@ -40,8 +38,23 @@ public:
ser.write_plain(computeType); ser.write_plain(computeType);
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, Empty) {
int ndim = ser.read_plain<int>();
int padA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
int dilationA[MEGDNN_MAX_NDIM];
for (int i = 0; i < ndim; ++i) {
padA[i] = ser.read_plain<int>();
strideA[i] = ser.read_plain<int>();
dilationA[i] = ser.read_plain<int>();
}
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>();
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>();
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA,
mode, computeType);
return Empty{};
}
}; };
class CudnnTensorDescParam { class CudnnTensorDescParam {
public: public:
cudnnTensorDescriptor_t value; cudnnTensorDescriptor_t value;
...@@ -50,8 +63,8 @@ public: ...@@ -50,8 +63,8 @@ public:
cudnnDataType_t dataType; cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM]; int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM];
cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims, cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA,
dimA, strideA); strideA);
ser.write_plain(nbDims); ser.write_plain(nbDims);
for (int i = 0; i < nbDims; ++i) { for (int i = 0; i < nbDims; ++i) {
ser.write_plain(dimA[i]); ser.write_plain(dimA[i]);
...@@ -60,8 +73,21 @@ public: ...@@ -60,8 +73,21 @@ public:
ser.write_plain(dataType); ser.write_plain(dataType);
return Empty{}; return Empty{};
} }
Empty deserialize(StringSerializer& ser, Empty) {
int nbDims = MEGDNN_MAX_NDIM;
cudnnDataType_t dataType;
int dimA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM];
nbDims = ser.read_plain<int>();
for (int i = 0; i < nbDims; ++i) {
dimA[i] = ser.read_plain<int>();
strideA[i] = ser.read_plain<int>();
}
dataType = ser.read_plain<cudnnDataType_t>();
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA);
return Empty{};
}
}; };
class CudnnFilterDescParam { class CudnnFilterDescParam {
public: public:
cudnnFilterDescriptor_t value; cudnnFilterDescriptor_t value;
...@@ -80,29 +106,18 @@ public: ...@@ -80,29 +106,18 @@ public:
ser.write_plain(format); ser.write_plain(format);
return Empty{}; return Empty{};
} }
};
template <typename T>
class CudnnConvAlgoPerfParam {
public:
T value;
Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(value.algo);
ser.write_plain(value.status);
ser.write_plain(value.time);
ser.write_plain(value.memory);
ser.write_plain(value.determinism);
ser.write_plain(value.mathType);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) { Empty deserialize(StringSerializer& ser, Empty) {
ser.read_plain(&value.algo); int nbDims = MEGDNN_MAX_NDIM;
ser.read_plain(&value.status); cudnnDataType_t dataType;
ser.read_plain(&value.time); cudnnTensorFormat_t format;
ser.read_plain(&value.memory); int filterDimA[MEGDNN_MAX_NDIM];
ser.read_plain(&value.determinism); nbDims = ser.read_plain<int>();
ser.read_plain(&value.mathType); for (int i = 0; i < nbDims; ++i) {
filterDimA[i] = ser.read_plain<int>();
}
dataType = ser.read_plain<cudnnDataType_t>();
format = ser.read_plain<cudnnTensorFormat_t>();
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA);
return Empty{}; return Empty{};
} }
}; };
......
...@@ -168,8 +168,7 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { ...@@ -168,8 +168,7 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.input<CudnnTensorDescParam>() .input<CudnnTensorDescParam>()
.input<Param<int>>() .input<Param<int>>()
.output<RefArraySizeParam<int>>() .output<RefArraySizeParam<int>>()
.output<ArrayParam<int, .output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>()
Param<cudnnConvolutionFwdAlgoPerf_t>>>()
.ret<Param<cudnnStatus_t>>() .ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionForwardAlgorithm_v7); .build(&cudnnGetConvolutionForwardAlgorithm_v7);
GetConvolutionForwardAlgorithmMaxCount = GetConvolutionForwardAlgorithmMaxCount =
...@@ -200,8 +199,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { ...@@ -200,8 +199,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.input<CudnnTensorDescParam>() .input<CudnnTensorDescParam>()
.input<Param<int>>() .input<Param<int>>()
.output<RefArraySizeParam<int>>() .output<RefArraySizeParam<int>>()
.output<ArrayParam< .output<ArrayParam<int,
int, Param<cudnnConvolutionBwdDataAlgoPerf_t>>>() cudnnConvolutionBwdDataAlgoPerf_t>>()
.ret<Param<cudnnStatus_t>>() .ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7);
GetConvolutionBackwardDataAlgorithmMaxCount = GetConvolutionBackwardDataAlgorithmMaxCount =
...@@ -232,8 +231,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { ...@@ -232,8 +231,8 @@ HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
.input<CudnnFilterDescParam>() .input<CudnnFilterDescParam>()
.input<Param<int>>() .input<Param<int>>()
.output<RefArraySizeParam<int>>() .output<RefArraySizeParam<int>>()
.output<ArrayParam< .output<ArrayParam<int,
int, Param<cudnnConvolutionBwdFilterAlgoPerf_t>>>() cudnnConvolutionBwdFilterAlgoPerf_t>>()
.ret<Param<cudnnStatus_t>>() .ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7);
GetConvolutionBackwardFilterAlgorithmMaxCount = GetConvolutionBackwardFilterAlgorithmMaxCount =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册