提交 23032f50 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): support float16 for index_incr_multi_axis_vec

GitOrigin-RevId: c2ae93d568892d1af6a602aed3ed7c60f9dba1bd
上级 93894402
......@@ -11,11 +11,11 @@
#include "megdnn/dtype.h"
#include "src/cuda/utils.cuh"
#if !MEGDNN_DISABLE_FLOAT16
__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) {
__trap();
((int*)0)[0] = 1;
__device__ void atomicAdd(megdnn::dt_float16 * address, megdnn::dt_float16 val) {
::megdnn::cuda::atomic_add(address, val);
}
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) {
......
......@@ -199,9 +199,6 @@ size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(
void IndexingIncrMultiAxisVecImpl::exec(
_megdnn_tensor_inout data, _megdnn_tensor_in value,
const IndexDesc &index, _megdnn_workspace workspace) {
DNN_INC_FLOAT16(
megdnn_assert(data.layout.dtype != dtype::Float16(),
"float16 incr on cuda currently not supported"));
auto info = check_exec(data.layout, value.layout, index, workspace.size);
info.error_tracker = m_error_tracker;
info.error_info = async_error_info(handle());
......
......@@ -32,6 +32,11 @@ namespace {
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
ptr[i] = i;
}
} else if (tensor.layout.dtype == dtype::Float16()) {
auto ptr = tensor.ptr<dt_float16>() + span.low_elem;
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
ptr[i] = i;
}
} else {
auto ptr = tensor.ptr<int>() + span.low_elem;
for (size_t i = 0, it = span.dist_elem(); i < it; ++ i) {
......@@ -135,6 +140,19 @@ TEST_F(CUDA, INDEXING_MULTI_AXIS_VEC) {
TEST_F(CUDA, INDEXING_INCR_MULTI_AXIS_VEC) {
run_check<IndexingIncrMultiAxisVec>(handle_cuda());
Checker<IndexingIncrMultiAxisVec> checker(handle_cuda());
OrderedRNG rng;
checker.
set_dtype(0, dtype::Float16()). // data
set_dtype(1, dtype::Float16()). // value
set_dtype(2, dtype::Int32()). // idx0
set_rng(0, &rng).
set_rng(1, &rng).
set_rng(2, &rng);
checker.
set_proxy({{1}}).
execs({{5, 8, 3}, {5, 2, 3}, {2}});
}
TEST_F(CUDA, INDEXING_SET_MULTI_AXIS_VEC) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册