Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b87af9f7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b87af9f7
编写于
5月 28, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): topk support fp16
GitOrigin-RevId: c6610d4cf04f258fd13f420e5301c273e399d67f
上级
34262d90
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
59 addition
and
7 deletion
+59
-7
dnn/src/cuda/argsort/argsort.cu
dnn/src/cuda/argsort/argsort.cu
+1
-1
dnn/src/cuda/argsort/argsort.cuh
dnn/src/cuda/argsort/argsort.cuh
+2
-1
dnn/src/cuda/argsort/bitonic_sort.cu
dnn/src/cuda/argsort/bitonic_sort.cu
+14
-0
dnn/src/cuda/cub/util_type.cuh
dnn/src/cuda/cub/util_type.cuh
+1
-1
dnn/src/cuda/topk/opr_impl.cpp
dnn/src/cuda/topk/opr_impl.cpp
+10
-1
dnn/src/cuda/topk/topk_radix.cu
dnn/src/cuda/topk/topk_radix.cu
+2
-1
dnn/src/cuda/topk/topk_radix.cuh
dnn/src/cuda/topk/topk_radix.cuh
+24
-1
dnn/test/cuda/topk.cpp
dnn/test/cuda/topk.cpp
+5
-1
未找到文件。
dnn/src/cuda/argsort/argsort.cu
浏览文件 @
b87af9f7
...
...
@@ -124,7 +124,7 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype,
ARGSORT_FOREACH_CTYPE
(
cb
)
#undef cb
default:
megdnn_throw
(
"argsort only supports float
and int32
"
);
megdnn_throw
(
"argsort only supports float
, int32 and float16
"
);
}
if
(
!
iptr_src_given
)
{
size
=
DIVUP
(
size
,
sizeof
(
float
))
*
sizeof
(
float
)
+
M
*
N
*
sizeof
(
int
);
...
...
dnn/src/cuda/argsort/argsort.cuh
浏览文件 @
b87af9f7
...
...
@@ -33,7 +33,8 @@ void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace,
const
int
*
iptr_src
=
NULL
);
//! iterate over all supported data types
#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t)
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
}
// namespace argsort
}
// namespace cuda
...
...
dnn/src/cuda/argsort/bitonic_sort.cu
浏览文件 @
b87af9f7
...
...
@@ -11,6 +11,7 @@
#include "./bitonic_sort.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "megdnn/dtype.h"
#if __CUDACC_VER_MAJOR__ < 9
#pragma message "warp sync disabled due to insufficient cuda version"
...
...
@@ -82,6 +83,18 @@ struct NumTrait<int32_t> {
static
__device__
__forceinline__
int32_t
min
()
{
return
INT_MIN
;
}
};
#if !MEGDNN_DISABLE_FLOAT16
template
<
>
struct
NumTrait
<
dt_float16
>
{
static
__device__
__forceinline__
dt_float16
max
()
{
return
std
::
numeric_limits
<
dt_float16
>::
max
();
}
static
__device__
__forceinline__
dt_float16
min
()
{
return
std
::
numeric_limits
<
dt_float16
>::
lowest
();
}
};
#endif
struct
LessThan
{
template
<
typename
Key
,
typename
Value
>
static
__device__
__forceinline__
bool
cmp
(
Key
k0
,
Value
v0
,
Key
k1
,
...
...
@@ -295,6 +308,7 @@ namespace cuda {
INST
(
float
,
int
);
INST
(
int32_t
,
int
);
DNN_INC_FLOAT16
(
INST
(
dt_float16
,
int
));
#undef INST
}
// namespace megdnn
...
...
dnn/src/cuda/cub/util_type.cuh
浏览文件 @
b87af9f7
...
...
@@ -1146,7 +1146,7 @@ template <> struct NumericTraits<double> : BaseTraits<FLOATING_POIN
#if (__CUDACC_VER_MAJOR__ >= 9)
template
<
>
struct
NumericTraits
<
__half
>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
unsigned
short
,
__half
>
{};
#endif
template
<
>
struct
NumericTraits
<
half_float
::
half
>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
unsigned
short
,
half_float
::
half
>
{};
template
<
>
struct
NumericTraits
<
bool
>
:
BaseTraits
<
UNSIGNED_INTEGER
,
true
,
false
,
typename
UnitWord
<
bool
>::
VolatileWord
,
bool
>
{};
...
...
dnn/src/cuda/topk/opr_impl.cpp
浏览文件 @
b87af9f7
...
...
@@ -81,9 +81,18 @@ void TopKImpl::do_exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values,
values
.
ptr
<
int32_t
>
(),
indices
,
workspace
.
raw_ptr
);
return
;
#if !MEGDNN_DISABLE_FLOAT16
case
DTypeEnum
::
Float16
:
dispatch_with_ctype
<
dt_float16
>
(
k
,
data
.
layout
[
0
],
data
.
layout
[
1
],
data
.
layout
.
stride
[
0
],
data
.
ptr
<
dt_float16
>
(),
values
.
ptr
<
dt_float16
>
(),
indices
,
workspace
.
raw_ptr
);
return
;
#endif
default:
megdnn_throw
(
ssprintf
(
"only float32 and int32 supported for cuda topk, got: %s"
,
ssprintf
(
"only float32, int32 and float16 supported for "
"cuda topk, got: %s"
,
data
.
layout
.
dtype
.
name
()));
}
}
...
...
dnn/src/cuda/topk/topk_radix.cu
浏览文件 @
b87af9f7
...
...
@@ -489,7 +489,7 @@ cudaError_t topk::find_kth_radix(const ctype* input, ctype* output,
if
(
k
<
0
)
{
k
=
length
+
k
+
1
;
}
if
(
!
(
BUCKET_BITS
==
8
&&
sizeof
(
ctype
)
==
4
))
{
if
(
!
(
BUCKET_BITS
==
8
&&
(
sizeof
(
ctype
)
==
4
||
sizeof
(
ctype
)
==
2
)
))
{
// no c++11 in megdnn cuda; so we just trap instead of using static
// assert
megdnn_trap
();
...
...
@@ -668,6 +668,7 @@ namespace topk {
int32_t, uint32_t, cudaStream_t)
INST
(
float
);
INST
(
int32_t
);
DNN_INC_FLOAT16
(
INST
(
dt_float16
));
#undef INST
}
// namespace topk
...
...
dnn/src/cuda/topk/topk_radix.cuh
浏览文件 @
b87af9f7
...
...
@@ -10,7 +10,7 @@
*/
#pragma once
#include "megdnn/dtype.h"
#include <cuda_runtime.h>
#include <stdint.h>
...
...
@@ -60,6 +60,29 @@ struct RadixConverter<int32_t> {
}
};
#if !MEGDNN_DISABLE_FLOAT16
template
<
>
struct
RadixConverter
<
dt_float16
>
{
union
FIunion
{
FIunion
()
{}
dt_float16
fv
;
uint16_t
iv
;
};
static
__forceinline__
__device__
__host__
uint16_t
to_radix
(
dt_float16
val
)
{
FIunion
fi
;
fi
.
fv
=
val
;
return
fi
.
iv
^
(((
!
(
fi
.
iv
>>
15u
))
-
1u
)
|
0x8000u
);
}
static
__forceinline__
__device__
__host__
dt_float16
from_radix
(
uint16_t
val
)
{
FIunion
fi
;
// do not write as to_radix() to work around a compiler bug in cuda-9.0
uint16_t
m
=
0x8000u
;
fi
.
iv
=
val
^
(
m
|
(
m
-
!
(
val
>>
15u
)));
return
fi
.
fv
;
}
};
#endif
}
// namespace internal
/*!
...
...
dnn/test/cuda/topk.cpp
浏览文件 @
b87af9f7
...
...
@@ -27,6 +27,10 @@ TEST_F(CUDA, TOP_K) {
TEST_F
(
CUDA
,
TOP_K_I32
)
{
run_topk_test
<
dtype
::
Int32
>
(
handle_cuda
());
}
#if !MEGDNN_DISABLE_FLOAT16
TEST_F
(
CUDA
,
TOP_K_F16
)
{
run_topk_test
<
dtype
::
Float16
>
(
handle_cuda
());
}
#endif
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录