From 73112558d07ae15632fdfdf6835608a11b8ef22e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Mar 2022 18:45:06 +0800 Subject: [PATCH] feat(mge/dnn): support checknonfinite for fp16 GitOrigin-RevId: 83fa139ac06ed6851537764b3cdaba812f219773 --- dnn/include/megdnn/oprs/general.h | 1 - dnn/src/common/check_non_finite.cpp | 3 +- dnn/src/common/reduce_helper_device.h | 2 +- dnn/src/cuda/check_non_finite/kern.cu | 12 ++++--- dnn/src/cuda/check_non_finite/opr_impl.cpp | 40 +++++++++++++++++----- dnn/src/cuda/check_non_finite/opr_impl.h | 8 ++++- dnn/src/naive/check_non_finite/opr_impl.h | 2 +- 7 files changed, 49 insertions(+), 19 deletions(-) diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 0fbd15f76..30e6c181d 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1383,7 +1383,6 @@ public: protected: void check_exec( const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes); - virtual size_t _get_workspace_in_bytes() = 0; }; /*! diff --git a/dnn/src/common/check_non_finite.cpp b/dnn/src/common/check_non_finite.cpp index e03d78001..24fdfb8b2 100644 --- a/dnn/src/common/check_non_finite.cpp +++ b/dnn/src/common/check_non_finite.cpp @@ -18,8 +18,7 @@ void CheckNonFinite::check_exec( const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) { megdnn_assert_contiguous(dst.layout); megdnn_assert(srcs.size() > 0); - megdnn_assert(srcs.begin()->layout.dtype == dtype::Float32()); - auto required_workspace_in_bytes = _get_workspace_in_bytes(); + auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst.layout); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/reduce_helper_device.h b/dnn/src/common/reduce_helper_device.h index 31521261c..05f4b71a4 100644 --- a/dnn/src/common/reduce_helper_device.h +++ b/dnn/src/common/reduce_helper_device.h @@ -236,4 +236,4 @@ void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t a } // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/check_non_finite/kern.cu b/dnn/src/cuda/check_non_finite/kern.cu index 8251090ae..4a829a59c 100644 --- a/dnn/src/cuda/check_non_finite/kern.cu +++ b/dnn/src/cuda/check_non_finite/kern.cu @@ -18,11 +18,15 @@ namespace cuda { #define COMMA , -INST_REDUCE( - device_reduce::CheckNonFiniteOp< - dt_float32 COMMA size_t COMMA dt_int32 COMMA dt_int32>, - false); +#define cb(_dtype) \ + INST_REDUCE( \ + device_reduce::CheckNonFiniteOp< \ + _dtype COMMA size_t COMMA dt_int32 COMMA dt_int32>, \ + false); +cb(dt_float32); +cb(dt_float16); +#undef cb #undef COMMA } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/check_non_finite/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp index 54ac36e86..d04390220 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -22,13 +22,14 @@ namespace cuda { using device_reduce::CheckNonFiniteOp; #define total_nr_elems_max 2048 +template size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes - typedef CheckNonFiniteOp Op; + typedef CheckNonFiniteOp Op; megdnn_assert(m_size > 0); WorkspaceBundle bundle( nullptr, { - sizeof(dt_float32*) * m_size, + sizeof(T*) * m_size, sizeof(size_t) * m_size, }); return get_reduce_workspace_in_bytes(1, m_size * total_nr_elems_max, 1) + @@ -41,17 +42,38 @@ size_t CheckNonFiniteImpl::get_workspace_in_bytes( for (const auto& src : srcs) { m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max); } - return _get_workspace_in_bytes(); + if (srcs.begin()->layout.dtype == dtype::Float32()) { + return _get_workspace_in_bytes(); + } else if (srcs.begin()->layout.dtype == dtype::Float16()) { + return _get_workspace_in_bytes(); + } else { + megdnn_log_warn("only support fp16 and fp32, fallback to fp32"); + return _get_workspace_in_bytes(); + } } void CheckNonFiniteImpl::exec( _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + if (srcs.begin()->layout.dtype == dtype::Float32()) { + _exec(srcs, dst, workspace); + } +#ifdef DNN_INC_FLOAT16 + else if (srcs.begin()->layout.dtype == dtype::Float16()) { + _exec(srcs, dst, workspace); + } +#endif +} + +template +void CheckNonFiniteImpl::_exec( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(srcs, dst, workspace.size); - typedef CheckNonFiniteOp Op; + typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); SmallVector workspace_sizes{ - sizeof(dt_float32*) * m_size, + sizeof(T*) * m_size, sizeof(size_t) * m_size, }; WorkspaceBundle workspace_cpu(nullptr, workspace_sizes), @@ -63,8 +85,8 @@ void CheckNonFiniteImpl::exec( workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes); - auto srcs_cpu = static_cast(workspace_cpu.get(0)); - auto srcs_gpu = static_cast(workspace_gpu.get(0)); + auto srcs_cpu = static_cast(workspace_cpu.get(0)); + auto srcs_gpu = static_cast(workspace_gpu.get(0)); auto srcs_total_nr_elems_cpu = static_cast(workspace_cpu.get(1)); auto srcs_total_nr_elems_gpu = static_cast(workspace_gpu.get(1)); @@ -75,7 +97,7 @@ void CheckNonFiniteImpl::exec( size_t src_nr_elems = src.layout.total_nr_elems(); size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max); for (size_t j = 0; j < nr_elems; ++j, ++i) { - srcs_cpu[i] = src.ptr() + j * total_nr_elems_max; + srcs_cpu[i] = src.ptr() + j * total_nr_elems_max; if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) { srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max; } else { @@ -97,7 +119,7 @@ void CheckNonFiniteImpl::exec( workspace_gpu.total_size_in_bytes())), 1, m_size * total_nr_elems_max, 1, stream, Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr(), - total_nr_elems_max, param().scale)); + total_nr_elems_max, static_cast(param().scale))); } } // namespace cuda diff --git a/dnn/src/cuda/check_non_finite/opr_impl.h b/dnn/src/cuda/check_non_finite/opr_impl.h index 7392c0625..f47213380 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.h +++ b/dnn/src/cuda/check_non_finite/opr_impl.h @@ -18,7 +18,13 @@ namespace megdnn { namespace cuda { class CheckNonFiniteImpl final : public CheckNonFinite { - size_t _get_workspace_in_bytes() override; + template + size_t _get_workspace_in_bytes(); + + template + void _exec( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace); public: using CheckNonFinite::CheckNonFinite; diff --git a/dnn/src/naive/check_non_finite/opr_impl.h b/dnn/src/naive/check_non_finite/opr_impl.h index 2360a7191..fef7bfdef 100644 --- a/dnn/src/naive/check_non_finite/opr_impl.h +++ b/dnn/src/naive/check_non_finite/opr_impl.h @@ -17,7 +17,7 @@ namespace megdnn { namespace naive { class CheckNonFiniteImpl final : public CheckNonFinite { - size_t _get_workspace_in_bytes() override { return 0; } + size_t _get_workspace_in_bytes() { return 0; } public: using CheckNonFinite::CheckNonFinite; -- GitLab