提交 2696e4ef 编写于 作者: M Megvii Engine Team

feat(dnn): add float16 for remap backward

GitOrigin-RevId: 02630300515e1805aba1792968f3beab89ca4164
上级 ff171934
......@@ -89,8 +89,9 @@ void RemapBackwardData::check_exec(
size_t workspace_in_bytes) {
check_layout_fwd(grad, map_xy, diff);
megdnn_assert(
grad.dtype == dtype::Float32()
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
grad.dtype ==
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
......@@ -102,8 +103,9 @@ void RemapBackwardMat::check_exec(
check_layout_fwd(src, map_xy, diff);
megdnn_assert_eq_layout(map_xy, grad);
megdnn_assert(
grad.dtype == dtype::Float32()
DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()),
grad.dtype ==
dtype::Float32() DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16())
DNN_INC_FLOAT16(|| grad.dtype == dtype::Float16()),
"Backward Remap only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
......
......@@ -61,6 +61,7 @@ void RemapBackwardDataImpl::exec(
switch (grad.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}
......
......@@ -155,6 +155,7 @@ void backwarddata_proxy(
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
#undef FOR_FORMAT_BMODE
#undef INST
......
......@@ -62,6 +62,7 @@ void RemapBackwardMatImpl::exec(
switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32);
support_dtype(dtype::BFloat16);
support_dtype(dtype::Float16);
default:
megdnn_throw("unsupported dtype in remap backward cuda\n");
}
......
......@@ -156,6 +156,7 @@ void backwardmat_proxy(
FOR_FORMAT_BMODE(float)
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
#undef FOR_FORMAT_BMODE
#undef INST
......
......@@ -320,6 +320,7 @@ void RemapBackwardDataImpl::exec(
support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb
#undef support_dtype
......@@ -371,6 +372,7 @@ void RemapBackwardMatImpl::exec(
support_dtype(dtype::Float32);
DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
DNN_INC_FLOAT16(support_dtype(dtype::Float16));
#undef cb
#undef support_dtype
......
......@@ -180,6 +180,7 @@ TEST_F(CUDA, REMAP_BACKWARD_DATA) {
.execs({arg.map_xy, arg.dst, arg.src}); \
}
cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}
......@@ -222,6 +223,7 @@ TEST_F(CUDA, REMAP_BACKWARD_MAT) {
.execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
}
cb(dtype::BFloat16(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册