Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7b17c118
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7b17c118
编写于
6月 28, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn): make cudnn_frontend work
GitOrigin-RevId: f089f934945790f1e01659b0a25a4615b87b7db2
上级
35e9cc98
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
163 addition
and
226 deletion
+163
-226
dnn/CMakeLists.txt
dnn/CMakeLists.txt
+4
-1
dnn/include/megdnn/algorithm_cache.h
dnn/include/megdnn/algorithm_cache.h
+1
-1
dnn/src/cuda/conv_bias/algo.cpp
dnn/src/cuda/conv_bias/algo.cpp
+1
-1
dnn/src/cuda/conv_bias/algo.h
dnn/src/cuda/conv_bias/algo.h
+7
-7
dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
+0
-11
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
+128
-124
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
+1
-12
dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
+1
-12
dnn/src/cuda/conv_bias/helper.cpp
dnn/src/cuda/conv_bias/helper.cpp
+4
-1
dnn/src/cuda/conv_bias/helper.h
dnn/src/cuda/conv_bias/helper.h
+3
-2
dnn/src/cuda/conv_bias/opr_impl.cpp
dnn/src/cuda/conv_bias/opr_impl.cpp
+1
-1
dnn/src/cuda/conv_bias/opr_impl.h
dnn/src/cuda/conv_bias/opr_impl.h
+3
-3
dnn/src/cuda/cudnn_wrapper_v8.cpp
dnn/src/cuda/cudnn_wrapper_v8.cpp
+6
-14
dnn/src/cuda/cudnn_wrapper_v8.h
dnn/src/cuda/cudnn_wrapper_v8.h
+2
-10
dnn/src/cuda/handle.cpp
dnn/src/cuda/handle.cpp
+0
-15
dnn/test/cuda/conv_v8.cpp
dnn/test/cuda/conv_v8.cpp
+1
-11
未找到文件。
dnn/CMakeLists.txt
浏览文件 @
7b17c118
...
...
@@ -54,7 +54,10 @@ if(MGE_WITH_CUDA)
add_library
(
cutlass INTERFACE
)
target_include_directories
(
cutlass
INTERFACE $<BUILD_INTERFACE:
${
PROJECT_SOURCE_DIR
}
/third_party/cutlass/include>
)
INTERFACE
$<BUILD_INTERFACE:
${
PROJECT_SOURCE_DIR
}
/third_party/cutlass/include>
$<BUILD_INTERFACE:
${
PROJECT_SOURCE_DIR
}
/third_party/cutlass/tools/util/include>
)
add_library
(
cudnn-frontend INTERFACE
)
target_include_directories
(
cudnn-frontend
...
...
dnn/include/megdnn/algorithm_cache.h
浏览文件 @
7b17c118
...
...
@@ -31,7 +31,7 @@ public:
}
};
class
Key
{
struct
Key
{
Handle
*
m_handle
;
uint32_t
m_opr_type
;
const
TensorLayout
*
m_inp_layouts_ptr
;
...
...
dnn/src/cuda/conv_bias/algo.cpp
浏览文件 @
7b17c118
...
...
@@ -15,7 +15,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos
.
push_back
(
&
batched_matmul
);
non_cudnn_algos
.
push_back
(
&
int1_simple
);
#if CUDNN_VERSION >
8004
#if CUDNN_VERSION >
= 8020
all_algos
.
push_back
(
&
cudnn_conv_v8
);
all_algos
.
push_back
(
&
cudnn_conv_bias_activation_v8
);
#endif
...
...
dnn/src/cuda/conv_bias/algo.h
浏览文件 @
7b17c118
...
...
@@ -173,10 +173,10 @@ public:
bool
is_cudnn
()
const
override
{
return
true
;
}
size_t
get_preprocess_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
override
;
SmallVector
<
TensorLayout
>
deduce_preprocessed_filter_layout
(
const
SizeArgs
&
args
)
const
override
;
void
exec_preprocess
(
const
ExecArgs
&
args
)
const
override
;
//
size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override;
//
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
//
const SizeArgs& args) const override;
//
void exec_preprocess(const ExecArgs& args) const override;
protected:
virtual
size_t
cudnn_get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
=
0
;
...
...
@@ -237,7 +237,7 @@ private:
CudnnAlgoPack
::
Attr
m_attr
;
};
#if CUDNN_VERSION >
8004
#if CUDNN_VERSION >
= 8020
class
ConvBiasForwardImpl
::
AlgoCUDNNConvBiasActivationV8
final
:
public
AlgoCUDNNConvBiasActivationBase
{
public:
...
...
@@ -414,7 +414,7 @@ private:
CudnnAlgoPack
::
Attr
m_attr
;
};
#if CUDNN_VERSION >
8004
#if CUDNN_VERSION >
= 8020
class
ConvBiasForwardImpl
::
AlgoCUDNNConvV8
final
:
public
AlgoCUDNNConvBase
{
public:
AlgoCUDNNConvV8
()
:
AlgoCUDNNConvBase
()
{
...
...
@@ -1247,7 +1247,7 @@ public:
AlgoGroupConvGeneral
group
;
AlgoBFloat16
bfloat16
;
AlgoSimpleInt1
int1_simple
;
#if CUDNN_VERSION >
8004
#if CUDNN_VERSION >
= 8020
AlgoCUDNNConvV8
cudnn_conv_v8
;
AlgoCUDNNConvBiasActivationV8
cudnn_conv_bias_activation_v8
;
#endif
...
...
dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/utils.h"
...
...
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs/general.h"
#include "./algo.h"
...
...
@@ -26,19 +15,21 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::get_workspace_in_by
const
SizeArgs
&
args
)
const
{
auto
workspace_size
=
cudnn_get_workspace_in_bytes
(
args
);
auto
&&
param
=
args
.
opr
->
param
();
if
(
args
.
preprocessed_filter
==
nullptr
)
{
if
(
args
.
bias_layout
&&
args
.
bias_layout
->
dtype
!=
dtype
::
Float32
()
&&
args
.
src_layout
->
dtype
.
category
()
!=
DTypeCategory
::
FLOAT
)
{
// cudnn require bias to be float when executing CONFIG_INT
// convert bias to float if bias is not float at first
workspace_size
+=
sizeof
(
float
)
*
args
.
bias_layout
->
span
().
dist_elem
();
}
if
(
param
.
format
==
param
::
ConvBias
::
Format
::
NCHW32
)
{
workspace_size
+=
args
.
filter_layout
->
span
().
dist_byte
()
+
args
.
bias_layout
->
span
().
dist_byte
();
}
// if (args.preprocessed_filter == nullptr) {
if
(
args
.
bias_layout
&&
args
.
bias_layout
->
dtype
!=
dtype
::
Float32
()
&&
args
.
src_layout
->
dtype
.
category
()
!=
DTypeCategory
::
FLOAT
)
{
// cudnn require bias to be float when executing CONFIG_INT
// convert bias to float if bias is not float at first
workspace_size
+=
sizeof
(
float
)
*
args
.
bias_layout
->
span
().
dist_elem
();
}
// #if CUDNN_VERSION >= 7500
// auto&& param = args.opr->param();
// if (param.format == param::ConvBias::Format::NCHW32) {
// workspace_size += args.filter_layout->span().dist_byte() +
// args.bias_layout->span().dist_byte();
// }
// #endif
// }
return
workspace_size
;
}
...
...
@@ -56,55 +47,62 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec(
TensorND
filter_tensor
;
TensorND
bias_tensor
;
auto
&&
param
=
args
.
opr
->
param
();
if
(
args
.
preprocessed_filter
!=
nullptr
)
{
bias_tensor
=
TensorND
{
args
.
bias_tensor
->
layout
,
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
()};
if
(
param
.
format
==
Param
::
Format
::
NCHW32
)
{
megdnn_assert
(
args
.
preprocessed_filter
->
tensors
.
size
()
==
2
);
filter_tensor
=
TensorND
{
args
.
filter_tensor
->
layout
,
args
.
preprocessed_filter
->
tensors
[
1
].
raw_ptr
()};
}
else
{
filter_tensor
=
*
args
.
filter_tensor
;
}
}
else
{
if
(
args
.
bias_layout
&&
args
.
bias_layout
->
dtype
!=
dtype
::
Float32
()
&&
args
.
src_layout
->
dtype
.
category
()
!=
DTypeCategory
::
FLOAT
)
{
auto
cvt
=
args
.
handle
->
create_operator
<
TypeCvt
>
();
auto
float_bias_layout
=
*
args
.
bias_layout
;
auto
converted_bias_layout
=
*
args
.
bias_layout
;
converted_bias_layout
.
dtype
=
dtype
::
QuantizedS32
(
alpha
);
float_bias_layout
.
dtype
=
dtype
::
Float32
();
auto
bias_size_in_bytes
=
float_bias_layout
.
span
().
dist_byte
();
megdnn_assert
(
args
.
workspace
.
size
>=
bias_size_in_bytes
);
cvt
->
exec
(
{
args
.
bias_tensor
->
raw_ptr
(),
converted_bias_layout
},
TensorND
{
workspace_ptr
,
float_bias_layout
});
bias_ptr
=
workspace_ptr
;
workspace_ptr
+=
bias_size_in_bytes
;
workspace_size
-=
bias_size_in_bytes
;
}
if
(
param
.
format
==
Param
::
Format
::
NCHW32
)
{
size_t
reorder_workspace_size
=
args
.
filter_tensor
->
layout
.
span
().
dist_byte
()
+
args
.
bias_tensor
->
layout
.
span
().
dist_byte
();
auto
reorder_filter_ptr
=
workspace_ptr
;
auto
reorder_bias_ptr
=
workspace_ptr
+
args
.
filter_tensor
->
layout
.
span
().
dist_byte
();
cudnn_reorder_filer_and_bias_nchw32
(
cudnn_handle
(
args
.
opr
->
handle
()),
args
.
filter_tensor
->
raw_ptr
(),
args
.
filter_meta
,
bias_ptr
,
reorder_filter_ptr
,
reorder_bias_ptr
);
filter_tensor
=
TensorND
(
args
.
filter_tensor
->
layout
,
reorder_filter_ptr
);
bias_ptr
=
reorder_bias_ptr
;
workspace_ptr
+=
reorder_workspace_size
;
workspace_size
-=
reorder_workspace_size
;
}
else
{
filter_tensor
=
*
args
.
filter_tensor
;
}
// if (args.preprocessed_filter != nullptr) {
// bias_tensor = TensorND{
// args.bias_tensor->layout,
// args.preprocessed_filter->tensors[0].raw_ptr()};
// // #if CUDNN_VERSION >= 7500
// // auto&& param = args.opr->param();
// // if (param.format == Param::Format::NCHW32) {
// // megdnn_assert(args.preprocessed_filter->tensors.size() == 2);
// // filter_tensor = TensorND{
// // args.filter_tensor->layout,
// // args.preprocessed_filter->tensors[1].raw_ptr()};
// // }
// // #else
// filter_tensor = *args.filter_tensor;
// // #endif
// } else {
if
(
args
.
bias_layout
&&
args
.
bias_layout
->
dtype
!=
dtype
::
Float32
()
&&
args
.
src_layout
->
dtype
.
category
()
!=
DTypeCategory
::
FLOAT
)
{
auto
cvt
=
args
.
handle
->
create_operator
<
TypeCvt
>
();
auto
float_bias_layout
=
*
args
.
bias_layout
;
auto
converted_bias_layout
=
*
args
.
bias_layout
;
converted_bias_layout
.
dtype
=
dtype
::
QuantizedS32
(
alpha
);
float_bias_layout
.
dtype
=
dtype
::
Float32
();
auto
bias_size_in_bytes
=
float_bias_layout
.
span
().
dist_byte
();
megdnn_assert
(
args
.
workspace
.
size
>=
bias_size_in_bytes
);
cvt
->
exec
(
{
args
.
bias_tensor
->
raw_ptr
(),
converted_bias_layout
},
TensorND
{
workspace_ptr
,
float_bias_layout
});
bias_ptr
=
workspace_ptr
;
workspace_ptr
+=
bias_size_in_bytes
;
workspace_size
-=
bias_size_in_bytes
;
}
// #if CUDNN_VERSION >= 7500
// auto&& param = args.opr->param();
// if (param.format == Param::Format::NCHW32) {
// size_t reorder_workspace_size =
// args.filter_tensor->layout.span().dist_byte() +
// args.bias_tensor->layout.span().dist_byte();
// auto reorder_filter_ptr = workspace_ptr;
// auto reorder_bias_ptr =
// workspace_ptr +
// args.filter_tensor->layout.span().dist_byte();
// cudnn_reorder_filter_and_bias_nchw32(
// cudnn_handle(args.opr->handle()),
// args.filter_tensor->raw_ptr(), args.filter_meta,
// bias_ptr, reorder_filter_ptr, reorder_bias_ptr);
// filter_tensor = TensorND(args.filter_tensor->layout,
// reorder_filter_ptr); bias_ptr = reorder_bias_ptr; workspace_ptr
// += reorder_workspace_size; workspace_size -=
// reorder_workspace_size;
// }
// #else
filter_tensor
=
*
args
.
filter_tensor
;
// #endif
// }
bias_tensor
=
TensorND
{
args
.
bias_tensor
->
layout
,
bias_ptr
};
ExecArgs
exec_args
{
...
...
@@ -153,58 +151,64 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec(
}
}
size_t
ConvBiasForwardImpl
::
AlgoCUDNNConvBiasActivationBase
::
get_preprocess_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
auto
&&
param
=
args
.
opr
->
param
();
if
(
param
.
format
==
Param
::
Format
::
NCHW32
)
{
return
args
.
bias_layout
->
span
().
dist_byte
();
}
return
0
_z
;
}
SmallVector
<
TensorLayout
>
ConvBiasForwardImpl
::
AlgoCUDNNConvBiasActivationBase
::
deduce_preprocessed_filter_layout
(
const
SizeArgs
&
args
)
const
{
auto
&&
param
=
args
.
opr
->
param
();
if
(
param
.
format
==
Param
::
Format
::
NCHW32
)
{
return
{
args
.
bias_layout
->
collapse_contiguous
(),
args
.
filter_layout
->
collapse_contiguous
()};
}
else
{
return
{
args
.
bias_layout
->
collapse_contiguous
()};
}
}
void
ConvBiasForwardImpl
::
AlgoCUDNNConvBiasActivationBase
::
exec_preprocess
(
const
ExecArgs
&
args
)
const
{
float
alpha
,
beta
;
std
::
tie
(
alpha
,
beta
)
=
cudnn_get_conv_bias_act_scale_param
(
args
.
src_tensor
->
layout
,
args
.
dst_tensor
->
layout
,
args
.
filter_tensor
->
layout
,
args
.
bias_tensor
->
layout
,
args
.
z_tensor
->
layout
);
MEGDNN_MARK_USED_VAR
(
beta
);
auto
workspace_ptr
=
args
.
workspace
.
raw_ptr
;
auto
workspace_size
=
args
.
workspace
.
size
;
auto
bias_ptr
=
workspace_size
>
0
?
workspace_ptr
:
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
();
if
(
args
.
bias_layout
&&
args
.
bias_layout
->
dtype
!=
dtype
::
Float32
()
&&
args
.
src_layout
->
dtype
.
category
()
!=
DTypeCategory
::
FLOAT
)
{
auto
cvt
=
args
.
handle
->
create_operator
<
TypeCvt
>
();
auto
float_bias_layout
=
*
args
.
bias_layout
;
auto
converted_bias_layout
=
*
args
.
bias_layout
;
converted_bias_layout
.
dtype
=
dtype
::
QuantizedS32
(
alpha
);
float_bias_layout
.
dtype
=
dtype
::
Float32
();
cvt
->
exec
(
{
args
.
bias_tensor
->
raw_ptr
(),
converted_bias_layout
},
TensorND
{
bias_ptr
,
float_bias_layout
});
}
if
(
args
.
opr
->
param
().
format
==
Param
::
Format
::
NCHW32
)
{
auto
reorder_filter_ptr
=
args
.
preprocessed_filter
->
tensors
[
1
].
raw_ptr
();
auto
reorder_bias_ptr
=
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
();
cudnn_reorder_filer_and_bias_nchw32
(
cudnn_handle
(
args
.
opr
->
handle
()),
args
.
filter_tensor
->
raw_ptr
(),
args
.
filter_meta
,
bias_ptr
,
reorder_filter_ptr
,
reorder_bias_ptr
);
}
}
// size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::
// get_preprocess_workspace_in_bytes(const SizeArgs&) const {
// #if CUDNN_VERSION >= 7500
// auto&& param = args.opr->param();
// if (param.format == Param::Format::NCHW32) {
// return args.bias_layout->span().dist_byte();
// }
// #endif
// return 0_z;
// }
// SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::
// deduce_preprocessed_filter_layout(const SizeArgs& args) const {
// #if CUDNN_VERSION >= 7500
// auto&& param = args.opr->param();
// if (param.format == Param::Format::NCHW32) {
// return {args.bias_layout->collapse_contiguous(),
// args.filter_layout->collapse_contiguous()};
// }
// #endif
// return {args.bias_layout->collapse_contiguous()};
// }
// void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec_preprocess(
// const ExecArgs& args) const {
// float alpha, beta;
// std::tie(alpha, beta) = cudnn_get_conv_bias_act_scale_param(
// args.src_tensor->layout, args.dst_tensor->layout,
// args.filter_tensor->layout, args.bias_tensor->layout,
// args.z_tensor->layout);
// MEGDNN_MARK_USED_VAR(beta);
// auto workspace_ptr = args.workspace.raw_ptr;
// auto workspace_size = args.workspace.size;
// auto bias_ptr = workspace_size > 0 ? workspace_ptr
// :
// args.preprocessed_filter->tensors[0].raw_ptr();
// if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() &&
// args.src_layout->dtype.category() != DTypeCategory::FLOAT) {
// auto cvt = args.handle->create_operator<TypeCvt>();
// auto float_bias_layout = *args.bias_layout;
// auto converted_bias_layout = *args.bias_layout;
// converted_bias_layout.dtype = dtype::QuantizedS32(alpha);
// float_bias_layout.dtype = dtype::Float32();
// cvt->exec(
// {args.bias_tensor->raw_ptr(), converted_bias_layout},
// TensorND{bias_ptr, float_bias_layout});
// }
// #if CUDNN_VERSION >= 7500
// if (args.opr->param().format == Param::Format::NCHW32) {
// auto reorder_filter_ptr = args.preprocessed_filter->tensors[1].raw_ptr();
// auto reorder_bias_ptr = args.preprocessed_filter->tensors[0].raw_ptr();
// cudnn_reorder_filter_and_bias_nchw32(
// cudnn_handle(args.opr->handle()), args.filter_tensor->raw_ptr(),
// args.filter_meta, bias_ptr, reorder_filter_ptr, reorder_bias_ptr);
// }
// #endif
// }
// vim: syntax=cpp.doxygen
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs/general.h"
#include "./algo.h"
...
...
@@ -17,7 +6,7 @@
#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/utils.h"
#if CUDNN_VERSION >= 80
04
#if CUDNN_VERSION >= 80
20
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
conv_bias
;
...
...
dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/utils.h"
#if CUDNN_VERSION >= 80
04
#if CUDNN_VERSION >= 80
20
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
conv_bias
;
...
...
dnn/src/cuda/conv_bias/helper.cpp
浏览文件 @
7b17c118
...
...
@@ -239,7 +239,8 @@ std::pair<float, float> cudnn_get_conv_bias_act_scale_param(
return
{
alpha
,
beta
};
}
void
cudnn_reorder_filer_and_bias_nchw32
(
#if CUDNN_VERSION >= 7500
void
cudnn_reorder_filter_and_bias_nchw32
(
const
cudnnHandle_t
&
handle
,
const
void
*
filter_ptr
,
const
CanonizedFilterMeta
&
fm
,
const
void
*
bias_ptr
,
void
*
reordered_filter_ptr
,
void
*
reordered_bias_ptr
)
{
...
...
@@ -250,6 +251,8 @@ void cudnn_reorder_filer_and_bias_nchw32(
handle
,
filter_desc
.
desc
,
CUDNN_DEFAULT_REORDER
,
filter_ptr
,
reordered_filter_ptr
,
reorder_bias
,
bias_ptr
,
reordered_bias_ptr
));
}
#endif
}
// namespace conv_bias
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/conv_bias/helper.h
浏览文件 @
7b17c118
...
...
@@ -117,11 +117,12 @@ std::pair<float, float> cudnn_get_conv_bias_act_scale_param(
const
TensorLayout
&
x
,
const
TensorLayout
&
y
,
const
TensorLayout
&
w
,
const
TensorLayout
&
b
,
const
TensorLayout
&
z
);
void
cudnn_reorder_filer_and_bias_nchw32
(
#if CUDNN_VERSION >= 7500
void
cudnn_reorder_filter_and_bias_nchw32
(
const
cudnnHandle_t
&
handle
,
const
void
*
filter_ptr
,
const
CanonizedFilterMeta
&
fm
,
const
void
*
bias_ptr
,
void
*
reordered_filter_ptr
,
void
*
reordered_bias_ptr
);
#endif
}
// namespace conv_bias
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/conv_bias/opr_impl.cpp
浏览文件 @
7b17c118
...
...
@@ -47,7 +47,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
const
AlgoAttribute
&
positive_attr
,
const
AlgoAttribute
&
negative_attr
)
{
using
namespace
conv_bias
;
AlgoBase
::
SizeArgs
args
{
this
,
src
,
filter
,
bias
,
z
,
dst
};
#if CUDNN_VERSION >= 80
04
#if CUDNN_VERSION >= 80
20
if
(
sm_algo_pack
.
cudnn_conv_v8
.
is_available_attribute
(
args
,
positive_attr
,
negative_attr
,
workspace_limit_in_bytes
))
{
return
&
sm_algo_pack
.
cudnn_conv_v8
;
...
...
dnn/src/cuda/conv_bias/opr_impl.h
浏览文件 @
7b17c118
...
...
@@ -32,12 +32,10 @@ public:
const
char
*
get_algorithm_set_name
()
const
override
;
class
AlgoBase
;
class
AlgoCUDNNConvBiasActivation
;
class
AlgoChanwise
;
class
AlgoChanwiseSmall
;
class
AlgoDepthwiseLargeFilter
;
class
AlgoChanwise8x8x32
;
class
AlgoCUDNNConv
;
class
AlgoFallbackNCHWQS8
;
class
AlgoInplaceMatmul
;
class
AlgoMatmul
;
...
...
@@ -67,8 +65,10 @@ public:
class
AlgoFloat32NCHWFMAImplicitBatchedGemm
;
class
AlgoFloat16NCHWHMMAImplicitBatchedGemm
;
class
AlgoCUDNNConvBase
;
class
AlgoCUDNNConv
;
class
AlgoCUDNNConvBiasActivationBase
;
#if CUDNN_VERSION > 8004
class
AlgoCUDNNConvBiasActivation
;
#if CUDNN_VERSION >= 8020
class
AlgoCUDNNConvV8
;
class
AlgoCUDNNConvBiasActivationV8
;
#endif
...
...
dnn/src/cuda/cudnn_wrapper_v8.cpp
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/cudnn_wrapper_v8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#if CUDNN_VERSION >= 8020
#include "src/cuda/cudnn_wrapper_v8.h"
#include "src/cuda/cudnn_wrapper.h"
...
...
@@ -19,7 +10,7 @@
#include "cudnn_frontend_EngineConfigGenerator.h"
#include "megdnn/
heuristic
_cache.h"
#include "megdnn/
algorithm
_cache.h"
using
namespace
megdnn
;
using
namespace
cuda
;
...
...
@@ -240,9 +231,9 @@ auto make_activation_descriptor(
// high-level api for convolution execution
struct
StaticData
{
using
Key
=
megdnn
::
Heuristic
Cache
::
Key
;
using
KeyStorage
=
megdnn
::
Heuristic
Cache
::
KeyStorage
;
using
KeyHash
=
megdnn
::
Heuristic
Cache
::
Hash
;
using
Key
=
megdnn
::
Algorithm
Cache
::
Key
;
using
KeyStorage
=
megdnn
::
Algorithm
Cache
::
KeyStorage
;
using
KeyHash
=
megdnn
::
Algorithm
Cache
::
Hash
;
using
Result
=
cudnn_frontend
::
ExecutionPlan
;
using
CudnnFrontendExecutionPlanCache
=
std
::
unordered_map
<
KeyStorage
,
Result
,
KeyHash
>
;
...
...
@@ -682,4 +673,5 @@ void megdnn::cuda::run_conv_bias_act_with_plan(
handle
,
plan
.
get_raw_desc
(),
variant_pack
.
get_raw_desc
()));
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/cuda/cudnn_wrapper_v8.h
浏览文件 @
7b17c118
/**
* \file dnn/src/cuda/cudnn_wrapper_v8.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#if CUDNN_VERSION >= 8020
#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
#include "src/common/utils.h"
...
...
@@ -67,4 +58,5 @@ void run_conv_bias_act_with_plan(
}
// namespace cuda
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/cuda/handle.cpp
浏览文件 @
7b17c118
...
...
@@ -58,11 +58,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)"
);
}
#endif
size_t
free
,
tot
;
cudaMemGetInfo
(
&
free
,
&
tot
);
printf
(
"before cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB
\n
"
,
free
/
1024.0
/
1024.0
,
tot
/
1024.0
/
1024.0
,
(
tot
-
free
)
/
1024.0
/
1024.0
);
cudnn_check
(
cudnnCreate
(
&
m_cudnn_handle
));
cublas_check
(
cublasCreate
(
&
m_cublas_handle
));
#if CUDA_VERSION >= 10010
...
...
@@ -74,11 +69,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
cudnn_check
(
cudnnSetStream
(
m_cudnn_handle
,
stream
()));
cublas_check
(
cublasSetStream
(
m_cublas_handle
,
stream
()));
#if CUDNN_VERSION >= 8004
// cudnn_check(cudnnOpsInferVersionCheck());
// cudnn_check(cudnnCnnInferVersionCheck());
#endif
// Note that all cublas scalars (alpha, beta) and scalar results such as dot
// output resides at device side.
cublas_check
(
cublasSetPointerMode
(
m_cublas_handle
,
CUBLAS_POINTER_MODE_DEVICE
));
...
...
@@ -92,11 +82,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle)
cudaMemcpyHostToDevice
,
stream
()));
cuda_check
(
cudaStreamSynchronize
(
stream
()));
cudaMemGetInfo
(
&
free
,
&
tot
);
printf
(
"after cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB
\n
"
,
free
/
1024.0
/
1024.0
,
tot
/
1024.0
/
1024.0
,
(
tot
-
free
)
/
1024.0
/
1024.0
);
// check tk1
m_is_tegra_k1
=
(
strcmp
(
m_device_prop
->
name
,
"GK20A"
)
==
0
);
m_cusolver_handle
=
nullptr
;
...
...
dnn/test/cuda/conv_v8.cpp
浏览文件 @
7b17c118
/**
* \file dnn/test/cuda/conv_bias.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "test/cuda/fixture.h"
...
...
@@ -26,7 +16,7 @@ using namespace megdnn;
using
namespace
test
;
using
namespace
conv_bias
;
#if CUDNN_VERSION >= 80
04
#if CUDNN_VERSION >= 80
20
TEST_F
(
CUDA
,
CONV_V8_FLOAT
)
{
Checker
<
ConvBiasForward
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录