提交 5ebc9d50 编写于 作者: M Megvii Engine Team

fix(pylite): fix lite global layout transform and fast run conflict error

GitOrigin-RevId: 910c8da19f3c9973a088b70b782e91f6b366d4f9
上级 49d92d9c
......@@ -40,7 +40,6 @@ public:
void wait() override;
//! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model
......
......@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet):
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")
def test_fast_run_and_global_layout_transform(self):
config_ = LiteConfig()
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()
LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)
os.remove(fast_run_cache)
os.remove(global_layout_transform_model)
......@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda):
fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close()
os.remove("./model_afer_layoutTrans.mgb")
@require_cuda()
def test_fast_run_and_global_layout_transform(self):
config_ = LiteConfig()
config_.device_type = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()
LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)
os.remove(fast_run_cache)
os.remove(global_layout_transform_model)
......@@ -422,6 +422,8 @@ void NetworkImplDft::load_model(
m_load_result = m_loader->load(m_load_config, true);
modify_exection_policy();
global_layout_transform();
adapt_option_valid();
......@@ -436,7 +438,6 @@ void NetworkImplDft::load_model(
}
void NetworkImplDft::compile_graph() {
modify_exection_policy();
replace_dev_input_pass();
make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec);
......@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy(
if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
dst_strategy = dst_strategy | S::OPTIMIZED;
}
m_execution_policy = dst_strategy;
if (static_cast<uint32_t>(dst_strategy) != 0)
m_execution_policy = dst_strategy;
auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
......@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy(
}
void NetworkImplDft::modify_exection_policy() {
mgb::SymbolVarArray vars;
for (auto i : m_output_spec) {
vars.push_back(i.first);
}
if (static_cast<uint32_t>(m_execution_policy) != 0)
auto& vars = m_load_result.output_var_list;
if (static_cast<uint32_t>(m_execution_policy) != 0) {
mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
}
}
//! set opr algorithm selection strategy in the network
......
......@@ -289,21 +289,21 @@ namespace intl {
template <typename Opr>
struct OprFormatModifier;
#define INST(_Opr) \
template <> \
struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy(), opr.config()); \
MIDOUT_E \
} \
#define INST(_Opr) \
template <> \
struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \
param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy_transient(), opr.config()); \
MIDOUT_E \
} \
};
INST(Convolution);
INST(ConvBiasForward);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册