Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
23032f50
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
23032f50
编写于
5月 25, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): support float16 for index_incr_multi_axis_vec
GitOrigin-RevId: c2ae93d568892d1af6a602aed3ed7c60f9dba1bd
上级
93894402
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
21 addition
and
6 deletion
+21
-6
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
+3
-3
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
+0
-3
dnn/test/cuda/indexing_multi_axis_vec.cpp
dnn/test/cuda/indexing_multi_axis_vec.cpp
+18
-0
未找到文件。
dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu
浏览文件 @
23032f50
...
...
@@ -11,11 +11,11 @@
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
#if !MEGDNN_DISABLE_FLOAT16
__device__
void
atomicAdd
(
megdnn
::
dt_float16
*
,
megdnn
::
dt_float16
)
{
__trap
();
((
int
*
)
0
)[
0
]
=
1
;
__device__
void
atomicAdd
(
megdnn
::
dt_float16
*
address
,
megdnn
::
dt_float16
val
)
{
::
megdnn
::
cuda
::
atomic_add
(
address
,
val
);
}
__device__
void
atomicAdd
(
megdnn
::
dt_bfloat16
*
,
megdnn
::
dt_bfloat16
)
{
...
...
dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp
浏览文件 @
23032f50
...
...
@@ -199,9 +199,6 @@ size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(
void
IndexingIncrMultiAxisVecImpl
::
exec
(
_megdnn_tensor_inout
data
,
_megdnn_tensor_in
value
,
const
IndexDesc
&
index
,
_megdnn_workspace
workspace
)
{
DNN_INC_FLOAT16
(
megdnn_assert
(
data
.
layout
.
dtype
!=
dtype
::
Float16
(),
"float16 incr on cuda currently not supported"
));
auto
info
=
check_exec
(
data
.
layout
,
value
.
layout
,
index
,
workspace
.
size
);
info
.
error_tracker
=
m_error_tracker
;
info
.
error_info
=
async_error_info
(
handle
());
...
...
dnn/test/cuda/indexing_multi_axis_vec.cpp
浏览文件 @
23032f50
...
...
@@ -32,6 +32,11 @@ namespace {
for
(
size_t
i
=
0
,
it
=
span
.
dist_elem
();
i
<
it
;
++
i
)
{
ptr
[
i
]
=
i
;
}
}
else
if
(
tensor
.
layout
.
dtype
==
dtype
::
Float16
())
{
auto
ptr
=
tensor
.
ptr
<
dt_float16
>
()
+
span
.
low_elem
;
for
(
size_t
i
=
0
,
it
=
span
.
dist_elem
();
i
<
it
;
++
i
)
{
ptr
[
i
]
=
i
;
}
}
else
{
auto
ptr
=
tensor
.
ptr
<
int
>
()
+
span
.
low_elem
;
for
(
size_t
i
=
0
,
it
=
span
.
dist_elem
();
i
<
it
;
++
i
)
{
...
...
@@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
TEST_F
(
CUDA
,
INDEXING_INCR_MULTI_AXIS_VEC
)
{
run_check
<
IndexingIncrMultiAxisVec
>
(
handle_cuda
());
Checker
<
IndexingIncrMultiAxisVec
>
checker
(
handle_cuda
());
OrderedRNG
rng
;
checker
.
set_dtype
(
0
,
dtype
::
Float16
()).
// data
set_dtype
(
1
,
dtype
::
Float16
()).
// value
set_dtype
(
2
,
dtype
::
Int32
()).
// idx0
set_rng
(
0
,
&
rng
).
set_rng
(
1
,
&
rng
).
set_rng
(
2
,
&
rng
);
checker
.
set_proxy
({{
1
}}).
execs
({{
5
,
8
,
3
},
{
5
,
2
,
3
},
{
2
}});
}
TEST_F
(
CUDA
,
INDEXING_SET_MULTI_AXIS_VEC
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录