Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7dc34769
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7dc34769
编写于
4月 24, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add typecvt uint16
GitOrigin-RevId: d1368c414e99e15d6fb93273b5051832d1995dea
上级
b92866d2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
69 addition
and
7 deletion
+69
-7
dnn/src/cuda/elemwise_helper.cpp
dnn/src/cuda/elemwise_helper.cpp
+6
-0
dnn/src/cuda/elemwise_helper.cuh
dnn/src/cuda/elemwise_helper.cuh
+1
-0
dnn/src/cuda/type_cvt/kern.cu
dnn/src/cuda/type_cvt/kern.cu
+54
-2
dnn/src/cuda/type_cvt/opr_impl.cpp
dnn/src/cuda/type_cvt/opr_impl.cpp
+5
-4
dnn/test/cuda/type_cvt.cpp
dnn/test/cuda/type_cvt.cpp
+2
-1
src/core/impl/dtype.cpp
src/core/impl/dtype.cpp
+1
-0
未找到文件。
dnn/src/cuda/elemwise_helper.cpp
浏览文件 @
7dc34769
...
...
@@ -148,6 +148,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
...
...
@@ -201,6 +204,9 @@ INST_FOR_CTYPE
#define ct dt_int16
INST_FOR_CTYPE
#undef ct
#define ct dt_uint16
INST_FOR_CTYPE
#undef ct
#define ct dt_quint8
INST_FOR_CTYPE
#undef ct
...
...
dnn/src/cuda/elemwise_helper.cuh
浏览文件 @
7dc34769
...
...
@@ -92,6 +92,7 @@ INST(dt_float16, half4);
INST
(
dt_bfloat16
,
bhalf4
);
INST
(
dt_int32
,
int4
);
INST
(
dt_int16
,
short4
);
INST
(
dt_uint16
,
ushort4
);
INST
(
dt_bool
,
uchar4
);
#undef as_raw
#define as_raw(x) x.as_int8()
...
...
dnn/src/cuda/type_cvt/kern.cu
浏览文件 @
7dc34769
...
...
@@ -247,6 +247,19 @@ struct TypeCvtOpFromQuantizedToQuantized4bit<
namespace
megdnn
{
namespace
cuda
{
// currently only typecvt_kern_{n2q,n2q4} respect this. change others typecvt_kern_* if
// needed.
template
<
typename
dtype_src
,
typename
dtype_dest
,
typename
sfinae
=
void
>
struct
enable_typecvt_kern
{
static
constexpr
bool
value
=
true
;
};
#define MEGDNN_DISABLE_CUDA_TYPECVT_KERN(dtype_src, dtype_dest) \
template <> \
struct enable_typecvt_kern<dtype_src, dtype_dest, void> { \
static constexpr bool value = false; \
};
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_q2q
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
...
...
@@ -257,12 +270,28 @@ void typecvt_kern_q2q(
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_n2q
(
typename
std
::
enable_if
<
enable_typecvt_kern
<
dtype_src
,
dtype_dest
>::
value
>::
type
typecvt_kern_n2q_impl
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
main_func
(
TypeCvtOpToQuantized
,
op
.
param
=
dst_param
;);
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
typename
std
::
enable_if
<!
enable_typecvt_kern
<
dtype_src
,
dtype_dest
>::
value
>::
type
typecvt_kern_n2q_impl
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
megdnn_throw
(
"TypeCvt: CUDA kernel for this dtype pair is disabled"
);
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_n2q
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
typecvt_kern_n2q_impl
<
dtype_src
,
dtype_dest
>
(
dest
,
src
,
dst_param
,
stream
);
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_q2n
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
...
...
@@ -312,12 +341,15 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
cb(dtype_src, dt_qint8) \
cb(dtype_src, dt_qint1) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC
(
dt_uint16
,
MEGDNN_DISABLE_CUDA_TYPECVT_KERN
)
#define INST_SRC_QUANTIZED(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2N) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_Q2Q) \
#define INST_SRC_NORMAL(dtype_src) \
MEGDNN_FOREACH_COMPUTING_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2N) \
INST_N2N(dtype_src, dt_uint16) \
MEGDNN_FOREACH_QUANTIZED_DTYPE_WITH_DTYPE_SRC(dtype_src, INST_N2Q) \
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \
...
...
@@ -340,6 +372,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, cudaStream_t st
MEGDNN_FOREACH_QUANTIZED_CTYPE
(
INST_SRC_QUANTIZED
)
MEGDNN_FOREACH_COMPUTING_CTYPE
(
INST_SRC_NORMAL
)
INST_SRC_NORMAL
(
dt_uint16
)
// clang-format on
template
void
typecvt_kern_n2q
<
dtype
::
Int8
,
dtype
::
QuantizedS8
>(
...
...
@@ -377,12 +410,28 @@ void typecvt_kern_q2q4(
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_n2q4
(
typename
std
::
enable_if
<
enable_typecvt_kern
<
dtype_src
,
dtype_dest
>::
value
>::
type
typecvt_kern_n2q4_impl
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
main_func_to_q4
(
TypeCvtOpFromNormalToQuantized4bit
,
op
.
dst_param
=
dst_param
;)
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
typename
std
::
enable_if
<!
enable_typecvt_kern
<
dtype_src
,
dtype_dest
>::
value
>::
type
typecvt_kern_n2q4_impl
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
megdnn_throw
(
"TypeCvt: CUDA kernel for this dtype pair is disabled"
);
}
template
<
typename
dtype_src
,
typename
dtype_dest
>
void
typecvt_kern_n2q4
(
const
TensorND
&
dest
,
const
TensorND
&
src
,
const
CudaDTypeParam
<
dtype_dest
>&
dst_param
,
cudaStream_t
stream
)
{
typecvt_kern_n2q4_impl
<
dtype_src
,
dtype_dest
>
(
dest
,
src
,
dst_param
,
stream
);
}
#define INST_Q2Q4(dtype_src, dtype_dest) \
template void typecvt_kern_q2q4<dtype_src, dtype_dest>( \
const TensorND& dest, const TensorND& src, \
...
...
@@ -399,6 +448,8 @@ void typecvt_kern_n2q4(
cb(dtype_src, dt_qint4) \
cb(dtype_src, dt_quint4) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC
(
dt_uint16
,
MEGDNN_DISABLE_CUDA_TYPECVT_KERN
)
#define INST_SRC_QUANTIZED_LOWBIT(dtype_src) \
MEGDNN_FOREACH_QUANTIZED_LOWBIT_WITH_DTYPE_SRC(dtype_src, INST_Q2Q4) \
...
...
@@ -407,6 +458,7 @@ void typecvt_kern_n2q4(
MEGDNN_FOREACH_QUANTIZED_CTYPE
(
INST_SRC_QUANTIZED_LOWBIT
)
MEGDNN_FOREACH_COMPUTING_CTYPE
(
INST_SRC_NORMAL_LOWBIT
)
INST_SRC_NORMAL_LOWBIT
(
dt_uint16
)
}
// namespace cuda
}
// namespace megdnn
...
...
dnn/src/cuda/type_cvt/opr_impl.cpp
浏览文件 @
7dc34769
...
...
@@ -12,6 +12,8 @@
#include "./opr_impl.h"
#include "./kern.cuh"
#include "megdnn/dtype.h"
#include "src/common/utils.cuh"
#include "src/cuda/utils.cuh"
#include "src/cuda/utils.h"
...
...
@@ -87,10 +89,9 @@ void exec_src_normal(const TensorND& dst, const TensorND& src, cudaStream_t stre
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
);
cb
(
::
megdnn
::
dtype
::
Bool
)
;
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
#undef cb
default:
megdnn_assert_internal
(
0
);
default
:
megdnn_assert_internal
(
0
);
}
}
else
if
(
!
is_dst_lowbit
)
{
switch
(
dst
.
layout
.
dtype
.
enumv
())
{
...
...
@@ -138,7 +139,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
#undef cb
default
:
megdnn_assert_internal
(
0
);
}
...
...
dnn/test/cuda/type_cvt.cpp
浏览文件 @
7dc34769
...
...
@@ -19,7 +19,8 @@ using namespace test;
TEST_F
(
CUDA
,
TYPE_CVT
)
{
UniformFloatRNG
init
(
0
,
20
);
std
::
vector
<
DType
>
dtypes
=
{
dtype
::
Float32
(),
dtype
::
Float16
(),
dtype
::
Int32
(),
dtype
::
Int16
(),
dtype
::
Int8
(),
dtype
::
Uint8
()};
dtype
::
Int16
(),
dtype
::
Int8
(),
dtype
::
Uint8
(),
dtype
::
Uint16
()};
for
(
auto
sdtype
:
dtypes
)
for
(
auto
ddtype
:
dtypes
)
{
TensorLayout
src
({
10
,
10
},
sdtype
),
dst
({
10
,
10
},
ddtype
);
...
...
src/core/impl/dtype.cpp
浏览文件 @
7dc34769
...
...
@@ -210,6 +210,7 @@ typename ctype_enable_if<ctype>::type DTypeScalar::set_retain_dtype(ctype val) {
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
cb
(
dt_bool
);
cb
(
dt_uint16
);
#undef cb
default:
mgb_throw
(
ConversionError
,
"can not assign to dtype %s"
,
m_dtype
.
name
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录