提交 de363c04 编写于 作者: M Megvii Engine Team

Revert "perf(cuda/conv): cache serval cudnn api"

This reverts commit 188c62cdd65ba27793b0af89164c6eee0122b21f.

GitOrigin-RevId: 92a82b8cd9f1bed053c7be89fa61ffb05acd1279
上级 729ee649
...@@ -12,28 +12,32 @@ ...@@ -12,28 +12,32 @@
#pragma once #pragma once
#include <cstring> #include <unordered_map>
#include <memory> #include <memory>
#include <cstring>
#include <tuple> #include <tuple>
#include <unordered_map>
#include "megdnn/thin/function.h" #include "megdnn/thin/function.h"
namespace megdnn { namespace megdnn {
template <typename... TArgs>
class FunctionCache { template <typename TSignature>
class FunctionCache;
template <typename TRet, typename... TArgs>
class FunctionCache<TRet(TArgs...)> {
public: public:
using key_t = std::string; using key_t = std::string;
using value_t = std::string; using value_t = TRet;
using key_mapper_t = thin_function<key_t(TArgs...)>; using key_mapper_t = thin_function<key_t(TArgs...)>;
using value_mapper_t = thin_function<value_t(TArgs...)>; using value_mapper_t = thin_function<value_t(TArgs...)>;
using storage_t = std::unordered_map<key_t, value_t>; using storage_t = std::unordered_map<key_t, value_t>;
public:
storage_t storage; storage_t storage;
key_mapper_t key_mapper; key_mapper_t key_mapper;
value_mapper_t value_mapper; value_mapper_t value_mapper;
public:
value_t operator()(TArgs... args) { TRet operator()(TArgs... args) {
key_t key = key_mapper(args...); key_t key = key_mapper(args...);
if (storage.count(key) == 0) { if (storage.count(key) == 0) {
storage[key] = value_mapper(std::forward<TArgs>(args)...); storage[key] = value_mapper(std::forward<TArgs>(args)...);
...@@ -42,28 +46,28 @@ public: ...@@ -42,28 +46,28 @@ public:
} }
}; };
// FIFO // FIFO
class StringSerializer { class StringSerializer {
private: private:
std::string m_buffer; std::string m_buffer;
size_t m_cursor = 0; size_t m_cursor = 0;
public: public:
template <typename T> template <typename T>
T read_plain() { T read_plain() {
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); T result;
T ret; std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T));
memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T));
m_cursor += sizeof(T); m_cursor += sizeof(T);
return ret; return result;
} }
template <typename T> template <typename T>
void write_plain(T value) { void write_plain(T value) {
static_assert(std::is_trivially_copyable<T>::value, m_buffer.resize(m_buffer.size() + sizeof(T));
"type should be trivially copyable"); std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T));
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T));
} }
std::string take() { std::string take() {
std::string result;
m_buffer.erase(0, m_cursor);
return std::move(m_buffer); return std::move(m_buffer);
} }
void set(std::string new_buf) { void set(std::string new_buf) {
...@@ -72,20 +76,20 @@ public: ...@@ -72,20 +76,20 @@ public:
} }
}; };
struct Empty {}; struct Empty {};
template <typename... TParams> template <typename... TParams>
class ParamBundle { class ParamBundle {
private: private:
template <std::size_t N, std::size_t... Seq> template<std::size_t N, std::size_t... Seq>
static std::index_sequence<N + Seq...> add_all( static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){
std::index_sequence<Seq...>) {
return {}; return {};
} }
template <std::size_t Min, std::size_t Max> template<std::size_t Min, std::size_t Max>
using make_index_range = using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>()));
decltype(add_all<Min>(std::make_index_sequence<Max - Min>()));
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>;
storage_t m_storage; storage_t m_storage;
...@@ -95,31 +99,21 @@ private: ...@@ -95,31 +99,21 @@ private:
return functor(std::get<Indices>(m_storage).value...); return functor(std::get<Indices>(m_storage).value...);
} }
template <size_t Index, size_t... Indices, typename TPrev> template <size_t Index, size_t... Indices, typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
std::index_sequence<Index, Indices...>) { return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>());
return serialize_helper(ser,
std::get<Index>(m_storage).serialize(ser, prev),
std::index_sequence<Indices...>());
} }
template <typename TPrev> template <typename TPrev>
auto serialize_helper(StringSerializer& ser, TPrev&& prev, auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TPrev> template <size_t Index, size_t... Indices, typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) {
std::index_sequence<Index, Indices...>) { return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>());
return deserialize_helper(
ser, std::get<Index>(m_storage).deserialize(ser, prev),
std::index_sequence<Indices...>());
} }
template <typename TPrev> template <typename TPrev>
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {}
std::index_sequence<>) {}
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> template <size_t Index, size_t... Indices, typename TArg, typename... TArgs>
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) {
TArgs&&... args) {
std::get<Index>(m_storage).value = arg; std::get<Index>(m_storage).value = arg;
set_values_helper(std::index_sequence<Indices...>(), set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...);
std::forward<TArgs>(args)...);
} }
template <size_t... Indices> template <size_t... Indices>
void set_values_helper(std::index_sequence<Indices...>) { void set_values_helper(std::index_sequence<Indices...>) {
...@@ -129,33 +123,27 @@ private: ...@@ -129,33 +123,27 @@ private:
public: public:
template <typename TFunctor> template <typename TFunctor>
auto call_by(TFunctor&& functor) { auto call_by(TFunctor&& functor) {
return call_helper(std::forward<TFunctor>(functor), return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>());
std::make_index_sequence<sizeof...(TParams)>());
} }
template <size_t NBegin, size_t NEnd> template <size_t NBegin, size_t NEnd>
void serialize_params(StringSerializer& ser) { void serialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range"); static_assert(NEnd >= NBegin, "invalid range");
serialize_helper( serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
ser, Empty{},
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
} }
template <size_t NBegin, size_t NEnd> template <size_t NBegin, size_t NEnd>
void deserialize_params(StringSerializer& ser) { void deserialize_params(StringSerializer& ser) {
static_assert(NEnd >= NBegin, "invalid range"); static_assert(NEnd >= NBegin, "invalid range");
deserialize_helper( deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>());
ser, Empty{},
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()));
} }
template <size_t NBegin, size_t NEnd, typename... TArgs> template <size_t NBegin, size_t NEnd, typename... TArgs>
void set_values(TArgs&&... args) { void set_values(TArgs&&... args) {
set_values_helper( set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...);
add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()),
std::forward<TArgs>(args)...);
} }
}; };
template <typename T> template <typename T>
class Param { class RetParam {
public: public:
T value; T value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
...@@ -168,68 +156,45 @@ public: ...@@ -168,68 +156,45 @@ public:
} }
}; };
template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>,
typename TOutputs = std::tuple<>> template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>>
class FunctionCacheBuilder { class FunctionCacheBuilder {
private: private:
static auto declargs() static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; }
-> decltype(std::tuple_cat(std::declval<TInputs>(),
std::declval<TOutputs>())) {
return {};
}
template <size_t... Indices> template <size_t... Indices>
static auto declfunction_helper(std::index_sequence<Indices...>) static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; }
-> thin_function<decltype(std::declval<TRet>().value)(
decltype(std::get<Indices>(declargs()).value)...)> {
return {};
}
static auto declfunction() { static auto declfunction() {
return declfunction_helper( return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>());
std::make_index_sequence<std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>());
} }
template <size_t... Indices> template <size_t... Indices>
static auto declbundle_helper(std::index_sequence<Indices...>) static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; }
-> ParamBundle<decltype(std::get<Indices>(declargs()))...> {
return {};
}
static auto declbundle() { static auto declbundle() {
return declbundle_helper( return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>());
std::make_index_sequence<std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>());
} }
using function_t = decltype(declfunction()); using function_t = decltype(declfunction());
using bundle_t = decltype(declbundle()); using bundle_t = decltype(declbundle());
public: public:
template <typename TNewRet> template <typename TNewRet>
auto ret() { auto ret() {
static_assert(std::is_same<TRet, Param<Empty>>::value, static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition");
"return value redefinition");
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{};
} }
template <typename TNewInput> template <typename TNewInput>
auto input() { auto input() {
using TNewInputs = decltype( using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>())));
std::tuple_cat(std::declval<TInputs>(),
std::make_tuple(std::declval<TNewInput>())));
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{};
} }
template <typename TNewOutput> template <typename TNewOutput>
auto output() { auto output() {
using TNewOutputs = decltype( using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>())));
std::tuple_cat(std::declval<TOutputs>(),
std::make_tuple(std::declval<TNewOutput>())));
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{};
} }
template <typename TFunctor> template <typename TFunctor>
function_t build(TFunctor func) { function_t build(TFunctor func) {
FunctionCache<bundle_t> cache; FunctionCache<std::string(bundle_t)> cache;
cache.key_mapper = [](bundle_t bundle) { cache.key_mapper = [](bundle_t bundle) {
StringSerializer ser; StringSerializer ser;
bundle.template serialize_params<0, bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser);
std::tuple_size<TInputs>::value>(
ser);
return ser.take(); return ser.take();
}; };
cache.value_mapper = [=](bundle_t bundle) { cache.value_mapper = [=](bundle_t bundle) {
...@@ -237,33 +202,42 @@ public: ...@@ -237,33 +202,42 @@ public:
TRet ret; TRet ret;
ret.value = bundle.call_by(func); ret.value = bundle.call_by(func);
ret.serialize(ser, Empty{}); ret.serialize(ser, Empty{});
bundle.template serialize_params< bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser);
std::tuple_size<TInputs>::value,
std::tuple_size<TInputs>::value +
std::tuple_size<TOutputs>::value>(ser);
return ser.take(); return ser.take();
}; };
return [=](auto&&... args) mutable { return [=](auto&&... args) mutable {
bundle_t bundle; bundle_t bundle;
TRet ret; TRet ret;
StringSerializer ser; StringSerializer ser;
static_assert( static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value,
sizeof...(args) == std::tuple_size<TInputs>::value + "arg count mismatch");
std::tuple_size<TOutputs>::value, bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...);
"args count mismatch");
bundle.template set_values<0, sizeof...(args)>(
std::forward<decltype(args)>(args)...);
ser.set(cache(bundle)); ser.set(cache(bundle));
ret.deserialize(ser, Empty{}); ret.deserialize(ser, Empty{});
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; constexpr size_t n_inputs = std::tuple_size<TInputs>::value;
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; constexpr size_t n_outputs = std::tuple_size<TOutputs>::value;
bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser);
ser);
return ret.value; return ret.value;
}; };
} }
}; };
template <typename T>
class PlainParam {
public:
T value;
Empty serialize(StringSerializer& ser, Empty) {
ser.write_plain(value);
return Empty{};
}
Empty deserialize(StringSerializer& ser, Empty) {
value = ser.read_plain<T>();
return Empty{};
}
};
template <typename T> template <typename T>
class RefParam { class RefParam {
public: public:
...@@ -278,6 +252,7 @@ public: ...@@ -278,6 +252,7 @@ public:
} }
}; };
template <typename T> template <typename T>
class RefArraySizeParam { class RefArraySizeParam {
public: public:
...@@ -291,6 +266,7 @@ public: ...@@ -291,6 +266,7 @@ public:
} }
}; };
template <typename TSize, typename TItem> template <typename TSize, typename TItem>
class ArrayParam { class ArrayParam {
public: public:
...@@ -309,4 +285,4 @@ public: ...@@ -309,4 +285,4 @@ public:
} }
}; };
} // namespace megdnn }
...@@ -16,109 +16,105 @@ ...@@ -16,109 +16,105 @@
#include "src/cuda/cudnn_wrapper.h" #include "src/cuda/cudnn_wrapper.h"
namespace megdnn { namespace megdnn {
class CudnnConvDescParam { class CudnnConvDescParam {
public: public:
cudnnConvolutionDescriptor_t value; cudnnConvolutionDescriptor_t value;
Empty serialize(StringSerializer& ser, Empty) { Empty serialize(StringSerializer& ser, Empty) {
constexpr int nbDims = MEGDNN_MAX_NDIM; int ndim = MEGDNN_MAX_NDIM;
int padA[MEGDNN_MAX_NDIM]; int padA[MEGDNN_MAX_NDIM];
int strideA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM];
int dilationA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM];
cudnnConvolutionMode_t mode; cudnnConvolutionMode_t mode;
cudnnDataType_t computeType; cudnnDataType_t computeType;
cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType);
dilationA, &mode, &computeType); ser.write_plain(ndim);
ser.write_plain(nbDims); for (int i = 0; i < ndim; ++i) {
for (int i = 0; i < nbDims; ++i) { ser.write_plain(padA[i]);
ser.write_plain(padA[i]); ser.write_plain(strideA[i]);
ser.write_plain(strideA[i]); ser.write_plain(dilationA[i]);
ser.write_plain(dilationA[i]); }
ser.write_plain(mode);
ser.write_plain(computeType);
return Empty{};
} }
ser.write_plain(mode); Empty deserialize(StringSerializer& ser, Empty) {
ser.write_plain(computeType); int ndim = ser.read_plain<int>();
return Empty{}; int padA[MEGDNN_MAX_NDIM];
} int strideA[MEGDNN_MAX_NDIM];
Empty deserialize(StringSerializer& ser, Empty) { int dilationA[MEGDNN_MAX_NDIM];
int ndim = ser.read_plain<int>(); for (int i = 0; i < ndim; ++i) {
int padA[MEGDNN_MAX_NDIM]; padA[i] = ser.read_plain<int>();
int strideA[MEGDNN_MAX_NDIM]; strideA[i] = ser.read_plain<int>();
int dilationA[MEGDNN_MAX_NDIM]; dilationA[i] = ser.read_plain<int>();
for (int i = 0; i < ndim; ++i) { }
padA[i] = ser.read_plain<int>(); cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>();
strideA[i] = ser.read_plain<int>(); cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>();
dilationA[i] = ser.read_plain<int>(); cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType);
return Empty{};
} }
cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); };
cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); class CudnnTensorDescParam {
cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, public:
mode, computeType); cudnnTensorDescriptor_t value;
return Empty{}; Empty serialize(StringSerializer& ser, Empty) {
} int nbDims = MEGDNN_MAX_NDIM;
}; cudnnDataType_t dataType;
class CudnnTensorDescParam { int dimA[MEGDNN_MAX_NDIM];
public: int strideA[MEGDNN_MAX_NDIM];
cudnnTensorDescriptor_t value; cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA);
Empty serialize(StringSerializer& ser, Empty) { ser.write_plain(nbDims);
constexpr int nbDims = MEGDNN_MAX_NDIM; for (int i = 0; i < nbDims; ++i) {
cudnnDataType_t dataType; ser.write_plain(dimA[i]);
int dimA[MEGDNN_MAX_NDIM]; ser.write_plain(strideA[i]);
int strideA[MEGDNN_MAX_NDIM]; }
cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, ser.write_plain(dataType);
strideA); return Empty{};
ser.write_plain(nbDims);
for (int i = 0; i < nbDims; ++i) {
ser.write_plain(dimA[i]);
ser.write_plain(strideA[i]);
} }
ser.write_plain(dataType); Empty deserialize(StringSerializer& ser, Empty) {
return Empty{}; int nbDims = MEGDNN_MAX_NDIM;
} cudnnDataType_t dataType;
Empty deserialize(StringSerializer& ser, Empty) { int dimA[MEGDNN_MAX_NDIM];
constexpr int nbDims = MEGDNN_MAX_NDIM; int strideA[MEGDNN_MAX_NDIM];
cudnnDataType_t dataType; nbDims = ser.read_plain<int>();
int dimA[MEGDNN_MAX_NDIM]; for (int i = 0; i < nbDims; ++i) {
int strideA[MEGDNN_MAX_NDIM]; dimA[i] = ser.read_plain<int>();
nbDims = ser.read_plain<int>(); strideA[i] = ser.read_plain<int>();
for (int i = 0; i < nbDims; ++i) { }
dimA[i] = ser.read_plain<int>(); dataType = ser.read_plain<cudnnDataType_t>();
strideA[i] = ser.read_plain<int>(); cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA);
return Empty{};
} }
dataType = ser.read_plain<cudnnDataType_t>(); };
cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); class CudnnFilterDescParam {
return Empty{}; public:
} cudnnFilterDescriptor_t value;
}; Empty serialize(StringSerializer& ser, Empty) {
class CudnnFilterDescParam { int nbDims = MEGDNN_MAX_NDIM;
public: cudnnDataType_t dataType;
cudnnFilterDescriptor_t value; cudnnTensorFormat_t format;
Empty serialize(StringSerializer& ser, Empty) { int filterDimA[MEGDNN_MAX_NDIM];
constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA);
cudnnDataType_t dataType; ser.write_plain(nbDims);
cudnnTensorFormat_t format; for (int i = 0; i < nbDims; ++i) {
int filterDimA[MEGDNN_MAX_NDIM]; ser.write_plain(filterDimA[i]);
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, }
filterDimA); ser.write_plain(dataType);
ser.write_plain(nbDims); ser.write_plain(format);
for (int i = 0; i < nbDims; ++i) { return Empty{};
ser.write_plain(filterDimA[i]);
} }
ser.write_plain(dataType); Empty deserialize(StringSerializer& ser, Empty) {
ser.write_plain(format); int nbDims = MEGDNN_MAX_NDIM;
return Empty{}; cudnnDataType_t dataType;
} cudnnTensorFormat_t format;
Empty deserialize(StringSerializer& ser, Empty) { int filterDimA[MEGDNN_MAX_NDIM];
constexpr int nbDims = MEGDNN_MAX_NDIM; nbDims = ser.read_plain<int>();
cudnnDataType_t dataType; for (int i = 0; i < nbDims; ++i) {
cudnnTensorFormat_t format; filterDimA[i] = ser.read_plain<int>();
int filterDimA[MEGDNN_MAX_NDIM]; }
nbDims = ser.read_plain<int>(); dataType = ser.read_plain<cudnnDataType_t>();
for (int i = 0; i < nbDims; ++i) { format = ser.read_plain<cudnnTensorFormat_t>();
filterDimA[i] = ser.read_plain<int>(); cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA);
return Empty{};
} }
dataType = ser.read_plain<cudnnDataType_t>(); };
format = ser.read_plain<cudnnTensorFormat_t>(); }
cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA);
return Empty{};
}
};
} // namespace megdnn
...@@ -56,8 +56,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( ...@@ -56,8 +56,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
conv_args.init_conv_desc(D); conv_args.init_conv_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = conv_args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
conv_args.handle->cudnn_handle(), D.src_desc.desc, conv_args.handle->cudnn_handle(), D.src_desc.desc,
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc,
m_cudnn_enum, &workspace_size); m_cudnn_enum, &workspace_size);
...@@ -83,8 +82,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( ...@@ -83,8 +82,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle(
conv_args.init_conv_desc(D); conv_args.init_conv_desc(D);
size_t conv_workspace_size; size_t conv_workspace_size;
auto& cudnn = conv_args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
conv_args.handle->cudnn_handle(), D.src_desc.desc, conv_args.handle->cudnn_handle(), D.src_desc.desc,
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc,
m_cudnn_enum, &conv_workspace_size); m_cudnn_enum, &conv_workspace_size);
......
...@@ -149,8 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( ...@@ -149,8 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
megdnn_throw("unsupported NonlineMode"); megdnn_throw("unsupported NonlineMode");
} }
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum,
&workspace_size); &workspace_size);
...@@ -163,8 +162,7 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( ...@@ -163,8 +162,7 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes(
args.init_conv_bias_desc(D); args.init_conv_bias_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc,
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum,
&workspace_size); &workspace_size);
......
...@@ -95,13 +95,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( ...@@ -95,13 +95,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
CUDNNForwardDescs desc; CUDNNForwardDescs desc;
conv_args.init_conv_desc(desc); conv_args.init_conv_desc(desc);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn();
int max_count = 0; int max_count = 0;
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle,
&max_count)); &max_count));
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count);
int ret_count = 0; int ret_count = 0;
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7(
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count,
&ret_count, algo_perf.data())); &ret_count, algo_perf.data()));
......
...@@ -42,10 +42,9 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( ...@@ -42,10 +42,9 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
if (!conv_bias::is_cudnn_supported(bias_args)) if (!conv_bias::is_cudnn_supported(bias_args))
return false; return false;
auto& cudnn = args.handle->cudnn();
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.filter_desc.desc, D.filter_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
...@@ -58,11 +57,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( ...@@ -58,11 +57,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes(
const SizeArgs &args) const { const SizeArgs &args) const {
auto& cudnn = args.handle->cudnn();
CUDNNBwdDataDescs D; CUDNNBwdDataDescs D;
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.filter_desc.desc, D.filter_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
......
...@@ -29,7 +29,6 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( ...@@ -29,7 +29,6 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available(
return false; return false;
} }
} }
auto& cudnn = args.handle->cudnn();
CUDNNBwdFilterDescs D; CUDNNBwdFilterDescs D;
TensorLayout bias_layout, z_layout; TensorLayout bias_layout, z_layout;
...@@ -44,7 +43,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( ...@@ -44,7 +43,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available(
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.src_desc.desc, D.src_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
...@@ -57,11 +56,10 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( ...@@ -57,11 +56,10 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available(
size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes(
const SizeArgs &args) const { const SizeArgs &args) const {
auto& cudnn = args.handle->cudnn();
CUDNNBwdFilterDescs D; CUDNNBwdFilterDescs D;
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.src_desc.desc, D.src_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
......
...@@ -144,13 +144,12 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( ...@@ -144,13 +144,12 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR(negative_attr); MEGDNN_MARK_USED_VAR(negative_attr);
auto& cudnn = args.handle->cudnn();
int max_count = 0; int max_count = 0;
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(
cudnn_handle, &max_count)); cudnn_handle, &max_count));
SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count);
int ret_count = 0; int ret_count = 0;
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7(
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count,
algo_perf.data())); algo_perf.data()));
...@@ -280,13 +279,12 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( ...@@ -280,13 +279,12 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
#endif #endif
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR(negative_attr); MEGDNN_MARK_USED_VAR(negative_attr);
auto& cudnn = args.handle->cudnn();
int max_count = 0; int max_count = 0;
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
cudnn_handle, &max_count)); cudnn_handle, &max_count));
SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count);
int ret_count = 0; int ret_count = 0;
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc,
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count,
algo_perf.data())); algo_perf.data()));
......
...@@ -28,8 +28,7 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( ...@@ -28,8 +28,7 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available(
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.filter_desc.desc, D.filter_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
...@@ -45,8 +44,7 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( ...@@ -45,8 +44,7 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes(
CUDNNBwdDataDescs D; CUDNNBwdDataDescs D;
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.filter_desc.desc, D.filter_desc.desc,
D.diff_desc.desc, D.diff_desc.desc,
......
...@@ -28,8 +28,7 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( ...@@ -28,8 +28,7 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available(
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize(
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc,
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size);
return status == CUDNN_STATUS_SUCCESS; return status == CUDNN_STATUS_SUCCESS;
...@@ -41,8 +40,7 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( ...@@ -41,8 +40,7 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes(
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize(
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc,
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size);
megdnn_assert(status == CUDNN_STATUS_SUCCESS, megdnn_assert(status == CUDNN_STATUS_SUCCESS,
......
...@@ -27,8 +27,7 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( ...@@ -27,8 +27,7 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available(
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.src_desc.desc, D.src_desc.desc,
D.filter_desc.desc, D.filter_desc.desc,
...@@ -44,8 +43,7 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( ...@@ -44,8 +43,7 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes(
CUDNNForwardDescs D; CUDNNForwardDescs D;
args.init_desc(D); args.init_desc(D);
size_t workspace_size; size_t workspace_size;
auto& cudnn = args.handle->cudnn(); auto status = cudnnGetConvolutionForwardWorkspaceSize(
auto status = cudnn.GetConvolutionForwardWorkspaceSize(
args.handle->cudnn_handle(), args.handle->cudnn_handle(),
D.src_desc.desc, D.src_desc.desc,
D.filter_desc.desc, D.filter_desc.desc,
......
...@@ -93,7 +93,7 @@ namespace convolution3d { ...@@ -93,7 +93,7 @@ namespace convolution3d {
const Workspace &workspace, void *&raw_ptr); const Workspace &workspace, void *&raw_ptr);
inline bool cudnn_get_convolution_fwd_algo_helper( inline bool cudnn_get_convolution_fwd_algo_helper(
Handle* handle, const cudnnTensorDescriptor_t x_desc, cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc, const cudnnFilterDescriptor_t w_desc,
const cudnnConvolutionDescriptor_t conv_desc, const cudnnConvolutionDescriptor_t conv_desc,
const cudnnTensorDescriptor_t y_desc, const cudnnTensorDescriptor_t y_desc,
...@@ -103,14 +103,13 @@ namespace convolution3d { ...@@ -103,14 +103,13 @@ namespace convolution3d {
MEGDNN_MARK_USED_VAR(positive_attr); MEGDNN_MARK_USED_VAR(positive_attr);
MEGDNN_MARK_USED_VAR(negative_attr); MEGDNN_MARK_USED_VAR(negative_attr);
#if CUDNN_MAJOR >= 7 #if CUDNN_MAJOR >= 7
auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn();
int algo_max_count = 0; int algo_max_count = 0;
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(
cuda::cudnn_handle(handle), &algo_max_count)); cudnn_handle, &algo_max_count));
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count);
int algo_count = 0; int algo_count = 0;
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7(
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count,
&algo_count, algo_perf.data())); &algo_count, algo_perf.data()));
for (int i = 0; i < algo_count; ++i) { for (int i = 0; i < algo_count; ++i) {
if (algo_perf[i].algo == if (algo_perf[i].algo ==
...@@ -118,8 +117,8 @@ namespace convolution3d { ...@@ -118,8 +117,8 @@ namespace convolution3d {
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING)
continue; continue;
size_t workspace_size = 0; size_t workspace_size = 0;
cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( cudnn_check(cudnnGetConvolutionForwardWorkspaceSize(
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, cudnn_handle, x_desc, w_desc, conv_desc, y_desc,
algo_perf[i].algo, &workspace_size)); algo_perf[i].algo, &workspace_size));
if (workspace_size > workspace_limit_in_bytes) continue; if (workspace_size > workspace_limit_in_bytes) continue;
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) {
...@@ -135,7 +134,7 @@ namespace convolution3d { ...@@ -135,7 +134,7 @@ namespace convolution3d {
return false; return false;
#else #else
cudnn_check(cudnnGetConvolutionForwardAlgorithm( cudnn_check(cudnnGetConvolutionForwardAlgorithm(
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, cudnn_handle, x_desc, w_desc, conv_desc, y_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_in_bytes, algo)); workspace_limit_in_bytes, algo));
return true; return true;
......
...@@ -64,12 +64,13 @@ Convolution3DForwardImpl::get_algorithm_heuristic( ...@@ -64,12 +64,13 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
auto get_cudnn_algo = auto get_cudnn_algo =
[this, &args, workspace_limit_in_bytes, positive_attr, [this, &args, workspace_limit_in_bytes, positive_attr,
negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { negative_attr]() -> Convolution3DForwardImpl::AlgoBase* {
auto cudnn_handle = cuda::cudnn_handle(this->handle());
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
CUDNNForwardDescs desc; CUDNNForwardDescs desc;
args.init_desc(desc); args.init_desc(desc);
bool got = cudnn_get_convolution_fwd_algo_helper( bool got = cudnn_get_convolution_fwd_algo_helper(
this->handle(), desc.src_desc.desc, desc.filter_desc.desc, cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc,
desc.conv_desc.desc, desc.dst_desc.desc, desc.conv_desc.desc, desc.dst_desc.desc,
workspace_limit_in_bytes, &algo, positive_attr, negative_attr); workspace_limit_in_bytes, &algo, positive_attr, negative_attr);
if (got) { if (got) {
......
...@@ -56,7 +56,7 @@ namespace convolution { ...@@ -56,7 +56,7 @@ namespace convolution {
using KernLayout = _kern_layout; \ using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \ using OutputLayout = _output_layout; \
using Param = _conv_param; \ using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_ static constexpr bool check_bounds = check_bounds_;
#define MEGDNN_COMMA , #define MEGDNN_COMMA ,
template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype,
......
...@@ -53,7 +53,7 @@ namespace convolution { ...@@ -53,7 +53,7 @@ namespace convolution {
using KernLayout = _kern_layout; \ using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \ using OutputLayout = _output_layout; \
using Param = _conv_param; \ using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_ static constexpr bool check_bounds = check_bounds_;
#define MEGDNN_COMMA , #define MEGDNN_COMMA ,
template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_,
......
...@@ -53,7 +53,7 @@ namespace convolution { ...@@ -53,7 +53,7 @@ namespace convolution {
using KernLayout = _kern_layout; \ using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \ using OutputLayout = _output_layout; \
using Param = _conv_param; \ using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_ static constexpr bool check_bounds = check_bounds_;
#define MEGDNN_COMMA , #define MEGDNN_COMMA ,
template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_,
......
...@@ -11,16 +11,13 @@ ...@@ -11,16 +11,13 @@
#include "src/common/handle_impl.h" #include "src/common/handle_impl.h"
#include "src/common/version_symbol.h" #include "src/common/version_symbol.h"
#include "src/common/api_cache.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/cuda/api_cache.h"
#include "megdnn/common.h" #include "megdnn/common.h"
#include <cuda.h> #include <cuda.h>
#include <cstring> #include <cstring>
#include <memory>
#define STR_HELPER(x) #x #define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x) #define STR(x) STR_HELPER(x)
...@@ -94,8 +91,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): ...@@ -94,8 +91,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
// check tk1 // check tk1
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0);
m_cusolver_handle = nullptr; m_cusolver_handle = nullptr;
m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle);
} }
HandleImpl::~HandleImpl() noexcept { HandleImpl::~HandleImpl() noexcept {
...@@ -141,111 +136,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { ...@@ -141,111 +136,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return HandleVendorType::CUDA; return HandleVendorType::CUDA;
} }
HandleImpl::CUDNN& HandleImpl::cudnn() { } // namespace cuda
return *m_cudnn_api_cache; } // namespace megdnn
}
HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) {
m_handle = handle;
GetConvolutionForwardWorkspaceSize =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnTensorDescParam>()
.input<CudnnFilterDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnTensorDescParam>()
.input<Param<cudnnConvolutionFwdAlgo_t>>()
.output<RefParam<size_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionForwardWorkspaceSize);
#if CUDNN_MAJOR >= 7
GetConvolutionForwardAlgorithm_v7 =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnTensorDescParam>()
.input<CudnnFilterDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnTensorDescParam>()
.input<Param<int>>()
.output<RefArraySizeParam<int>>()
.output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionForwardAlgorithm_v7);
GetConvolutionForwardAlgorithmMaxCount =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.output<RefParam<int>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionForwardAlgorithmMaxCount);
#endif
GetConvolutionBackwardDataWorkspaceSize =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnFilterDescParam>()
.input<CudnnTensorDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnTensorDescParam>()
.input<Param<cudnnConvolutionBwdDataAlgo_t>>()
.output<RefParam<size_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardDataWorkspaceSize);
#if CUDNN_MAJOR >= 7
GetConvolutionBackwardDataAlgorithm_v7 =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnFilterDescParam>()
.input<CudnnTensorDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnTensorDescParam>()
.input<Param<int>>()
.output<RefArraySizeParam<int>>()
.output<ArrayParam<int,
cudnnConvolutionBwdDataAlgoPerf_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardDataAlgorithm_v7);
GetConvolutionBackwardDataAlgorithmMaxCount =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.output<RefParam<int>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount);
#endif
GetConvolutionBackwardFilterWorkspaceSize =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnTensorDescParam>()
.input<CudnnTensorDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnFilterDescParam>()
.input<Param<cudnnConvolutionBwdFilterAlgo_t>>()
.output<RefParam<size_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardFilterWorkspaceSize);
#if CUDNN_MAJOR >= 7
GetConvolutionBackwardFilterAlgorithm_v7 =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.input<CudnnTensorDescParam>()
.input<CudnnTensorDescParam>()
.input<CudnnConvDescParam>()
.input<CudnnFilterDescParam>()
.input<Param<int>>()
.output<RefArraySizeParam<int>>()
.output<ArrayParam<int,
cudnnConvolutionBwdFilterAlgoPerf_t>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7);
GetConvolutionBackwardFilterAlgorithmMaxCount =
FunctionCacheBuilder<>()
.input<Param<cudnnHandle_t>>()
.output<RefParam<int>>()
.ret<Param<cudnnStatus_t>>()
.build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount);
#endif
}
} // namespace cuda
} // namespace megdnn
MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION);
MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
......
...@@ -124,10 +124,6 @@ class HandleImpl: public HandleImplHelper { ...@@ -124,10 +124,6 @@ class HandleImpl: public HandleImplHelper {
size_t image2d_pitch_alignment() const override; size_t image2d_pitch_alignment() const override;
HandleVendorType vendor_type() const override; HandleVendorType vendor_type() const override;
class CUDNN;
CUDNN& cudnn();
private: private:
bool m_is_tegra_k1; bool m_is_tegra_k1;
int m_device_id; int m_device_id;
...@@ -160,34 +156,9 @@ class HandleImpl: public HandleImplHelper { ...@@ -160,34 +156,9 @@ class HandleImpl: public HandleImplHelper {
//! device ptr to const scalars //! device ptr to const scalars
ConstScalars* m_const_scalars; ConstScalars* m_const_scalars;
std::unique_ptr<CUDNN> m_cudnn_api_cache;
void initialize_cusolver(); void initialize_cusolver();
}; };
class HandleImpl::CUDNN {
cudnnHandle_t m_handle;
public:
CUDNN(cudnnHandle_t handle);
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME;
WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize);
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7);
WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount);
#endif
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7);
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount);
#endif
WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize);
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount);
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7);
#endif
WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize);
#undef WRAP_CUDNN_API
};
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册