diff --git a/lite/include/lite/network.h b/lite/include/lite/network.h index 1a69f14ee15afc1a1e638f45e19d03c758049070..7e9915073d9d261808e051f4907044a6bd93c5a9 100644 --- a/lite/include/lite/network.h +++ b/lite/include/lite/network.h @@ -97,6 +97,7 @@ struct LITE_API Options { bool enable_nchw4 = false; bool enable_nchw32 = false; bool enable_nchw64 = false; + bool enable_f16_io_comp = false; // convert to fp16 }; /** diff --git a/lite/load_and_run/src/helpers/common.h b/lite/load_and_run/src/helpers/common.h index 65481090be2bf8f62e8d6822ee120f7449b6aabd..cbc9d331b43f55658c92bfe09b9f0b2417394207 100644 --- a/lite/load_and_run/src/helpers/common.h +++ b/lite/load_and_run/src/helpers/common.h @@ -67,6 +67,10 @@ enum class OptLayoutType { NHWCD4 = 1 << 6, NCHW44_DOT = 1 << 7 }; +/*! + * \brief: dtype type for running model optimization + */ +enum class OptDTypeType { IOC16 = 1 << 0 }; /** * base class to story option value */ diff --git a/lite/load_and_run/src/options/dtype_options.cpp b/lite/load_and_run/src/options/dtype_options.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd67298a59b4a5fd387504d0e7ef932d53321219 --- /dev/null +++ b/lite/load_and_run/src/options/dtype_options.cpp @@ -0,0 +1,106 @@ +#include + +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +#include "dtype_options.h" +namespace lar { +template <> +void DTypeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { +#define ENABLE_DTYPE(dtype) \ + LITE_LOG("enable " #dtype " optimization"); \ + model->get_config().options.enable_##dtype = true; \ + break; + + switch (m_option_flag) { + case OptDTypeType::IOC16: + ENABLE_DTYPE(f16_io_comp) + default: + LITE_THROW( + "Set unsupport dtype, only --enable-ioc16 is supported. " + "Default case is fp32."); + break; + } +#undef ENABLE_DTYPE + } +} + +template <> +void DTypeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { +#define ENABLE_DTYPE(dtype) \ + mgb_log("enable " #dtype " optimization"); \ + model->get_mdl_config().comp_graph->options().graph_opt.enable_##dtype(); \ + break; + + switch (m_option_flag) { + case OptDTypeType::IOC16: + ENABLE_DTYPE(f16_io_comp) + default: + LITE_THROW( + "Set unsupport dtype, only --enable-ioc16 is supported. " + "Default case is fp32."); + break; + } + +#undef ENABLE_DTYPE + } +} +} // namespace lar + +using namespace lar; +bool DTypeOption::m_valid; +void DTypeOption::update() { + m_option_name = "dtype"; + m_option_flag = static_cast(0); + m_option = { + {"enable_ioc16", lar::Bool::make(false)}, + }; + std::static_pointer_cast(m_option["enable_ioc16"]) + ->set_value(FLAGS_enable_ioc16); +} + +bool DTypeOption::is_valid() { + size_t valid_flag = 0; + if (FLAGS_enable_ioc16) { + valid_flag |= static_cast(OptDTypeType::IOC16); + } + //! only one flag is valid + bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); + + return ret | m_valid; +}; + +std::shared_ptr DTypeOption::create_option() { + static std::shared_ptr option(new DTypeOption); + if (DTypeOption::is_valid()) { + option->update(); + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void DTypeOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + size_t valid_flag = 0; + if (FLAGS_enable_ioc16 || + std::static_pointer_cast(m_option["enable_ioc16"])->get_value()) { + valid_flag |= static_cast(OptDTypeType::IOC16); + } + + mgb_throw_if( + valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError, + "invalid options of dtype transform 0x%lx", valid_flag); + m_option_flag = static_cast(valid_flag); + CONFIG_MODEL_FUN; +} + +DEFINE_bool(enable_ioc16, false, "enable fp16 dtype optimization!!"); + +REGIST_OPTION_CREATOR(dtype, lar::DTypeOption::create_option); +REGIST_OPTION_VALIDATER(dtype, lar::DTypeOption::set_valid); \ No newline at end of file diff --git a/lite/load_and_run/src/options/dtype_options.h b/lite/load_and_run/src/options/dtype_options.h new file mode 100644 index 0000000000000000000000000000000000000000..345b929a0e6608d492ba14db0874162ef9bd2ead --- /dev/null +++ b/lite/load_and_run/src/options/dtype_options.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include "helpers/common.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(enable_ioc16); + +namespace lar { +/*! + * \brief: dtype option for optimization + */ +class DTypeOption final : public OptionBase { +public: + //! check the validation of option flag + static bool is_valid(); + + //! creat options when option is used + static std::shared_ptr create_option(); + + //! config the model, dispatch configuration for different model implement + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get option name + std::string option_name() const override { return m_option_name; }; + + static void set_valid(bool val) { m_valid = val; } + + OptionValMap* get_option() override { return &m_option; } + + void update() override; + +private: + //! Constructor + DTypeOption() = default; + + //! configuration for different model implement + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + OptDTypeType m_option_flag; + std::string m_option_name; + static bool m_valid; + OptionValMap m_option; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index e074cef13a83a962a9662e9c97b2c543a433af54..0375ad5a0abdd79365f01097bfdeb0294162cf87 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -97,6 +97,9 @@ void NetworkImplDft::application_config() { ConfigOptionLayoutTransform(enable_nchw32); ConfigOptionLayoutTransform(enable_nchw64); #undef ConfigOptionLayoutTransform + if (m_user_config->options.enable_f16_io_comp) { + options.graph_opt.enable_f16_io_comp(); + } if (m_user_config->has_compression) { m_load_config.tensor_value_loader = decompressed_tensor_value_loader; } diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 43293311aa4e31cf49e847f9cd6ca0e10d745041..01004cbc0cfeefd424027787785e585004dd6b36 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -641,6 +641,22 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { for (size_t i = 0; i < origin_out.size(); i++) { rewriter.replace_var(origin_out[i], cur_out[i], nullptr); } + } else if ( + m_multi_tensor_replace_func.find(opr->dyn_typeinfo()) != + m_multi_tensor_replace_func.end()) { + auto&& new_inp = new_inp_cache; + new_inp.clear(); + new_inp.reserve(opr->input().size()); + for (auto i : opr->input()) { + new_inp.push_back(rewriter.get_var(i)); + } + auto&& origin_out = opr->output(); + auto&& cur_out = + m_multi_tensor_replace_func.at(opr->dyn_typeinfo())(opr, new_inp); + mgb_assert(origin_out.size() == cur_out.size()); + for (size_t i = 0; i < origin_out.size(); i++) { + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + } } else { rewriter.auto_replace_outputs(opr); } @@ -691,6 +707,48 @@ std::unique_ptr ConvertF32ToF16Pass::make(bool use_f32_comp return opr; }; + auto replace_multi_sdt_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& multi_sdt_opr = opr->cast_final_safe(); + + VarNodeArray cvt_vars; + cvt_vars.reserve(multi_sdt_opr.output().size()); + + for (size_t i = 0; i < multi_sdt_opr.output().size(); ++i) { + if (multi_sdt_opr.output(i)->dtype() == dtype::Float32()) { + cvt_vars.append({opr::TypeCvt::make( + multi_sdt_opr.output(i), dtype::Float16(), {}) + .node()}); + } else { + cvt_vars.append({multi_sdt_opr.output(i)}); + } + } + return cvt_vars; + }; + + auto replace_multi_sdt_with_format_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& multi_sdt_with_format_opr = + opr->cast_final_safe(); + + VarNodeArray cvt_vars; + cvt_vars.reserve(multi_sdt_with_format_opr.output().size()); + + for (size_t i = 0; i < multi_sdt_with_format_opr.output().size(); ++i) { + if (multi_sdt_with_format_opr.output(i)->dtype() == dtype::Float32()) { + cvt_vars.append({opr::TypeCvt::make( + multi_sdt_with_format_opr.output(i), + dtype::Float16(), {}) + .node()}); + } else { + cvt_vars.append({multi_sdt_with_format_opr.output(i)}); + } + } + return cvt_vars; + }; + auto replace_imt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->same_type()); mgb_assert(opr->input().size() == new_inp.size()); @@ -934,6 +992,12 @@ std::unique_ptr ConvertF32ToF16Pass::make(bool use_f32_comp replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr; replace_func[opr::Remap::typeinfo()] = replace_remap_opr; replace_func[opr::BatchedMatrixMul::typeinfo()] = replace_batched_matmul_opr; + + auto& tensor_replace_func = ret->m_multi_tensor_replace_func; + tensor_replace_func[opr::MultipleDeviceTensorHolder::typeinfo()] = + replace_multi_sdt_opr; + tensor_replace_func[opr::MultipleDeviceTensorWithFormatHolder::typeinfo()] = + replace_multi_sdt_with_format_opr; return ret; #endif } diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index f6933ccd2db450394fa50b9192bc4bda70a9506f..eb4bbeecf5a8efdb4e3e0a9ddfd2270da5d76048 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -73,6 +73,10 @@ class ConvertF32ToF16Pass : public Pass { Typeinfo*, thin_function> m_opr_replace_func; + ThinHashMap< + Typeinfo*, + thin_function> + m_multi_tensor_replace_func; VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag::CHECK_ALL; public: