From 2696e4efaab39244de980860cd777ec7f9509d73 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 29 Oct 2021 15:40:12 +0800 Subject: [PATCH] feat(dnn): add float16 for remap backward GitOrigin-RevId: 02630300515e1805aba1792968f3beab89ca4164 --- dnn/src/common/remap.cpp | 10 ++++++---- dnn/src/cuda/remap/backward_data.cpp | 1 + dnn/src/cuda/remap/backward_data.cu | 1 + dnn/src/cuda/remap/backward_mat.cpp | 1 + dnn/src/cuda/remap/backward_mat.cu | 1 + dnn/src/naive/remap/opr_impl.cpp | 2 ++ dnn/test/cuda/remap.cpp | 2 ++ 7 files changed, 14 insertions(+), 4 deletions(-) diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp index d77a5fa0..6c2e5a47 100644 --- a/dnn/src/common/remap.cpp +++ b/dnn/src/common/remap.cpp @@ -89,8 +89,9 @@ void RemapBackwardData::check_exec( size_t workspace_in_bytes) { check_layout_fwd(grad, map_xy, diff); megdnn_assert( - grad.dtype == dtype::Float32() - DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + grad.dtype == + dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) + DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), "Backward Remap only supports Float32/BFloat16."); auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); @@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec( check_layout_fwd(src, map_xy, diff); megdnn_assert_eq_layout(map_xy, grad); megdnn_assert( - grad.dtype == dtype::Float32() - DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + grad.dtype == + dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()) + DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()), "Backward Remap only supports Float32/BFloat16."); auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/cuda/remap/backward_data.cpp b/dnn/src/cuda/remap/backward_data.cpp index 1e427e21..a2ce8950 100644 --- a/dnn/src/cuda/remap/backward_data.cpp +++ b/dnn/src/cuda/remap/backward_data.cpp @@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec( switch (grad.layout.dtype.enumv()) { support_dtype(dtype::Float32); support_dtype(dtype::BFloat16); + support_dtype(dtype::Float16); default: megdnn_throw("unsupported dtype in remap backward cuda\n"); } diff --git a/dnn/src/cuda/remap/backward_data.cu b/dnn/src/cuda/remap/backward_data.cu index 370518ed..662f9ac6 100644 --- a/dnn/src/cuda/remap/backward_data.cu +++ b/dnn/src/cuda/remap/backward_data.cu @@ -155,6 +155,7 @@ void backwarddata_proxy( FOR_FORMAT_BMODE(float) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) +DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) #undef FOR_FORMAT_BMODE #undef INST diff --git a/dnn/src/cuda/remap/backward_mat.cpp b/dnn/src/cuda/remap/backward_mat.cpp index 3e5a12b1..e14d9b09 100644 --- a/dnn/src/cuda/remap/backward_mat.cpp +++ b/dnn/src/cuda/remap/backward_mat.cpp @@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec( switch (src.layout.dtype.enumv()) { support_dtype(dtype::Float32); support_dtype(dtype::BFloat16); + support_dtype(dtype::Float16); default: megdnn_throw("unsupported dtype in remap backward cuda\n"); } diff --git a/dnn/src/cuda/remap/backward_mat.cu b/dnn/src/cuda/remap/backward_mat.cu index f4ab0908..f0f6498a 100644 --- a/dnn/src/cuda/remap/backward_mat.cu +++ b/dnn/src/cuda/remap/backward_mat.cu @@ -156,6 +156,7 @@ void backwardmat_proxy( FOR_FORMAT_BMODE(float) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) +DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) #undef FOR_FORMAT_BMODE #undef INST diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp index 7eef71f3..56e8b632 100644 --- a/dnn/src/naive/remap/opr_impl.cpp +++ b/dnn/src/naive/remap/opr_impl.cpp @@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec( support_dtype(dtype::Float32); DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); + DNN_INC_FLOAT16(support_dtype(dtype::Float16)); #undef cb #undef support_dtype @@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec( support_dtype(dtype::Float32); DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); + DNN_INC_FLOAT16(support_dtype(dtype::Float16)); #undef cb #undef support_dtype diff --git a/dnn/test/cuda/remap.cpp b/dnn/test/cuda/remap.cpp index 53927eaa..d737d55e 100644 --- a/dnn/test/cuda/remap.cpp +++ b/dnn/test/cuda/remap.cpp @@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) { .execs({arg.map_xy, arg.dst, arg.src}); \ } cb(dtype::BFloat16(), float_rng); + cb(dtype::Float16(), float_rng); #undef cb } @@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) { .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ } cb(dtype::BFloat16(), float_rng); + cb(dtype::Float16(), float_rng); #undef cb } -- GitLab