提交 2696e4ef 编写于 作者: M Megvii Engine Team

feat(dnn): add float16 for remap backward

GitOrigin-RevId: 02630300515e1805aba1792968f3beab89ca4164
上级 ff171934
...@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec( ...@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec(
size_t workspace_in_bytes) { size_t workspace_in_bytes) {
check_layout_fwd(grad, map_xy, diff); check_layout_fwd(grad, map_xy, diff);
megdnn_assert( megdnn_assert(
grad.dtype == dtype::Float32() grad.dtype ==
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16."); "Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
...@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec( ...@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec(
check_layout_fwd(src, map_xy, diff); check_layout_fwd(src, map_xy, diff);
megdnn_assert_eq_layout(map_xy, grad); megdnn_assert_eq_layout(map_xy, grad);
megdnn_assert( megdnn_assert(
grad.dtype == dtype::Float32() grad.dtype ==
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16."); "Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
......
...@@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec( ...@@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec(
switch (grad.layout.dtype.enumv()) { switch (grad.layout.dtype.enumv()) {
support_dtype(dtype::Float32); support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16); support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default: default:
megdnn_throw("unsupported dtype in remap backward cuda\n"); megdnn_throw("unsupported dtype in remap backward cuda\n");
} }
......
...@@ -155,6 +155,7 @@ void backwarddata_proxy( ...@@ -155,6 +155,7 @@ void backwarddata_proxy(
FOR_FORMAT_BMODE(float) FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
#undef FOR_FORMAT_BMODE #undef FOR_FORMAT_BMODE
#undef INST #undef INST
......
...@@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec( ...@@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec(
switch (src.layout.dtype.enumv()) { switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32); support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16); support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default: default:
megdnn_throw("unsupported dtype in remap backward cuda\n"); megdnn_throw("unsupported dtype in remap backward cuda\n");
} }
......
...@@ -156,6 +156,7 @@ void backwardmat_proxy( ...@@ -156,6 +156,7 @@ void backwardmat_proxy(
FOR_FORMAT_BMODE(float) FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
#undef FOR_FORMAT_BMODE #undef FOR_FORMAT_BMODE
#undef INST #undef INST
......
...@@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec( ...@@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec(
support_dtype(dtype::Float32); support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb #undef cb
#undef support_dtype #undef support_dtype
...@@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec( ...@@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec(
support_dtype(dtype::Float32); support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb #undef cb
#undef support_dtype #undef support_dtype
......
...@@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) { ...@@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) {
.execs({arg.map_xy, arg.dst, arg.src}); \ .execs({arg.map_xy, arg.dst, arg.src}); \
} }
cb(dtype::BFloat16(), float_rng); cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb #undef cb
} }
...@@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) { ...@@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) {
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
} }
cb(dtype::BFloat16(), float_rng); cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb #undef cb
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册