Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5419a95d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5419a95d
编写于
3月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(cuda/conv): cache serval cudnn api
GitOrigin-RevId: 188c62cdd65ba27793b0af89164c6eee0122b21f
上级
19887942
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
389 addition
and
207 deletion
+389
-207
dnn/src/common/api_cache.h
dnn/src/common/api_cache.h
+100
-76
dnn/src/cuda/api_cache.h
dnn/src/cuda/api_cache.h
+100
-96
dnn/src/cuda/conv_bias/cudnn_conv.cpp
dnn/src/cuda/conv_bias/cudnn_conv.cpp
+4
-2
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
+4
-2
dnn/src/cuda/conv_bias/opr_impl.cpp
dnn/src/cuda/conv_bias/opr_impl.cpp
+3
-2
dnn/src/cuda/convolution/backward_data/cudnn.cpp
dnn/src/cuda/convolution/backward_data/cudnn.cpp
+4
-2
dnn/src/cuda/convolution/backward_filter/cudnn.cpp
dnn/src/cuda/convolution/backward_filter/cudnn.cpp
+4
-2
dnn/src/cuda/convolution/opr_impl.cpp
dnn/src/cuda/convolution/opr_impl.cpp
+6
-4
dnn/src/cuda/convolution3d/backward_data/cudnn.cpp
dnn/src/cuda/convolution3d/backward_data/cudnn.cpp
+4
-2
dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp
dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp
+4
-2
dnn/src/cuda/convolution3d/forward/cudnn.cpp
dnn/src/cuda/convolution3d/forward/cudnn.cpp
+4
-2
dnn/src/cuda/convolution3d/helper.h
dnn/src/cuda/convolution3d/helper.h
+9
-8
dnn/src/cuda/convolution3d/opr_impl.cpp
dnn/src/cuda/convolution3d/opr_impl.cpp
+1
-2
dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh
.../cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh
+1
-1
dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh
...c/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh
+1
-1
dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
+1
-1
dnn/src/cuda/handle.cpp
dnn/src/cuda/handle.cpp
+110
-2
dnn/src/cuda/handle.h
dnn/src/cuda/handle.h
+29
-0
未找到文件。
dnn/src/common/api_cache.h
浏览文件 @
5419a95d
...
...
@@ -12,32 +12,28 @@
#pragma once
#include <unordered_map>
#include <memory>
#include <cstring>
#include <memory>
#include <tuple>
#include <unordered_map>
#include "megdnn/thin/function.h"
namespace
megdnn
{
template
<
typename
TSignature
>
class
FunctionCache
;
template
<
typename
TRet
,
typename
...
TArgs
>
class
FunctionCache
<
TRet
(
TArgs
...)
>
{
template
<
typename
...
TArgs
>
class
FunctionCache
{
public:
using
key_t
=
std
::
string
;
using
value_t
=
TRet
;
using
value_t
=
std
::
string
;
using
key_mapper_t
=
thin_function
<
key_t
(
TArgs
...)
>
;
using
value_mapper_t
=
thin_function
<
value_t
(
TArgs
...)
>
;
using
storage_t
=
std
::
unordered_map
<
key_t
,
value_t
>
;
public:
storage_t
storage
;
key_mapper_t
key_mapper
;
value_mapper_t
value_mapper
;
public:
TRe
t
operator
()(
TArgs
...
args
)
{
value_
t
operator
()(
TArgs
...
args
)
{
key_t
key
=
key_mapper
(
args
...);
if
(
storage
.
count
(
key
)
==
0
)
{
storage
[
key
]
=
value_mapper
(
std
::
forward
<
TArgs
>
(
args
)...);
...
...
@@ -46,28 +42,28 @@ public:
}
};
// FIFO
class
StringSerializer
{
private:
std
::
string
m_buffer
;
size_t
m_cursor
=
0
;
public:
template
<
typename
T
>
T
read_plain
()
{
T
result
;
std
::
memcpy
(
&
result
,
m_buffer
.
data
()
+
m_cursor
,
sizeof
(
T
));
static_assert
(
std
::
is_trivially_copyable
<
T
>::
value
,
"invalid type"
);
T
ret
;
memcpy
(
&
ret
,
m_buffer
.
data
()
+
m_cursor
,
sizeof
(
T
));
m_cursor
+=
sizeof
(
T
);
return
re
sul
t
;
return
ret
;
}
template
<
typename
T
>
void
write_plain
(
T
value
)
{
m_buffer
.
resize
(
m_buffer
.
size
()
+
sizeof
(
T
));
std
::
memcpy
(
const_cast
<
char
*>
(
m_buffer
.
data
())
+
(
m_buffer
.
size
()
-
sizeof
(
T
)),
&
value
,
sizeof
(
T
));
static_assert
(
std
::
is_trivially_copyable
<
T
>::
value
,
"type should be trivially copyable"
);
m_buffer
.
append
(
reinterpret_cast
<
const
char
*>
(
&
value
),
sizeof
(
T
));
}
std
::
string
take
()
{
std
::
string
result
;
m_buffer
.
erase
(
0
,
m_cursor
);
return
std
::
move
(
m_buffer
);
}
void
set
(
std
::
string
new_buf
)
{
...
...
@@ -76,20 +72,20 @@ public:
}
};
struct
Empty
{};
template
<
typename
...
TParams
>
class
ParamBundle
{
private:
template
<
std
::
size_t
N
,
std
::
size_t
...
Seq
>
static
std
::
index_sequence
<
N
+
Seq
...
>
add_all
(
std
::
index_sequence
<
Seq
...
>
){
template
<
std
::
size_t
N
,
std
::
size_t
...
Seq
>
static
std
::
index_sequence
<
N
+
Seq
...
>
add_all
(
std
::
index_sequence
<
Seq
...
>
)
{
return
{};
}
template
<
std
::
size_t
Min
,
std
::
size_t
Max
>
using
make_index_range
=
decltype
(
add_all
<
Min
>
(
std
::
make_index_sequence
<
Max
-
Min
>
()));
template
<
std
::
size_t
Min
,
std
::
size_t
Max
>
using
make_index_range
=
decltype
(
add_all
<
Min
>
(
std
::
make_index_sequence
<
Max
-
Min
>
()));
using
storage_t
=
std
::
tuple
<
typename
std
::
remove_reference_t
<
TParams
>
...
>
;
storage_t
m_storage
;
...
...
@@ -99,21 +95,31 @@ private:
return
functor
(
std
::
get
<
Indices
>
(
m_storage
).
value
...);
}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TPrev
>
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
return
serialize_helper
(
ser
,
std
::
get
<
Index
>
(
m_storage
).
serialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
return
serialize_helper
(
ser
,
std
::
get
<
Index
>
(
m_storage
).
serialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
}
template
<
typename
TPrev
>
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
auto
serialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TPrev
>
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
return
deserialize_helper
(
ser
,
std
::
get
<
Index
>
(
m_storage
).
deserialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<
Index
,
Indices
...
>
)
{
return
deserialize_helper
(
ser
,
std
::
get
<
Index
>
(
m_storage
).
deserialize
(
ser
,
prev
),
std
::
index_sequence
<
Indices
...
>
());
}
template
<
typename
TPrev
>
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
auto
deserialize_helper
(
StringSerializer
&
ser
,
TPrev
&&
prev
,
std
::
index_sequence
<>
)
{}
template
<
size_t
Index
,
size_t
...
Indices
,
typename
TArg
,
typename
...
TArgs
>
void
set_values_helper
(
std
::
index_sequence
<
Index
,
Indices
...
>
,
TArg
&&
arg
,
TArgs
&&
...
args
)
{
void
set_values_helper
(
std
::
index_sequence
<
Index
,
Indices
...
>
,
TArg
&&
arg
,
TArgs
&&
...
args
)
{
std
::
get
<
Index
>
(
m_storage
).
value
=
arg
;
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
}
template
<
size_t
...
Indices
>
void
set_values_helper
(
std
::
index_sequence
<
Indices
...
>
)
{
...
...
@@ -123,27 +129,33 @@ private:
public:
template
<
typename
TFunctor
>
auto
call_by
(
TFunctor
&&
functor
)
{
return
call_helper
(
std
::
forward
<
TFunctor
>
(
functor
),
std
::
make_index_sequence
<
sizeof
...(
TParams
)
>
());
return
call_helper
(
std
::
forward
<
TFunctor
>
(
functor
),
std
::
make_index_sequence
<
sizeof
...(
TParams
)
>
());
}
template
<
size_t
NBegin
,
size_t
NEnd
>
void
serialize_params
(
StringSerializer
&
ser
)
{
static_assert
(
NEnd
>=
NBegin
,
"invalid range"
);
serialize_helper
(
ser
,
Empty
{},
make_index_range
<
NBegin
,
NEnd
>
());
serialize_helper
(
ser
,
Empty
{},
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()));
}
template
<
size_t
NBegin
,
size_t
NEnd
>
void
deserialize_params
(
StringSerializer
&
ser
)
{
static_assert
(
NEnd
>=
NBegin
,
"invalid range"
);
deserialize_helper
(
ser
,
Empty
{},
make_index_range
<
NBegin
,
NEnd
>
());
deserialize_helper
(
ser
,
Empty
{},
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()));
}
template
<
size_t
NBegin
,
size_t
NEnd
,
typename
...
TArgs
>
void
set_values
(
TArgs
&&
...
args
)
{
set_values_helper
(
make_index_range
<
NBegin
,
NEnd
>
(),
std
::
forward
<
TArgs
>
(
args
)...);
set_values_helper
(
add_all
<
NBegin
>
(
std
::
make_index_sequence
<
NEnd
-
NBegin
>
()),
std
::
forward
<
TArgs
>
(
args
)...);
}
};
template
<
typename
T
>
class
Ret
Param
{
class
Param
{
public:
T
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
...
...
@@ -156,45 +168,68 @@ public:
}
};
template
<
typename
TRet
=
RetParam
<
Empty
>,
typename
TInputs
=
std
::
tuple
<>
,
typename
TOutputs
=
std
::
tuple
<>>
template
<
typename
TRet
=
Param
<
Empty
>,
typename
TInputs
=
std
::
tuple
<>
,
typename
TOutputs
=
std
::
tuple
<>>
class
FunctionCacheBuilder
{
private:
static
auto
declargs
()
->
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
declval
<
TOutputs
>
()))
{
return
{};
}
static
auto
declargs
()
->
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
declval
<
TOutputs
>
()))
{
return
{};
}
template
<
size_t
...
Indices
>
static
auto
declfunction_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
thin_function
<
decltype
(
std
::
declval
<
TRet
>
().
value
)(
decltype
(
std
::
get
<
Indices
>
(
declargs
()).
value
)...)
>
{
return
{};
}
static
auto
declfunction_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
thin_function
<
decltype
(
std
::
declval
<
TRet
>
().
value
)(
decltype
(
std
::
get
<
Indices
>
(
declargs
()).
value
)...)
>
{
return
{};
}
static
auto
declfunction
()
{
return
declfunction_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
return
declfunction_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
}
template
<
size_t
...
Indices
>
static
auto
declbundle_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
ParamBundle
<
decltype
(
std
::
get
<
Indices
>
(
declargs
()))...
>
{
return
{};
}
static
auto
declbundle_helper
(
std
::
index_sequence
<
Indices
...
>
)
->
ParamBundle
<
decltype
(
std
::
get
<
Indices
>
(
declargs
()))...
>
{
return
{};
}
static
auto
declbundle
()
{
return
declbundle_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
return
declbundle_helper
(
std
::
make_index_sequence
<
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
());
}
using
function_t
=
decltype
(
declfunction
());
using
bundle_t
=
decltype
(
declbundle
());
public:
template
<
typename
TNewRet
>
auto
ret
()
{
static_assert
(
std
::
is_same
<
TRet
,
RetParam
<
Empty
>>::
value
,
"return value redefinition"
);
static_assert
(
std
::
is_same
<
TRet
,
Param
<
Empty
>>::
value
,
"return value redefinition"
);
return
FunctionCacheBuilder
<
TNewRet
,
TInputs
,
TOutputs
>
{};
}
template
<
typename
TNewInput
>
auto
input
()
{
using
TNewInputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewInput
>
())));
using
TNewInputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TInputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewInput
>
())));
return
FunctionCacheBuilder
<
TRet
,
TNewInputs
,
TOutputs
>
{};
}
template
<
typename
TNewOutput
>
auto
output
()
{
using
TNewOutputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TOutputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewOutput
>
())));
using
TNewOutputs
=
decltype
(
std
::
tuple_cat
(
std
::
declval
<
TOutputs
>
(),
std
::
make_tuple
(
std
::
declval
<
TNewOutput
>
())));
return
FunctionCacheBuilder
<
TRet
,
TInputs
,
TNewOutputs
>
{};
}
template
<
typename
TFunctor
>
function_t
build
(
TFunctor
func
)
{
FunctionCache
<
std
::
string
(
bundle_t
)
>
cache
;
FunctionCache
<
bundle_t
>
cache
;
cache
.
key_mapper
=
[](
bundle_t
bundle
)
{
StringSerializer
ser
;
bundle
.
template
serialize_params
<
0
,
std
::
tuple_size
<
TInputs
>
::
value
>
(
ser
);
bundle
.
template
serialize_params
<
0
,
std
::
tuple_size
<
TInputs
>
::
value
>
(
ser
);
return
ser
.
take
();
};
cache
.
value_mapper
=
[
=
](
bundle_t
bundle
)
{
...
...
@@ -202,42 +237,33 @@ public:
TRet
ret
;
ret
.
value
=
bundle
.
call_by
(
func
);
ret
.
serialize
(
ser
,
Empty
{});
bundle
.
template
serialize_params
<
std
::
tuple_size
<
TInputs
>
::
value
,
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
(
ser
);
bundle
.
template
serialize_params
<
std
::
tuple_size
<
TInputs
>
::
value
,
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
>
(
ser
);
return
ser
.
take
();
};
return
[
=
](
auto
&&
...
args
)
mutable
{
bundle_t
bundle
;
TRet
ret
;
StringSerializer
ser
;
static_assert
(
sizeof
...(
args
)
==
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
,
"arg count mismatch"
);
bundle
.
template
set_values
<
0
,
sizeof
...(
args
)>(
std
::
forward
<
decltype
(
args
)
>
(
args
)...);
static_assert
(
sizeof
...(
args
)
==
std
::
tuple_size
<
TInputs
>::
value
+
std
::
tuple_size
<
TOutputs
>::
value
,
"args count mismatch"
);
bundle
.
template
set_values
<
0
,
sizeof
...(
args
)>(
std
::
forward
<
decltype
(
args
)
>
(
args
)...);
ser
.
set
(
cache
(
bundle
));
ret
.
deserialize
(
ser
,
Empty
{});
constexpr
size_t
n_inputs
=
std
::
tuple_size
<
TInputs
>::
value
;
constexpr
size_t
n_outputs
=
std
::
tuple_size
<
TOutputs
>::
value
;
bundle
.
template
deserialize_params
<
n_inputs
,
n_inputs
+
n_outputs
>(
ser
);
bundle
.
template
deserialize_params
<
n_inputs
,
n_inputs
+
n_outputs
>(
ser
);
return
ret
.
value
;
};
}
};
template
<
typename
T
>
class
PlainParam
{
public:
T
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
ser
.
write_plain
(
value
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
value
=
ser
.
read_plain
<
T
>
();
return
Empty
{};
}
};
template
<
typename
T
>
class
RefParam
{
public:
...
...
@@ -252,7 +278,6 @@ public:
}
};
template
<
typename
T
>
class
RefArraySizeParam
{
public:
...
...
@@ -266,7 +291,6 @@ public:
}
};
template
<
typename
TSize
,
typename
TItem
>
class
ArrayParam
{
public:
...
...
@@ -285,4 +309,4 @@ public:
}
};
}
}
// namespace megdnn
dnn/src/cuda/api_cache.h
浏览文件 @
5419a95d
...
...
@@ -16,105 +16,109 @@
#include "src/cuda/cudnn_wrapper.h"
namespace
megdnn
{
class
CudnnConvDescParam
{
public:
cudnnConvolutionDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
int
ndim
=
MEGDNN_MAX_NDIM
;
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
computeType
;
cudnnGetConvolutionNdDescriptor
(
value
,
MEGDNN_MAX_NDIM
,
&
ndim
,
padA
,
strideA
,
dilationA
,
&
mode
,
&
computeType
);
ser
.
write_plain
(
ndim
);
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
ser
.
write_plain
(
padA
[
i
]);
ser
.
write_plain
(
strideA
[
i
]);
ser
.
write_plain
(
dilationA
[
i
]);
}
ser
.
write_plain
(
mode
);
ser
.
write_plain
(
computeType
);
return
Empty
{};
class
CudnnConvDescParam
{
public:
cudnnConvolutionDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
cudnnConvolutionMode_t
mode
;
cudnnDataType_t
computeType
;
cudnnGetConvolutionNdDescriptor
(
value
,
nbDims
,
&
nbDims
,
padA
,
strideA
,
dilationA
,
&
mode
,
&
computeType
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
padA
[
i
]);
ser
.
write_plain
(
strideA
[
i
]);
ser
.
write_plain
(
dilationA
[
i
]);
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
ndim
=
ser
.
read_plain
<
int
>
();
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
padA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
dilationA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
cudnnConvolutionMode_t
mode
=
ser
.
read_plain
<
cudnnConvolutionMode_t
>
();
cudnnDataType_t
computeType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetConvolutionNdDescriptor
(
value
,
ndim
,
padA
,
strideA
,
dilationA
,
mode
,
computeType
);
return
Empty
{};
ser
.
write_plain
(
mode
);
ser
.
write_plain
(
computeType
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
ndim
=
ser
.
read_plain
<
int
>
();
int
padA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
int
dilationA
[
MEGDNN_MAX_NDIM
];
for
(
int
i
=
0
;
i
<
ndim
;
++
i
)
{
padA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
dilationA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
};
class
CudnnTensorDescParam
{
public:
cudnnTensorDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
cudnnGetTensorNdDescriptor
(
value
,
nbDims
,
&
dataType
,
&
nbDims
,
dimA
,
strideA
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
dimA
[
i
]);
ser
.
write_plain
(
strideA
[
i
]);
}
ser
.
write_plain
(
dataType
);
return
Empty
{};
cudnnConvolutionMode_t
mode
=
ser
.
read_plain
<
cudnnConvolutionMode_t
>
();
cudnnDataType_t
computeType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetConvolutionNdDescriptor
(
value
,
ndim
,
padA
,
strideA
,
dilationA
,
mode
,
computeType
);
return
Empty
{};
}
};
class
CudnnTensorDescParam
{
public:
cudnnTensorDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
cudnnGetTensorNdDescriptor
(
value
,
nbDims
,
&
dataType
,
&
nbDims
,
dimA
,
strideA
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
dimA
[
i
]);
ser
.
write_plain
(
strideA
[
i
]);
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
dimA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetTensorNdDescriptor
(
value
,
dataType
,
nbDims
,
dimA
,
strideA
);
return
Empty
{};
ser
.
write_plain
(
dataType
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
int
dimA
[
MEGDNN_MAX_NDIM
];
int
strideA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
dimA
[
i
]
=
ser
.
read_plain
<
int
>
();
strideA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
};
class
CudnnFilterDescParam
{
public:
cudnnFilterDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
cudnnGetFilterNdDescriptor
(
value
,
nbDims
,
&
dataType
,
&
format
,
&
nbDims
,
filterDimA
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
filterDimA
[
i
]);
}
ser
.
write_plain
(
dataType
);
ser
.
write_plain
(
format
);
return
Empty
{};
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
cudnnSetTensorNdDescriptor
(
value
,
dataType
,
nbDims
,
dimA
,
strideA
);
return
Empty
{};
}
};
class
CudnnFilterDescParam
{
public:
cudnnFilterDescriptor_t
value
;
Empty
serialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
cudnnGetFilterNdDescriptor
(
value
,
nbDims
,
&
dataType
,
&
format
,
&
nbDims
,
filterDimA
);
ser
.
write_plain
(
nbDims
);
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
ser
.
write_plain
(
filterDimA
[
i
]);
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
filterDimA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
format
=
ser
.
read_plain
<
cudnnTensorFormat_t
>
();
cudnnSetFilterNdDescriptor
(
value
,
dataType
,
format
,
nbDims
,
filterDimA
);
return
Empty
{};
ser
.
write_plain
(
dataType
);
ser
.
write_plain
(
format
);
return
Empty
{};
}
Empty
deserialize
(
StringSerializer
&
ser
,
Empty
)
{
constexpr
int
nbDims
=
MEGDNN_MAX_NDIM
;
cudnnDataType_t
dataType
;
cudnnTensorFormat_t
format
;
int
filterDimA
[
MEGDNN_MAX_NDIM
];
nbDims
=
ser
.
read_plain
<
int
>
();
for
(
int
i
=
0
;
i
<
nbDims
;
++
i
)
{
filterDimA
[
i
]
=
ser
.
read_plain
<
int
>
();
}
};
}
dataType
=
ser
.
read_plain
<
cudnnDataType_t
>
();
format
=
ser
.
read_plain
<
cudnnTensorFormat_t
>
();
cudnnSetFilterNdDescriptor
(
value
,
dataType
,
format
,
nbDims
,
filterDimA
);
return
Empty
{};
}
};
}
// namespace megdnn
dnn/src/cuda/conv_bias/cudnn_conv.cpp
浏览文件 @
5419a95d
...
...
@@ -39,7 +39,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
conv_args
.
init_conv_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
conv_args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
conv_args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
D
.
conv_desc
.
conv_desc
,
D
.
dst_desc
.
desc
,
m_cudnn_enum
,
&
workspace_size
);
...
...
@@ -65,7 +66,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle(
conv_args
.
init_conv_desc
(
D
);
size_t
conv_workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
conv_args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
conv_args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
D
.
conv_desc
.
conv_desc
,
D
.
dst_desc
.
desc
,
m_cudnn_enum
,
&
conv_workspace_size
);
...
...
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
浏览文件 @
5419a95d
...
...
@@ -108,7 +108,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
megdnn_throw
(
"unsupported NonlineMode"
);
}
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
D
.
conv_desc
.
conv_desc
,
D
.
dst_desc
.
desc
,
m_cudnn_enum
,
&
workspace_size
);
...
...
@@ -121,7 +122,8 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes(
args
.
init_conv_bias_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
D
.
conv_desc
.
conv_desc
,
D
.
dst_desc
.
desc
,
m_cudnn_enum
,
&
workspace_size
);
...
...
dnn/src/cuda/conv_bias/opr_impl.cpp
浏览文件 @
5419a95d
...
...
@@ -83,12 +83,13 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
CUDNNForwardDescs
desc
;
conv_args
.
init_conv_desc
(
desc
);
#if CUDNN_MAJOR >= 7
auto
&
cudnn
=
static_cast
<
HandleImpl
*>
(
this
->
handle
())
->
cudnn
();
int
max_count
=
0
;
cudnn_check
(
cudnnGetConvolutionForwardAlgorithmMaxCount
(
cudnn_handle
,
cudnn_check
(
cudnn
.
GetConvolutionForwardAlgorithmMaxCount
(
cudnn_handle
,
&
max_count
));
SmallVector
<
cudnnConvolutionFwdAlgoPerf_t
>
algo_perf
(
max_count
);
int
ret_count
=
0
;
cudnn_check
(
cudnnGetConvolutionForwardAlgorithm_v7
(
cudnn_check
(
cudnn
.
GetConvolutionForwardAlgorithm_v7
(
cudnn_handle
,
desc
.
src_desc
.
desc
,
desc
.
filter_desc
.
desc
,
desc
.
conv_desc
.
conv_desc
,
desc
.
dst_desc
.
desc
,
max_count
,
&
ret_count
,
algo_perf
.
data
()));
...
...
dnn/src/cuda/convolution/backward_data/cudnn.cpp
浏览文件 @
5419a95d
...
...
@@ -44,9 +44,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
}
#endif
auto
&
cudnn
=
args
.
handle
->
cudnn
();
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
auto
status
=
cudnn
.
GetConvolutionBackwardDataWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
filter_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
@@ -59,10 +60,11 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
size_t
ConvolutionBackwardDataImpl
::
AlgoCUDNN
::
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
auto
&
cudnn
=
args
.
handle
->
cudnn
();
CUDNNBwdDataDescs
D
;
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
auto
status
=
cudnn
.
GetConvolutionBackwardDataWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
filter_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
dnn/src/cuda/convolution/backward_filter/cudnn.cpp
浏览文件 @
5419a95d
...
...
@@ -21,6 +21,7 @@ using namespace convolution;
bool
ConvolutionBackwardFilterImpl
::
AlgoCUDNN
::
is_available
(
const
SizeArgs
&
args
)
const
{
auto
&
cudnn
=
args
.
handle
->
cudnn
();
CUDNNBwdFilterDescs
D
;
if
(
!
is_cudnn_supported
(
args
.
as_fwd_args
()))
...
...
@@ -28,7 +29,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available(
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
auto
status
=
cudnn
.
GetConvolutionBackwardFilterWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
@@ -41,10 +42,11 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available(
size_t
ConvolutionBackwardFilterImpl
::
AlgoCUDNN
::
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
auto
&
cudnn
=
args
.
handle
->
cudnn
();
CUDNNBwdFilterDescs
D
;
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
auto
status
=
cudnn
.
GetConvolutionBackwardFilterWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
dnn/src/cuda/convolution/opr_impl.cpp
浏览文件 @
5419a95d
...
...
@@ -141,12 +141,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter,
#if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR
(
negative_attr
);
auto
&
cudnn
=
args
.
handle
->
cudnn
();
int
max_count
=
0
;
cudnn_check
(
cudnnGetConvolutionBackwardDataAlgorithmMaxCount
(
cudnn_check
(
cudnn
.
GetConvolutionBackwardDataAlgorithmMaxCount
(
cudnn_handle
,
&
max_count
));
SmallVector
<
cudnnConvolutionBwdDataAlgoPerf_t
>
algo_perf
(
max_count
);
int
ret_count
=
0
;
cudnn_check
(
cudnnGetConvolutionBackwardDataAlgorithm_v7
(
cudnn_check
(
cudnn
.
GetConvolutionBackwardDataAlgorithm_v7
(
cudnn_handle
,
desc
.
filter_desc
.
desc
,
desc
.
diff_desc
.
desc
,
desc
.
conv_desc
.
desc
,
desc
.
grad_desc
.
desc
,
max_count
,
&
ret_count
,
algo_perf
.
data
()));
...
...
@@ -286,12 +287,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic(
#endif
#if CUDNN_MAJOR >= 7
MEGDNN_MARK_USED_VAR
(
negative_attr
);
auto
&
cudnn
=
args
.
handle
->
cudnn
();
int
max_count
=
0
;
cudnn_check
(
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount
(
cudnn_check
(
cudnn
.
GetConvolutionBackwardFilterAlgorithmMaxCount
(
cudnn_handle
,
&
max_count
));
SmallVector
<
cudnnConvolutionBwdFilterAlgoPerf_t
>
algo_perf
(
max_count
);
int
ret_count
=
0
;
cudnn_check
(
cudnnGetConvolutionBackwardFilterAlgorithm_v7
(
cudnn_check
(
cudnn
.
GetConvolutionBackwardFilterAlgorithm_v7
(
cudnn_handle
,
desc
.
src_desc
.
desc
,
desc
.
diff_desc
.
desc
,
desc
.
conv_desc
.
desc
,
desc
.
grad_desc
.
desc
,
max_count
,
&
ret_count
,
algo_perf
.
data
()));
...
...
dnn/src/cuda/convolution3d/backward_data/cudnn.cpp
浏览文件 @
5419a95d
...
...
@@ -28,7 +28,8 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available(
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionBackwardDataWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
filter_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
@@ -44,7 +45,8 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes(
CUDNNBwdDataDescs
D
;
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionBackwardDataWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
filter_desc
.
desc
,
D
.
diff_desc
.
desc
,
...
...
dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp
浏览文件 @
5419a95d
...
...
@@ -28,7 +28,8 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available(
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionBackwardFilterWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
diff_desc
.
desc
,
D
.
conv_desc
.
desc
,
D
.
grad_desc
.
desc
,
m_cudnn_enum
,
&
workspace_size
);
return
status
==
CUDNN_STATUS_SUCCESS
;
...
...
@@ -40,7 +41,8 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes(
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionBackwardFilterWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
diff_desc
.
desc
,
D
.
conv_desc
.
desc
,
D
.
grad_desc
.
desc
,
m_cudnn_enum
,
&
workspace_size
);
megdnn_assert
(
status
==
CUDNN_STATUS_SUCCESS
,
...
...
dnn/src/cuda/convolution3d/forward/cudnn.cpp
浏览文件 @
5419a95d
...
...
@@ -27,7 +27,8 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available(
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
...
...
@@ -43,7 +44,8 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes(
CUDNNForwardDescs
D
;
args
.
init_desc
(
D
);
size_t
workspace_size
;
auto
status
=
cudnnGetConvolutionForwardWorkspaceSize
(
auto
&
cudnn
=
args
.
handle
->
cudnn
();
auto
status
=
cudnn
.
GetConvolutionForwardWorkspaceSize
(
args
.
handle
->
cudnn_handle
(),
D
.
src_desc
.
desc
,
D
.
filter_desc
.
desc
,
...
...
dnn/src/cuda/convolution3d/helper.h
浏览文件 @
5419a95d
...
...
@@ -92,7 +92,7 @@ namespace convolution3d {
const
Workspace
&
workspace
,
void
*&
raw_ptr
);
inline
bool
cudnn_get_convolution_fwd_algo_helper
(
cudnnHandle_t
cudnn_
handle
,
const
cudnnTensorDescriptor_t
x_desc
,
Handle
*
handle
,
const
cudnnTensorDescriptor_t
x_desc
,
const
cudnnFilterDescriptor_t
w_desc
,
const
cudnnConvolutionDescriptor_t
conv_desc
,
const
cudnnTensorDescriptor_t
y_desc
,
...
...
@@ -102,13 +102,14 @@ namespace convolution3d {
MEGDNN_MARK_USED_VAR
(
positive_attr
);
MEGDNN_MARK_USED_VAR
(
negative_attr
);
#if CUDNN_MAJOR >= 7
auto
&
cudnn
=
static_cast
<
HandleImpl
*>
(
handle
)
->
cudnn
();
int
algo_max_count
=
0
;
cudnn_check
(
cudnnGetConvolutionForwardAlgorithmMaxCount
(
cud
nn_handle
,
&
algo_max_count
));
cudnn_check
(
cudnn
.
GetConvolutionForwardAlgorithmMaxCount
(
cud
a
::
cudnn_handle
(
handle
)
,
&
algo_max_count
));
SmallVector
<
cudnnConvolutionFwdAlgoPerf_t
>
algo_perf
(
algo_max_count
);
int
algo_count
=
0
;
cudnn_check
(
cudnnGetConvolutionForwardAlgorithm_v7
(
cud
nn_handle
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
algo_max_count
,
cudnn_check
(
cudnn
.
GetConvolutionForwardAlgorithm_v7
(
cud
a
::
cudnn_handle
(
handle
)
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
algo_max_count
,
&
algo_count
,
algo_perf
.
data
()));
for
(
int
i
=
0
;
i
<
algo_count
;
++
i
)
{
if
(
algo_perf
[
i
].
algo
==
...
...
@@ -116,8 +117,8 @@ namespace convolution3d {
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
)
continue
;
size_t
workspace_size
=
0
;
cudnn_check
(
cudnnGetConvolutionForwardWorkspaceSize
(
cud
nn_handle
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
cudnn_check
(
cudnn
.
GetConvolutionForwardWorkspaceSize
(
cud
a
::
cudnn_handle
(
handle
)
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
algo_perf
[
i
].
algo
,
&
workspace_size
));
if
(
workspace_size
>
workspace_limit_in_bytes
)
continue
;
if
(
!
(
positive_attr
&
AlgoAttribute
::
REPRODUCIBLE
))
{
...
...
@@ -133,7 +134,7 @@ namespace convolution3d {
return
false
;
#else
cudnn_check
(
cudnnGetConvolutionForwardAlgorithm
(
cud
nn_handle
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
cud
a
::
cudnn_handle
(
handle
)
,
x_desc
,
w_desc
,
conv_desc
,
y_desc
,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
workspace_limit_in_bytes
,
algo
));
return
true
;
...
...
dnn/src/cuda/convolution3d/opr_impl.cpp
浏览文件 @
5419a95d
...
...
@@ -74,13 +74,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic(
auto
get_cudnn_algo
=
[
this
,
&
args
,
workspace_limit_in_bytes
,
positive_attr
,
negative_attr
]()
->
Convolution3DForwardImpl
::
AlgoBase
*
{
auto
cudnn_handle
=
cuda
::
cudnn_handle
(
this
->
handle
());
cudnnConvolutionFwdAlgo_t
algo
;
CUDNNForwardDescs
desc
;
args
.
init_desc
(
desc
);
bool
got
=
cudnn_get_convolution_fwd_algo_helper
(
cudnn_handle
,
desc
.
src_desc
.
desc
,
desc
.
filter_desc
.
desc
,
this
->
handle
()
,
desc
.
src_desc
.
desc
,
desc
.
filter_desc
.
desc
,
desc
.
conv_desc
.
desc
,
desc
.
dst_desc
.
desc
,
workspace_limit_in_bytes
,
&
algo
,
positive_attr
,
negative_attr
);
if
(
got
)
{
...
...
dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh
浏览文件 @
5419a95d
...
...
@@ -56,7 +56,7 @@ namespace convolution {
using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \
using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_
;
static constexpr bool check_bounds = check_bounds_
#define MEGDNN_COMMA ,
template
<
bool
check_bounds_
,
typename
src_ldg_dtype
,
typename
filter_ldg_dtype
,
...
...
dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh
浏览文件 @
5419a95d
...
...
@@ -53,7 +53,7 @@ namespace convolution {
using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \
using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_
;
static constexpr bool check_bounds = check_bounds_
#define MEGDNN_COMMA ,
template
<
bool
check_bounds_
,
typename
IMMAConfig_
,
typename
WarpTileConfig_
,
...
...
dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh
浏览文件 @
5419a95d
...
...
@@ -53,7 +53,7 @@ namespace convolution {
using KernLayout = _kern_layout; \
using OutputLayout = _output_layout; \
using Param = _conv_param; \
static constexpr bool check_bounds = check_bounds_
;
static constexpr bool check_bounds = check_bounds_
#define MEGDNN_COMMA ,
template
<
bool
check_bounds_
,
typename
ldg_dtype
,
typename
RegBlockConfig_
,
...
...
dnn/src/cuda/handle.cpp
浏览文件 @
5419a95d
...
...
@@ -11,12 +11,15 @@
#include "src/common/handle_impl.h"
#include "src/common/version_symbol.h"
#include "src/common/api_cache.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "src/cuda/api_cache.h"
#include <cuda.h>
#include <cstring>
#include <memory>
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
...
...
@@ -88,6 +91,8 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle):
// check tk1
m_is_tegra_k1
=
(
strcmp
(
m_device_prop
->
name
,
"GK20A"
)
==
0
);
m_cusolver_handle
=
nullptr
;
m_cudnn_api_cache
=
std
::
make_unique
<
CUDNN
>
(
m_cudnn_handle
);
}
HandleImpl
::~
HandleImpl
()
noexcept
{
...
...
@@ -133,8 +138,111 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return
HandleVendorType
::
CUDA
;
}
}
// namespace cuda
}
// namespace megdnn
HandleImpl
::
CUDNN
&
HandleImpl
::
cudnn
()
{
return
*
m_cudnn_api_cache
;
}
HandleImpl
::
CUDNN
::
CUDNN
(
cudnnHandle_t
handle
)
{
m_handle
=
handle
;
GetConvolutionForwardWorkspaceSize
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
cudnnConvolutionFwdAlgo_t
>>
()
.
output
<
RefParam
<
size_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionForwardWorkspaceSize
);
#if CUDNN_MAJOR >= 7
GetConvolutionForwardAlgorithm_v7
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionFwdAlgoPerf_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionForwardAlgorithm_v7
);
GetConvolutionForwardAlgorithmMaxCount
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
output
<
RefParam
<
int
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionForwardAlgorithmMaxCount
);
#endif
GetConvolutionBackwardDataWorkspaceSize
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
cudnnConvolutionBwdDataAlgo_t
>>
()
.
output
<
RefParam
<
size_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardDataWorkspaceSize
);
#if CUDNN_MAJOR >= 7
GetConvolutionBackwardDataAlgorithm_v7
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionBwdDataAlgoPerf_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardDataAlgorithm_v7
);
GetConvolutionBackwardDataAlgorithmMaxCount
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
output
<
RefParam
<
int
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardDataAlgorithmMaxCount
);
#endif
GetConvolutionBackwardFilterWorkspaceSize
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
Param
<
cudnnConvolutionBwdFilterAlgo_t
>>
()
.
output
<
RefParam
<
size_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardFilterWorkspaceSize
);
#if CUDNN_MAJOR >= 7
GetConvolutionBackwardFilterAlgorithm_v7
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnTensorDescParam
>
()
.
input
<
CudnnConvDescParam
>
()
.
input
<
CudnnFilterDescParam
>
()
.
input
<
Param
<
int
>>
()
.
output
<
RefArraySizeParam
<
int
>>
()
.
output
<
ArrayParam
<
int
,
cudnnConvolutionBwdFilterAlgoPerf_t
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardFilterAlgorithm_v7
);
GetConvolutionBackwardFilterAlgorithmMaxCount
=
FunctionCacheBuilder
<>
()
.
input
<
Param
<
cudnnHandle_t
>>
()
.
output
<
RefParam
<
int
>>
()
.
ret
<
Param
<
cudnnStatus_t
>>
()
.
build
(
&
cudnnGetConvolutionBackwardFilterAlgorithmMaxCount
);
#endif
}
}
// namespace cuda
}
// namespace megdnn
MEGDNN_VERSION_SYMBOL
(
CUDA
,
CUDA_VERSION
);
MEGDNN_VERSION_SYMBOL3
(
CUDNN
,
CUDNN_MAJOR
,
CUDNN_MINOR
,
CUDNN_PATCHLEVEL
);
...
...
dnn/src/cuda/handle.h
浏览文件 @
5419a95d
...
...
@@ -124,6 +124,10 @@ class HandleImpl: public HandleImplHelper {
size_t
image2d_pitch_alignment
()
const
override
;
HandleVendorType
vendor_type
()
const
override
;
class
CUDNN
;
CUDNN
&
cudnn
();
private:
bool
m_is_tegra_k1
;
int
m_device_id
;
...
...
@@ -156,9 +160,34 @@ class HandleImpl: public HandleImplHelper {
//! device ptr to const scalars
ConstScalars
*
m_const_scalars
;
std
::
unique_ptr
<
CUDNN
>
m_cudnn_api_cache
;
void
initialize_cusolver
();
};
class
HandleImpl
::
CUDNN
{
cudnnHandle_t
m_handle
;
public:
CUDNN
(
cudnnHandle_t
handle
);
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME;
WRAP_CUDNN_API
(
GetConvolutionForwardWorkspaceSize
);
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API
(
GetConvolutionForwardAlgorithm_v7
);
WRAP_CUDNN_API
(
GetConvolutionForwardAlgorithmMaxCount
);
#endif
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API
(
GetConvolutionBackwardDataAlgorithm_v7
);
WRAP_CUDNN_API
(
GetConvolutionBackwardDataAlgorithmMaxCount
);
#endif
WRAP_CUDNN_API
(
GetConvolutionBackwardDataWorkspaceSize
);
#if CUDNN_MAJOR >= 7
WRAP_CUDNN_API
(
GetConvolutionBackwardFilterAlgorithmMaxCount
);
WRAP_CUDNN_API
(
GetConvolutionBackwardFilterAlgorithm_v7
);
#endif
WRAP_CUDNN_API
(
GetConvolutionBackwardFilterWorkspaceSize
);
#undef WRAP_CUDNN_API
};
}
// namespace cuda
}
// namespace megdnn
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录