From 3e0bb22c7f57b6d0cc736a6a5da5f3dd3919d26c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 27 Dec 2022 14:29:50 +0800 Subject: [PATCH] feat(lite): load and run supports convert fp32 to fp16 online GitOrigin-RevId: 05c9a17a00301ced15043cf34284f2d362d2608c --- lite/include/lite/network.h | 1 + lite/load_and_run/src/helpers/common.h | 4 + .../src/options/dtype_options.cpp | 106 ++++++++++++++++++ lite/load_and_run/src/options/dtype_options.h | 48 ++++++++ lite/src/mge/network_impl.cpp | 3 + src/gopt/impl/inference.cpp | 64 +++++++++++ src/gopt/include/megbrain/gopt/inference.h | 4 + 7 files changed, 230 insertions(+) create mode 100644 lite/load_and_run/src/options/dtype_options.cpp create mode 100644 lite/load_and_run/src/options/dtype_options.h diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index 1a69f14ee..7e9915073 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -97,6 +97,7 @@ struct LITE_API Options { bool enable_nchw4 = false; bool enable_nchw32 = false; bool enable_nchw64 = false; + bool enable_f16_io_comp = false; // convert to fp16 }; /** diff --git a/lite/load_and_run/src/helpers/common.h b/lite/load_and_run/src/helpers/common.h index 65481090b..cbc9d331b 100644 --- a/lite/load_and_run/src/helpers/common.h +++ b/lite/load_and_run/src/helpers/common.h @@ -67,6 +67,10 @@ enum class OptLayoutType { NHWCD4 = 1 << 6, NCHW44_DOT = 1 << 7 }; +/*! + * \brief: dtype type for running model optimization + */ +enum class OptDTypeType { IOC16 = 1 << 0 }; /** * base class to story option value */ diff --git a/lite/load_and_run/src/options/dtype_options.cpp b/lite/load_and_run/src/options/dtype_options.cpp new file mode 100644 index 000000000..dd67298a5 --- /dev/null +++ b/lite/load_and_run/src/options/dtype_options.cpp @@ -0,0 +1,106 @@ +#include + +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +#include "dtype_options.h" +namespace lar { +template <> +void DTypeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { +#define ENABLE_DTYPE(dtype) \ + LITE_LOG("enable " #dtype " optimization"); \ + model->get_config().options.enable_##dtype = true; \ + break; + + switch (m_option_flag) { + case OptDTypeType::IOC16: + ENABLE_DTYPE(f16_io_comp) + default: + LITE_THROW( + "Set unsupport dtype, only --enable-ioc16 is supported. " + "Default case is fp32."); + break; + } +#undef ENABLE_DTYPE + } +} + +template <> +void DTypeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { +#define ENABLE_DTYPE(dtype) \ + mgb_log("enable " #dtype " optimization"); \ + model->get_mdl_config().comp_graph->options().graph_opt.enable_##dtype(); \ + break; + + switch (m_option_flag) { + case OptDTypeType::IOC16: + ENABLE_DTYPE(f16_io_comp) + default: + LITE_THROW( + "Set unsupport dtype, only --enable-ioc16 is supported. " + "Default case is fp32."); + break; + } + +#undef ENABLE_DTYPE + } +} +} // namespace lar + +using namespace lar; +bool DTypeOption::m_valid; +void DTypeOption::update() { + m_option_name = "dtype"; + m_option_flag = static_cast(0); + m_option = { + {"enable_ioc16", lar::Bool::make(false)}, + }; + std::static_pointer_cast(m_option["enable_ioc16"]) + ->set_value(FLAGS_enable_ioc16); +} + +bool DTypeOption::is_valid() { + size_t valid_flag = 0; + if (FLAGS_enable_ioc16) { + valid_flag |= static_cast(OptDTypeType::IOC16); + } + //! only one flag is valid + bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); + + return ret | m_valid; +}; + +std::shared_ptr DTypeOption::create_option() { + static std::shared_ptr option(new DTypeOption); + if (DTypeOption::is_valid()) { + option->update(); + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void DTypeOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + size_t valid_flag = 0; + if (FLAGS_enable_ioc16 || + std::static_pointer_cast(m_option["enable_ioc16"])->get_value()) { + valid_flag |= static_cast(OptDTypeType::IOC16); + } + + mgb_throw_if( + valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError, + "invalid options of dtype transform 0x%lx", valid_flag); + m_option_flag = static_cast(valid_flag); + CONFIG_MODEL_FUN; +} + +DEFINE_bool(enable_ioc16, false, "enable fp16 dtype optimization!!"); + +REGIST_OPTION_CREATOR(dtype, lar::DTypeOption::create_option); +REGIST_OPTION_VALIDATER(dtype, lar::DTypeOption::set_valid); \ No newline at end of file diff --git a/lite/load_and_run/src/options/dtype_options.h b/lite/load_and_run/src/options/dtype_options.h new file mode 100644 index 000000000..345b929a0 --- /dev/null +++ b/lite/load_and_run/src/options/dtype_options.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include "helpers/common.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(enable_ioc16); + +namespace lar { +/*! + * \brief: dtype option for optimization + */ +class DTypeOption final : public OptionBase { +public: + //! check the validation of option flag + static bool is_valid(); + + //! creat options when option is used + static std::shared_ptr create_option(); + + //! config the model, dispatch configuration for different model implement + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get option name + std::string option_name() const override { return m_option_name; }; + + static void set_valid(bool val) { m_valid = val; } + + OptionValMap* get_option() override { return &m_option; } + + void update() override; + +private: + //! Constructor + DTypeOption() = default; + + //! configuration for different model implement + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + OptDTypeType m_option_flag; + std::string m_option_name; + static bool m_valid; + OptionValMap m_option; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index e074cef13..0375ad5a0 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -97,6 +97,9 @@ void NetworkImplDft::application_config() { ConfigOptionLayoutTransform(enable_nchw32); ConfigOptionLayoutTransform(enable_nchw64); #undef ConfigOptionLayoutTransform + if (m_user_config->options.enable_f16_io_comp) { + options.graph_opt.enable_f16_io_comp(); + } if (m_user_config->has_compression) { m_load_config.tensor_value_loader = decompressed_tensor_value_loader; } diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 43293311a..01004cbc0 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -641,6 +641,22 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { for (size_t i = 0; i < origin_out.size(); i++) { rewriter.replace_var(origin_out[i], cur_out[i], nullptr); } + } else if ( + m_multi_tensor_replace_func.find(opr->dyn_typeinfo()) != + m_multi_tensor_replace_func.end()) { + auto&& new_inp = new_inp_cache; + new_inp.clear(); + new_inp.reserve(opr->input().size()); + for (auto i : opr->input()) { + new_inp.push_back(rewriter.get_var(i)); + } + auto&& origin_out = opr->output(); + auto&& cur_out = + m_multi_tensor_replace_func.at(opr->dyn_typeinfo())(opr, new_inp); + mgb_assert(origin_out.size() == cur_out.size()); + for (size_t i = 0; i < origin_out.size(); i++) { + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + } } else { rewriter.auto_replace_outputs(opr); } @@ -691,6 +707,48 @@ std::unique_ptr ConvertF32ToF16Pass::make(bool use_f32_comp return opr; }; + auto replace_multi_sdt_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& multi_sdt_opr = opr->cast_final_safe(); + + VarNodeArray cvt_vars; + cvt_vars.reserve(multi_sdt_opr.output().size()); + + for (size_t i = 0; i < multi_sdt_opr.output().size(); ++i) { + if (multi_sdt_opr.output(i)->dtype() == dtype::Float32()) { + cvt_vars.append({opr::TypeCvt::make( + multi_sdt_opr.output(i), dtype::Float16(), {}) + .node()}); + } else { + cvt_vars.append({multi_sdt_opr.output(i)}); + } + } + return cvt_vars; + }; + + auto replace_multi_sdt_with_format_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& multi_sdt_with_format_opr = + opr->cast_final_safe(); + + VarNodeArray cvt_vars; + cvt_vars.reserve(multi_sdt_with_format_opr.output().size()); + + for (size_t i = 0; i < multi_sdt_with_format_opr.output().size(); ++i) { + if (multi_sdt_with_format_opr.output(i)->dtype() == dtype::Float32()) { + cvt_vars.append({opr::TypeCvt::make( + multi_sdt_with_format_opr.output(i), + dtype::Float16(), {}) + .node()}); + } else { + cvt_vars.append({multi_sdt_with_format_opr.output(i)}); + } + } + return cvt_vars; + }; + auto replace_imt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->same_type()); mgb_assert(opr->input().size() == new_inp.size()); @@ -934,6 +992,12 @@ std::unique_ptr ConvertF32ToF16Pass::make(bool use_f32_comp replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; replace_func[opr::Remap::typeinfo()] = replace_remap_opr; replace_func[opr::BatchedMatrixMul::typeinfo()] = replace_batched_matmul_opr; + + auto& tensor_replace_func = ret->m_multi_tensor_replace_func; + tensor_replace_func[opr::MultipleDeviceTensorHolder::typeinfo()] = + replace_multi_sdt_opr; + tensor_replace_func[opr::MultipleDeviceTensorWithFormatHolder::typeinfo()] = + replace_multi_sdt_with_format_opr; return ret; #endif } diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index f6933ccd2..eb4bbeecf 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -73,6 +73,10 @@ class ConvertF32ToF16Pass : public Pass { Typeinfo*, thin_function> m_opr_replace_func; + ThinHashMap< + Typeinfo*, + thin_function> + m_multi_tensor_replace_func; VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag::CHECK_ALL; public: -- GitLab