提交 3b557b11 编写于 作者: M Megvii Engine Team 提交者: 黄信达

feat(lite): load and run supports convert fp32 to fp16 online

GitOrigin-RevId: 05c9a17a00301ced15043cf34284f2d362d2608c
上级 66338652
......@@ -97,6 +97,7 @@ struct LITE_API Options {
bool enable_nchw4 = false;
bool enable_nchw32 = false;
bool enable_nchw64 = false;
bool enable_f16_io_comp = false; // convert to fp16
};
/**
......
......@@ -67,6 +67,10 @@ enum class OptLayoutType {
NHWCD4 = 1 << 6,
NCHW44_DOT = 1 << 7
};
/*!
* \brief: dtype type for running model optimization
*/
enum class OptDTypeType { IOC16 = 1 << 0 };
/**
* base class to story option value
*/
......
#include <gflags/gflags.h>
#include "misc.h"
#include "models/model_lite.h"
#include "models/model_mdl.h"
#include "dtype_options.h"
namespace lar {
template <>
void DTypeOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
#define ENABLE_DTYPE(dtype) \
LITE_LOG("enable " #dtype " optimization"); \
model->get_config().options.enable_##dtype = true; \
break;
switch (m_option_flag) {
case OptDTypeType::IOC16:
ENABLE_DTYPE(f16_io_comp)
default:
LITE_THROW(
"Set unsupport dtype, only --enable-ioc16 is supported. "
"Default case is fp32.");
break;
}
#undef ENABLE_DTYPE
}
}
template <>
void DTypeOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
#define ENABLE_DTYPE(dtype) \
mgb_log("enable " #dtype " optimization"); \
model->get_mdl_config().comp_graph->options().graph_opt.enable_##dtype(); \
break;
switch (m_option_flag) {
case OptDTypeType::IOC16:
ENABLE_DTYPE(f16_io_comp)
default:
LITE_THROW(
"Set unsupport dtype, only --enable-ioc16 is supported. "
"Default case is fp32.");
break;
}
#undef ENABLE_DTYPE
}
}
} // namespace lar
using namespace lar;
bool DTypeOption::m_valid;
void DTypeOption::update() {
m_option_name = "dtype";
m_option_flag = static_cast<OptDTypeType>(0);
m_option = {
{"enable_ioc16", lar::Bool::make(false)},
};
std::static_pointer_cast<lar::Bool>(m_option["enable_ioc16"])
->set_value(FLAGS_enable_ioc16);
}
bool DTypeOption::is_valid() {
size_t valid_flag = 0;
if (FLAGS_enable_ioc16) {
valid_flag |= static_cast<size_t>(OptDTypeType::IOC16);
}
//! only one flag is valid
bool ret = valid_flag && !(valid_flag & (valid_flag - 1));
return ret | m_valid;
};
std::shared_ptr<OptionBase> DTypeOption::create_option() {
static std::shared_ptr<DTypeOption> option(new DTypeOption);
if (DTypeOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}
void DTypeOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
size_t valid_flag = 0;
if (FLAGS_enable_ioc16 ||
std::static_pointer_cast<lar::Bool>(m_option["enable_ioc16"])->get_value()) {
valid_flag |= static_cast<size_t>(OptDTypeType::IOC16);
}
mgb_throw_if(
valid_flag && (valid_flag & (valid_flag - 1)), mgb::AssertionError,
"invalid options of dtype transform 0x%lx", valid_flag);
m_option_flag = static_cast<OptDTypeType>(valid_flag);
CONFIG_MODEL_FUN;
}
DEFINE_bool(enable_ioc16, false, "enable fp16 dtype optimization!!");
REGIST_OPTION_CREATOR(dtype, lar::DTypeOption::create_option);
REGIST_OPTION_VALIDATER(dtype, lar::DTypeOption::set_valid);
\ No newline at end of file
#pragma once
#include <gflags/gflags.h>
#include "helpers/common.h"
#include "models/model.h"
#include "option_base.h"
DECLARE_bool(enable_ioc16);
namespace lar {
/*!
* \brief: dtype option for optimization
*/
class DTypeOption final : public OptionBase {
public:
//! check the validation of option flag
static bool is_valid();
//! creat options when option is used
static std::shared_ptr<OptionBase> create_option();
//! config the model, dispatch configuration for different model implement
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
//! get option name
std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; }
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
//! Constructor
DTypeOption() = default;
//! configuration for different model implement
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
OptDTypeType m_option_flag;
std::string m_option_name;
static bool m_valid;
OptionValMap m_option;
};
} // namespace lar
\ No newline at end of file
......@@ -97,6 +97,9 @@ void NetworkImplDft::application_config() {
ConfigOptionLayoutTransform(enable_nchw32);
ConfigOptionLayoutTransform(enable_nchw64);
#undef ConfigOptionLayoutTransform
if (m_user_config->options.enable_f16_io_comp) {
options.graph_opt.enable_f16_io_comp();
}
if (m_user_config->has_compression) {
m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
}
......
......@@ -641,6 +641,22 @@ void ConvertF32ToF16Pass::apply(OptState& state) const {
for (size_t i = 0; i < origin_out.size(); i++) {
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
} else if (
m_multi_tensor_replace_func.find(opr->dyn_typeinfo()) !=
m_multi_tensor_replace_func.end()) {
auto&& new_inp = new_inp_cache;
new_inp.clear();
new_inp.reserve(opr->input().size());
for (auto i : opr->input()) {
new_inp.push_back(rewriter.get_var(i));
}
auto&& origin_out = opr->output();
auto&& cur_out =
m_multi_tensor_replace_func.at(opr->dyn_typeinfo())(opr, new_inp);
mgb_assert(origin_out.size() == cur_out.size());
for (size_t i = 0; i < origin_out.size(); i++) {
rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
}
} else {
rewriter.auto_replace_outputs(opr);
}
......@@ -691,6 +707,48 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
return opr;
};
auto replace_multi_sdt_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& multi_sdt_opr = opr->cast_final_safe<opr::MultipleDeviceTensorHolder>();
VarNodeArray cvt_vars;
cvt_vars.reserve(multi_sdt_opr.output().size());
for (size_t i = 0; i < multi_sdt_opr.output().size(); ++i) {
if (multi_sdt_opr.output(i)->dtype() == dtype::Float32()) {
cvt_vars.append({opr::TypeCvt::make(
multi_sdt_opr.output(i), dtype::Float16(), {})
.node()});
} else {
cvt_vars.append({multi_sdt_opr.output(i)});
}
}
return cvt_vars;
};
auto replace_multi_sdt_with_format_opr = [](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& multi_sdt_with_format_opr =
opr->cast_final_safe<opr::MultipleDeviceTensorWithFormatHolder>();
VarNodeArray cvt_vars;
cvt_vars.reserve(multi_sdt_with_format_opr.output().size());
for (size_t i = 0; i < multi_sdt_with_format_opr.output().size(); ++i) {
if (multi_sdt_with_format_opr.output(i)->dtype() == dtype::Float32()) {
cvt_vars.append({opr::TypeCvt::make(
multi_sdt_with_format_opr.output(i),
dtype::Float16(), {})
.node()});
} else {
cvt_vars.append({multi_sdt_with_format_opr.output(i)});
}
}
return cvt_vars;
};
auto replace_imt_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->same_type<opr::ImmutableTensor>());
mgb_assert(opr->input().size() == new_inp.size());
......@@ -934,6 +992,12 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make(bool use_f32_comp
replace_func[opr::WarpPerspective::typeinfo()] = replace_warp_opr;
replace_func[opr::Remap::typeinfo()] = replace_remap_opr;
replace_func[opr::BatchedMatrixMul::typeinfo()] = replace_batched_matmul_opr;
auto& tensor_replace_func = ret->m_multi_tensor_replace_func;
tensor_replace_func[opr::MultipleDeviceTensorHolder::typeinfo()] =
replace_multi_sdt_opr;
tensor_replace_func[opr::MultipleDeviceTensorWithFormatHolder::typeinfo()] =
replace_multi_sdt_with_format_opr;
return ret;
#endif
}
......
......@@ -73,6 +73,10 @@ class ConvertF32ToF16Pass : public Pass {
Typeinfo*,
thin_function<OperatorNodeBase*(OperatorNodeBase*, const VarNodeArray&)>>
m_opr_replace_func;
ThinHashMap<
Typeinfo*,
thin_function<VarNodeArray(OperatorNodeBase*, const VarNodeArray&)>>
m_multi_tensor_replace_func;
VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag::CHECK_ALL;
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册