Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
09dab387
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
09dab387
编写于
3月 11, 2022
作者:
M
Megvii Engine Team
提交者:
“wenjuan”
4月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): support int1 simplewq conv
GitOrigin-RevId: 9c37c41bc7e450f3df81e6059603101de3f14416
上级
6554e262
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
222 addition
and
2 deletion
+222
-2
dnn/src/common/convolution.cpp
dnn/src/common/convolution.cpp
+2
-1
dnn/src/cuda/conv_bias/algo.cpp
dnn/src/cuda/conv_bias/algo.cpp
+2
-1
dnn/src/cuda/conv_bias/algo.h
dnn/src/cuda/conv_bias/algo.h
+20
-0
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
+2
-0
dnn/src/cuda/conv_bias/helper.cpp
dnn/src/cuda/conv_bias/helper.cpp
+3
-0
dnn/src/cuda/conv_bias/opr_impl.cpp
dnn/src/cuda/conv_bias/opr_impl.cpp
+5
-0
dnn/src/cuda/conv_bias/opr_impl.h
dnn/src/cuda/conv_bias/opr_impl.h
+1
-0
dnn/src/cuda/conv_bias/simple_int1.cpp
dnn/src/cuda/conv_bias/simple_int1.cpp
+145
-0
dnn/src/cuda/convolution/forward/algos.cpp
dnn/src/cuda/convolution/forward/algos.cpp
+4
-0
dnn/src/naive/conv_bias/opr_impl.cpp
dnn/src/naive/conv_bias/opr_impl.cpp
+3
-0
dnn/src/naive/convolution/helper.h
dnn/src/naive/convolution/helper.h
+9
-0
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+26
-0
未找到文件。
dnn/src/common/convolution.cpp
浏览文件 @
09dab387
...
@@ -561,7 +561,8 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
...
@@ -561,7 +561,8 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(
src
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
src
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
||
src
.
enumv
()
==
DTypeEnum
::
QuantizedS1
)
{
supported_dst_dtype
.
push_back
(
dtype
::
QuantizedS32
(
mul_scale
(
src
,
filter
)));
supported_dst_dtype
.
push_back
(
dtype
::
QuantizedS32
(
mul_scale
(
src
,
filter
)));
bool
cond_dst
=
dst
.
valid
()
&&
(
dst
.
enumv
()
==
src
.
enumv
()
||
bool
cond_dst
=
dst
.
valid
()
&&
(
dst
.
enumv
()
==
src
.
enumv
()
||
((
dst
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
((
dst
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
...
...
dnn/src/cuda/conv_bias/algo.cpp
浏览文件 @
09dab387
...
@@ -25,7 +25,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
...
@@ -25,7 +25,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos
.
push_back
(
&
matmul
);
non_cudnn_algos
.
push_back
(
&
matmul
);
non_cudnn_algos
.
push_back
(
&
matmul8x8x32
);
non_cudnn_algos
.
push_back
(
&
matmul8x8x32
);
non_cudnn_algos
.
push_back
(
&
batched_matmul
);
non_cudnn_algos
.
push_back
(
&
batched_matmul
);
non_cudnn_algos
.
push_back
(
&
int1_simple
);
fill_cudnn_algos
();
fill_cudnn_algos
();
for
(
auto
&&
algo
:
cudnn_conv_bias_activations
)
{
for
(
auto
&&
algo
:
cudnn_conv_bias_activations
)
{
all_algos
.
push_back
(
&
algo
);
all_algos
.
push_back
(
&
algo
);
...
@@ -45,6 +45,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
...
@@ -45,6 +45,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
conv_algos
.
push_back
(
&
matmul8x8x32
);
conv_algos
.
push_back
(
&
matmul8x8x32
);
conv_algos
.
push_back
(
&
batched_matmul
);
conv_algos
.
push_back
(
&
batched_matmul
);
conv_algos
.
push_back
(
&
group
);
conv_algos
.
push_back
(
&
group
);
conv_algos
.
push_back
(
&
int1_simple
);
for
(
auto
&&
algo
:
conv_algos
)
{
for
(
auto
&&
algo
:
conv_algos
)
{
all_algos
.
push_back
(
algo
);
all_algos
.
push_back
(
algo
);
...
...
dnn/src/cuda/conv_bias/algo.h
浏览文件 @
09dab387
...
@@ -87,6 +87,7 @@ public:
...
@@ -87,6 +87,7 @@ public:
CUDA_FALLBACK_NCHW_INT4
,
CUDA_FALLBACK_NCHW_INT4
,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32
,
CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32
,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16
,
CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16
,
CUDA_SIMPLE_INT1
,
};
};
using
Mapper
=
std
::
unordered_map
<
AlgorithmDesc
,
AlgoBase
*>
;
using
Mapper
=
std
::
unordered_map
<
AlgorithmDesc
,
AlgoBase
*>
;
...
@@ -1089,6 +1090,24 @@ private:
...
@@ -1089,6 +1090,24 @@ private:
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
;
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
;
};
};
class
ConvBiasForwardImpl
::
AlgoSimpleInt1
final
:
public
AlgoBase
{
public:
bool
is_available
(
const
SizeArgs
&
args
)
const
override
;
size_t
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
override
;
void
exec
(
const
ExecArgs
&
args
)
const
override
;
std
::
vector
<
SearchItem
>
get_subopr_list
(
const
TensorLayoutArray
&
layouts
,
const
OperatorBase
*
opr
)
const
override
;
const
char
*
name
()
const
override
{
return
"CONVBIAS_SIMPLE_INT1"
;
}
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
MEGDNN_DECL_ALGO_TYPE
(
CUDA_SIMPLE_INT1
)
private:
WorkspaceBundle
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
;
};
class
ConvBiasForwardImpl
::
AlgoPack
:
NonCopyableObj
{
class
ConvBiasForwardImpl
::
AlgoPack
:
NonCopyableObj
{
private:
private:
AlgoBase
::
Mapper
m_all_algos_map
;
AlgoBase
::
Mapper
m_all_algos_map
;
...
@@ -1132,6 +1151,7 @@ public:
...
@@ -1132,6 +1151,7 @@ public:
std
::
vector
<
AlgoFloat16NCHWHMMAImplicitBatchedGemm
>
f16_implicit_bmm
;
std
::
vector
<
AlgoFloat16NCHWHMMAImplicitBatchedGemm
>
f16_implicit_bmm
;
AlgoGroupConvGeneral
group
;
AlgoGroupConvGeneral
group
;
AlgoBFloat16
bfloat16
;
AlgoBFloat16
bfloat16
;
AlgoSimpleInt1
int1_simple
;
AlgoBase
*
cudnn_conv_bias_act_from_enum
(
cudnnConvolutionFwdAlgo_t
algo
);
AlgoBase
*
cudnn_conv_bias_act_from_enum
(
cudnnConvolutionFwdAlgo_t
algo
);
...
...
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
浏览文件 @
09dab387
...
@@ -30,6 +30,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
...
@@ -30,6 +30,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return
false
;
return
false
;
}
}
}
}
if
(
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS1
)
return
false
;
if
((
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
if
((
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
args
.
filter_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
args
.
filter_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
...
...
dnn/src/cuda/conv_bias/helper.cpp
浏览文件 @
09dab387
...
@@ -134,6 +134,9 @@ void ConvBiasDesc::set_conv(
...
@@ -134,6 +134,9 @@ void ConvBiasDesc::set_conv(
namespace
conv_bias
{
namespace
conv_bias
{
bool
is_cudnn_supported
(
const
BiasForwardSizeArgs
&
args
)
{
bool
is_cudnn_supported
(
const
BiasForwardSizeArgs
&
args
)
{
if
(
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS1
)
return
false
;
if
((
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
if
((
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
&&
args
.
filter_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
args
.
filter_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
)
...
...
dnn/src/cuda/conv_bias/opr_impl.cpp
浏览文件 @
09dab387
...
@@ -221,6 +221,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
...
@@ -221,6 +221,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
return
&
sm_algo_pack
.
fallback_nchw_qs8
;
return
&
sm_algo_pack
.
fallback_nchw_qs8
;
}
}
if
(
sm_algo_pack
.
int1_simple
.
is_available_attribute
(
args
,
positive_attr
,
negative_attr
,
workspace_limit_in_bytes
))
{
return
&
sm_algo_pack
.
int1_simple
;
}
if
(
args
.
src_layout
->
dtype
.
enumv
()
!=
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
if
(
args
.
src_layout
->
dtype
.
enumv
()
!=
DTypeTrait
<
dtype
::
BFloat16
>::
enumv
)
{
return
megdnn
::
get_algo_match_attribute
<
ConvBiasForwardImpl
>
(
return
megdnn
::
get_algo_match_attribute
<
ConvBiasForwardImpl
>
(
sm_algo_pack
.
non_cudnn_algos
,
args
,
workspace_limit_in_bytes
,
sm_algo_pack
.
non_cudnn_algos
,
args
,
workspace_limit_in_bytes
,
...
...
dnn/src/cuda/conv_bias/opr_impl.h
浏览文件 @
09dab387
...
@@ -72,6 +72,7 @@ public:
...
@@ -72,6 +72,7 @@ public:
class
AlgoInt4Int4NHWCIMMAImplicitGemm
;
class
AlgoInt4Int4NHWCIMMAImplicitGemm
;
class
AlgoUInt4Int4NHWCIMMAImplicitGemm
;
class
AlgoUInt4Int4NHWCIMMAImplicitGemm
;
class
AlgoBFloat16
;
class
AlgoBFloat16
;
class
AlgoSimpleInt1
;
// The following algorithms are suitable for channel wise convolution
// The following algorithms are suitable for channel wise convolution
class
AlgoFloat32NCHWFMAImplicitBatchedGemm
;
class
AlgoFloat32NCHWFMAImplicitBatchedGemm
;
class
AlgoFloat16NCHWHMMAImplicitBatchedGemm
;
class
AlgoFloat16NCHWHMMAImplicitBatchedGemm
;
...
...
dnn/src/cuda/conv_bias/simple_int1.cpp
0 → 100644
浏览文件 @
09dab387
/**
* \file dnn/src/cuda/conv_bias/simple_int1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/algo_base.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
conv_bias
;
namespace
{
std
::
pair
<
TensorLayoutArray
,
ConvBiasForwardImpl
::
Param
>
sub_opr_config
(
const
TensorLayoutArray
&
layouts
,
const
ConvBiasForwardImpl
*
opr
)
{
megdnn_assert
(
layouts
.
size
()
>=
3
);
std
::
pair
<
TensorLayoutArray
,
ConvBiasForwardImpl
::
Param
>
ret
;
ret
.
first
=
layouts
;
auto
change_dtype
=
[](
TensorLayout
&
layout
)
{
if
(
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS1
||
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
layout
.
dtype
=
dtype
::
Float32
();
}
};
change_dtype
(
ret
.
first
[
0
]);
change_dtype
(
ret
.
first
[
1
]);
change_dtype
(
ret
.
first
[
2
]);
change_dtype
(
ret
.
first
[
3
]);
change_dtype
(
ret
.
first
[
4
]);
ret
.
second
=
opr
->
param
();
ret
.
second
.
compute_mode
=
ConvBiasForwardImpl
::
Param
::
ComputeMode
::
DEFAULT
;
return
ret
;
}
std
::
pair
<
TensorLayoutArray
,
std
::
unique_ptr
<
ConvBiasForward
>>
prepare_sub_opr
(
const
ConvBiasForwardImpl
::
AlgoBase
::
SizeArgs
&
args
)
{
auto
convbias_opr
=
args
.
handle
->
create_operator
<
ConvBias
>
();
auto
&&
config
=
sub_opr_config
(
{
*
args
.
src_layout
,
*
args
.
filter_layout
,
*
args
.
bias_layout
,
*
args
.
z_layout
,
*
args
.
dst_layout
},
args
.
opr
);
convbias_opr
->
param
()
=
config
.
second
;
return
{
config
.
first
,
std
::
move
(
convbias_opr
)};
}
}
// namespace
std
::
vector
<
Algorithm
::
SearchItem
>
ConvBiasForwardImpl
::
AlgoSimpleInt1
::
get_subopr_list
(
const
TensorLayoutArray
&
layouts
,
const
OperatorBase
*
opr
)
const
{
auto
&&
config
=
sub_opr_config
(
layouts
,
static_cast
<
const
ConvBiasForwardImpl
*>
(
opr
));
std
::
string
param_str
;
Algorithm
::
serialize_write_pod
(
config
.
second
,
param_str
);
return
{{
Algorithm
::
OprType
::
CONVBIAS_FORWARD
,
param_str
,
config
.
first
}};
}
bool
ConvBiasForwardImpl
::
AlgoSimpleInt1
::
is_available
(
const
SizeArgs
&
args
)
const
{
if
(
args
.
src_layout
->
dtype
.
valid
()
&&
args
.
filter_layout
->
dtype
.
valid
()
&&
args
.
bias_layout
->
dtype
.
valid
()
&&
args
.
z_layout
->
dtype
.
valid
()
&&
args
.
dst_layout
->
dtype
.
valid
())
{
auto
config
=
prepare_sub_opr
(
args
);
return
args
.
src_layout
->
dtype
.
enumv
()
==
args
.
filter_layout
->
dtype
.
enumv
()
&&
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS1
&&
get_algorithm
(
static_cast
<
ConvBiasForwardImpl
*>
(
config
.
second
.
get
()),
config
.
first
[
0
],
config
.
first
[
1
],
config
.
first
[
2
],
config
.
first
[
3
],
config
.
first
[
4
]);
}
else
{
return
false
;
}
}
WorkspaceBundle
ConvBiasForwardImpl
::
AlgoSimpleInt1
::
get_workspace_bundle
(
void
*
ptr
,
const
SizeArgs
&
args
)
const
{
auto
config
=
prepare_sub_opr
(
args
);
SmallVector
<
size_t
>
sizes
;
auto
get_workspace
=
[
&
sizes
](
const
TensorLayout
&
src
,
const
TensorLayout
&
dst
)
{
if
(
src
.
dtype
!=
dst
.
dtype
)
{
sizes
.
push_back
(
dst
.
span
().
dist_byte
());
}
};
get_workspace
(
*
args
.
src_layout
,
config
.
first
[
0
]);
get_workspace
(
*
args
.
filter_layout
,
config
.
first
[
1
]);
get_workspace
(
*
args
.
bias_layout
,
config
.
first
[
2
]);
get_workspace
(
*
args
.
z_layout
,
config
.
first
[
3
]);
get_workspace
(
*
args
.
dst_layout
,
config
.
first
[
4
]);
sizes
.
push_back
(
config
.
second
->
get_workspace_in_bytes
(
config
.
first
[
0
],
config
.
first
[
1
],
config
.
first
[
2
],
config
.
first
[
3
],
config
.
first
[
4
],
nullptr
));
return
{
ptr
,
std
::
move
(
sizes
)};
}
size_t
ConvBiasForwardImpl
::
AlgoSimpleInt1
::
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
return
get_workspace_bundle
(
nullptr
,
args
).
total_size_in_bytes
();
}
void
ConvBiasForwardImpl
::
AlgoSimpleInt1
::
exec
(
const
ExecArgs
&
args
)
const
{
TensorND
fsrc_tensor
=
*
args
.
src_tensor
;
TensorND
ffilter_tensor
=
*
args
.
filter_tensor
;
TensorND
fbias_tensor
=
*
args
.
bias_tensor
;
TensorND
fz_tensor
=
*
args
.
z_tensor
;
TensorND
fdst_tensor
=
*
args
.
dst_tensor
;
auto
config
=
prepare_sub_opr
(
args
);
auto
bundle
=
get_workspace_bundle
(
args
.
workspace
.
raw_ptr
,
args
);
CompTypeCvter
<
dtype
::
QuantizedS1
,
dtype
::
Float32
>
cvter
(
args
.
handle
,
&
bundle
);
{
cvter
.
src_to_comp_type
(
*
args
.
src_tensor
,
fsrc_tensor
)
.
src_to_comp_type
(
*
args
.
filter_tensor
,
ffilter_tensor
);
}
WorkspaceBundle
dst_bundle
=
{
bundle
.
get
(
2
),
{
bundle
.
get_size
(
2
),
bundle
.
get_size
(
3
),
bundle
.
get_size
(
4
),
bundle
.
get_size
(
5
)}};
CompTypeCvter
<
dtype
::
QuantizedS32
,
dtype
::
Float32
>
dst_cvter
(
args
.
handle
,
&
dst_bundle
);
{
dst_cvter
.
src_to_comp_type
(
*
args
.
bias_tensor
,
fbias_tensor
)
.
src_to_comp_type
(
*
args
.
z_tensor
,
fz_tensor
)
.
src_to_comp_type
(
*
args
.
dst_tensor
,
fdst_tensor
);
}
config
.
second
->
exec
(
fsrc_tensor
,
ffilter_tensor
,
fbias_tensor
,
fz_tensor
,
fdst_tensor
,
nullptr
,
dst_cvter
.
workspace
());
{
dst_cvter
.
comp_to_dst_type
(
fdst_tensor
,
*
args
.
dst_tensor
);
}
}
// vim: syntax=cpp.doxygen
dnn/src/cuda/convolution/forward/algos.cpp
浏览文件 @
09dab387
...
@@ -44,6 +44,10 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config(
...
@@ -44,6 +44,10 @@ std::pair<TensorLayoutArray, ConvBiasForward::Param> sub_opr_config(
src
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
scale
*
src
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
scale
*
filter
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
scale
);
filter
.
dtype
.
param
<
dtype
::
Quantized4Asymm
>
().
scale
);
}
else
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS1
)
{
bias_type
=
dtype
::
QuantizedS32
(
src
.
dtype
.
param
<
dtype
::
QuantizedS1
>
().
scale
*
filter
.
dtype
.
param
<
dtype
::
QuantizedS1
>
().
scale
);
}
else
{
}
else
{
megdnn_assert
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
bias_type
=
src
.
dtype
;
bias_type
=
src
.
dtype
;
...
...
dnn/src/naive/conv_bias/opr_impl.cpp
浏览文件 @
09dab387
...
@@ -278,6 +278,9 @@ void ConvBiasForwardImpl::exec(
...
@@ -278,6 +278,9 @@ void ConvBiasForwardImpl::exec(
DISPATCH_RAW
(
DISPATCH_RAW
(
Quantized4Asymm
,
QuantizedS4
,
QuantizedS32
,
QuantizedS32
,
DEFAULT
,
Quantized4Asymm
,
QuantizedS4
,
QuantizedS32
,
QuantizedS32
,
DEFAULT
,
(
convolution
::
forward_bias
<
dt_quint4
,
dt_qint4
,
dt_qint32
,
dt_qint32
>
))
(
convolution
::
forward_bias
<
dt_quint4
,
dt_qint4
,
dt_qint32
,
dt_qint32
>
))
DISPATCH_RAW
(
QuantizedS1
,
QuantizedS1
,
QuantizedS32
,
QuantizedS32
,
FLOAT32
,
(
convolution
::
forward_bias
<
dt_qint1
,
dt_qint1
,
dt_qint32
,
dt_qint32
>
))
#if !MEGDNN_DISABLE_FLOAT16
#if !MEGDNN_DISABLE_FLOAT16
DISPATCH
(
Float16
,
Float16
)
DISPATCH
(
Float16
,
Float16
)
DISPATCH_RAW
(
DISPATCH_RAW
(
...
...
dnn/src/naive/convolution/helper.h
浏览文件 @
09dab387
...
@@ -84,6 +84,15 @@ inline void StrategyFwd::on(
...
@@ -84,6 +84,15 @@ inline void StrategyFwd::on(
d
+=
cast
(
s
)
*
cast
(
f
);
d
+=
cast
(
s
)
*
cast
(
f
);
}
}
template
<
>
inline
void
StrategyFwd
::
on
(
dt_qint1
&
s
,
dt_qint1
&
f
,
dt_qint32
&
d
,
DType
,
DType
,
DType
)
{
auto
cast
=
[](
const
dt_qint1
&
val
)
{
return
dt_qint32
(
static_cast
<
int32_t
>
(
val
.
as_int8
()));
};
d
+=
cast
(
s
)
*
cast
(
f
);
}
struct
StrategyBwdData
{
struct
StrategyBwdData
{
template
<
typename
st
,
typename
ft
,
typename
dt
>
template
<
typename
st
,
typename
ft
,
typename
dt
>
static
void
on
(
st
&
s
,
ft
&
f
,
dt
&
d
,
DType
,
DType
,
DType
)
{
static
void
on
(
st
&
s
,
ft
&
f
,
dt
&
d
,
DType
,
DType
,
DType
)
{
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
09dab387
...
@@ -133,6 +133,32 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BF16) {
...
@@ -133,6 +133,32 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_BF16) {
}
}
}
}
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_QS1
)
{
require_compute_capability
(
6
,
1
);
UniformIntRNG
int_rng
{
1
,
1
};
Checker
<
ConvBiasForward
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
AlgoChecker
<
ConvBiasForward
>
(
ExecutionPolicyAlgoName
{
"CONVBIAS_SIMPLE_INT1"
,
{{
"MATMUL"
,
{}}}}));
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
param
.
compute_mode
=
param
::
Convolution
::
ComputeMode
::
FLOAT32
;
{
auto
src_shape
=
TensorShape
{
20
,
2
,
224
,
224
};
auto
filter_shape
=
TensorShape
{
20
,
2
,
3
,
3
};
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS1
(
1.0
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS1
(
1.0
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
1.0
f
))
.
set_dtype
(
3
,
dtype
::
QuantizedS32
(
1.0
f
))
.
set_dtype
(
4
,
dtype
::
QuantizedS32
(
1.0
f
))
.
set_rng
(
0
,
&
int_rng
)
.
set_rng
(
1
,
&
int_rng
)
.
set_param
(
param
)
.
execs
({
src_shape
,
filter_shape
,
{},
{},
{}});
}
}
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_QS8
)
{
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_QS8
)
{
require_compute_capability
(
6
,
1
);
require_compute_capability
(
6
,
1
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录