Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
df009e89
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
df009e89
编写于
3月 12, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add cuda conv bias impls for NCHW format tensors with qint4 data type
GitOrigin-RevId: a0a08cf42c01974d376c9852cd40bb9e3ccb09f4
上级
ed922075
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
174 addition
and
0 deletion
+174
-0
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
+8
-0
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
+165
-0
dnn/src/cuda/conv_bias/opr_impl.h
dnn/src/cuda/conv_bias/opr_impl.h
+1
-0
未找到文件。
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
浏览文件 @
df009e89
...
...
@@ -126,6 +126,14 @@ void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width(
size_t dh = _param.dilate_h, dw = _param.dilate_w; \
size_t fh = _filter_meta.spatial[0], fw = _filter_meta.spatial[1];
#define UNPACK_CONV_BIAS_NCHW_PARAM(_src, _filter_meta, _dst, _param) \
using Format = param::ConvBias::Format; \
megdnn_assert(_param.format == Format::NCHW); \
size_t n = (_src)[0], ci = (_src)[1], hi = (_src)[2], wi = (_src)[3]; \
size_t co = (_dst)[1], ho = (_dst)[2], wo = (_dst)[3]; \
UNPACK_CONV_PARAMETER(_filter_meta, _param); \
MARK_USED_VAR
#define UNPACK_CONV_BIAS_NCHW4_PARAM(_src, _filter_meta, _dst, _param) \
using Format = param::ConvBias::Format; \
megdnn_assert(_param.format == Format::NCHW4); \
...
...
dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
0 → 100644
浏览文件 @
df009e89
/**
* \file dnn/src/cuda/conv_bias/fallback_nchw_qs4.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "src/cuda/utils.h"
using
namespace
megdnn
;
using
namespace
cuda
;
bool
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
is_available
(
const
SizeArgs
&
args
)
const
{
if
(
args
.
bias_layout
->
ndim
<=
0
)
return
false
;
using
Param
=
param
::
ConvBias
;
using
Format
=
Param
::
Format
;
using
Sparse
=
Param
::
Sparse
;
using
Mode
=
Param
::
Mode
;
bool
available
=
true
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
if
(
!
conv_bias
::
check_bias_share_in_channel
(
*
(
args
.
bias_layout
),
param
.
format
))
return
false
;
if
(
param
.
format
!=
Format
::
NCHW
)
return
false
;
UNPACK_CONV_BIAS_NCHW_PARAM
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
dst_layout
),
param
);
// TODO support group conv
available
&=
param
.
sparse
==
Sparse
::
DENSE
;
// mode must be cross correlation
available
&=
param
.
mode
==
Mode
::
CROSS_CORRELATION
;
// check data type
auto
src_dtype
=
args
.
src_layout
->
dtype
,
filter_dtype
=
args
.
filter_layout
->
dtype
,
bias_dtype
=
args
.
bias_layout
->
dtype
,
dst_dtype
=
args
.
dst_layout
->
dtype
;
available
&=
(
src_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
filter_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
bias_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
&&
dst_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
);
// TODO: support dialtion
available
&=
dh
==
1
&&
dw
==
1
;
// ensure precomputed offsets are positive integers
available
&=
hi
>=
fh
&&
wi
>=
fw
;
// only support sm_75 or later, platform should have tensorcore int8
// support
available
&=
is_compute_capability_required
(
7
,
5
);
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available
&=
fh
*
fw
<=
191
;
// channels should be multiples of 64
available
&=
ci
%
64
==
0
&&
co
%
64
==
0
;
return
available
;
}
size_t
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
return
get_workspace_bundle
(
nullptr
,
args
).
total_size_in_bytes
();
}
void
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
exec
(
const
ExecArgs
&
args
)
const
{
using
Format
=
Param
::
Format
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
auto
layouts
=
make_underlying_tensor_layout
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
bias_layout
),
*
(
args
.
z_layout
),
*
(
args
.
dst_layout
));
auto
ws
=
get_workspace_bundle
(
args
.
workspace
.
raw_ptr
,
args
);
auto
ws_src
=
ws
.
get
(
0
);
auto
ws_filter
=
ws
.
get
(
1
);
auto
ws_dst
=
ws
.
get
(
2
);
void
*
ws_z
=
nullptr
;
if
(
args
.
z_layout
->
ndim
>
0
)
ws_z
=
ws
.
get
(
4
);
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
auto
nchw2nchw64
=
[](
const
TensorND
&
src
,
void
*
raw_dptr
)
{
if
(
raw_dptr
==
nullptr
)
return
;
};
auto
nchw642nchw
=
[](
const
TensorND
&
src
,
void
*
raw_dptr
)
{
};
// reformat src
nchw2nchw64
(
*
(
args
.
src_tensor
),
ws_src
);
// reformat filter
nchw2nchw64
(
*
(
args
.
filter_tensor
),
ws_filter
);
// reformat z
nchw2nchw64
(
*
(
args
.
z_tensor
),
ws_z
);
TensorND
src_
{
ws_src
,
layouts
[
0
]},
filter_
{
ws_filter
,
layouts
[
1
]},
bias_
{
args
.
bias_tensor
->
raw_ptr
,
layouts
[
2
]},
z_
{
ws_z
,
layouts
[
3
]},
dst_
{
ws_dst
,
layouts
[
4
]};
ExecArgs
args_
{
args
.
opr
,
src_
,
filter_
,
bias_
,
z_
,
dst_
,
ws
.
get_workspace
(
3
),
args
.
preprocessed_filter
};
m_underlying_algo
.
exec
(
args
);
// reformat dst
nchw642nchw
(
dst_
,
args
.
dst_tensor
->
raw_ptr
);
}
SmallVector
<
TensorLayout
>
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
make_underlying_tensor_layout
(
const
TensorLayout
&
src
,
const
CanonizedFilterMeta
&
filter_meta
,
const
TensorLayout
&
bias
,
const
TensorLayout
&
z
,
const
TensorLayout
&
dst
)
const
{
size_t
n
=
src
[
0
],
ci
=
src
[
1
],
hi
=
src
[
2
],
wi
=
src
[
3
];
size_t
co
=
dst
[
1
],
ho
=
dst
[
2
],
wo
=
dst
[
3
];
size_t
fh
=
filter_meta
.
spatial
[
0
],
fw
=
filter_meta
.
spatial
[
1
];
SmallVector
<
TensorLayout
>
rst
;
rst
.
emplace_back
(
TensorLayout
{{
n
,
ci
/
64
,
hi
,
wi
,
64
},
src
.
dtype
});
rst
.
emplace_back
(
TensorLayout
{{
co
,
ci
/
64
,
fh
,
fw
,
64
},
filter
.
dtype
});
rst
.
emplace_back
(
TensorLayout
{{
1
,
co
/
64
,
1
,
1
,
64
},
bias
.
dtype
});
if
(
z
.
layout
.
ndim
>
0
)
{
rst
.
emplace_back
(
TensorLayout
{{
n
,
co
/
64
,
ho
,
wo
,
64
},
z
.
dtype
});
}
else
{
rst
.
emplace_back
(
TensorLayout
{});
}
rst
.
emplace_back
(
TensorLayout
{{
n
,
co
/
64
,
ho
,
wo
,
64
},
dst
.
dtype
});
return
rst
;
}
WorkspaceBundle
ConvBiasForwardImpl
::
AlgoFallbackNCHWQS4
::
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
{
size_t
ws_size_src
=
args
.
src_layout
->
span
().
dist_byte
();
size_t
ws_size_filter
=
args
.
filter_layout
->
span
().
dist_byte
();
size_t
ws_size_dst
=
args
.
dst_layout
->
span
().
dist_byte
();
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
auto
layouts
=
make_underlying_tensor_layout
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
bias_layout
),
*
(
args
.
z_layout
),
*
(
args
.
dst_layout
));
SizeArgs
args_
{
args
.
opr
,
layouts
[
0
],
layouts
[
1
],
layouts
[
2
],
layouts
[
3
],
layouts
[
4
],
args
.
preprocessed_filter
};
size_t
ws_size_underlying_algo
=
m_underlying_algo
.
get_workspace_in_bytes
(
args_
);
if
(
args
.
z_layout
->
ndim
>
0
)
{
size_t
ws_size_z
=
args
.
z_layout
->
span
().
dist_byte
();
return
WorkspaceBundle
{
raw_ptr
,
{
ws_size_src
,
ws_size_filter
,
ws_size_dst
,
ws_size_underlying_algo
,
ws_size_z
}};
}
return
WorkspaceBundle
{
raw_ptr
,
{
ws_size_src
,
ws_size_filter
,
ws_size_underlying_algo
,
ws_size_dst
}};
}
// vim: syntax=cpp.doxygen
dnn/src/cuda/conv_bias/opr_impl.h
浏览文件 @
df009e89
...
...
@@ -64,6 +64,7 @@ public:
class
AlgoInt8CHWN4IMMAImplicitGemmReorderFilter
;
class
AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth
;
class
AlgoInt8NCHW32IMMAImplicitGemm
;
class
AlgoFallbackNCHWQS4
;
class
AlgoBFloat16
;
class
AlgoPack
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录