提交 5ebc9d50 编写于 作者: M Megvii Engine Team

fix(pylite): fix lite global layout transform and fast run conflict error

GitOrigin-RevId: 910c8da19f3c9973a088b70b782e91f6b366d4f9
上级 49d92d9c
...@@ -40,7 +40,6 @@ public: ...@@ -40,7 +40,6 @@ public:
void wait() override; void wait() override;
//! enable global layout transform //! enable global layout transform
void set_layout_transform(bool state) { enable_layout_transform = state; } void set_layout_transform(bool state) { enable_layout_transform = state; }
//! get the network of lite model //! get the network of lite model
......
...@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet): ...@@ -468,3 +468,29 @@ class TestNetwork(TestShuffleNet):
fi = open("./model_afer_layoutTrans.mgb", "r") fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close() fi.close()
os.remove("./model_afer_layoutTrans.mgb") os.remove("./model_afer_layoutTrans.mgb")
def test_fast_run_and_global_layout_transform(self):
config_ = LiteConfig()
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()
LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)
os.remove(fast_run_cache)
os.remove(global_layout_transform_model)
...@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda): ...@@ -293,3 +293,31 @@ class TestNetwork(TestShuffleNetCuda):
fi = open("./model_afer_layoutTrans.mgb", "r") fi = open("./model_afer_layoutTrans.mgb", "r")
fi.close() fi.close()
os.remove("./model_afer_layoutTrans.mgb") os.remove("./model_afer_layoutTrans.mgb")
@require_cuda()
def test_fast_run_and_global_layout_transform(self):
config_ = LiteConfig()
config_.device_type = LiteDeviceType.LITE_CUDA
network = LiteNetwork(config_)
fast_run_cache = "./algo_cache"
global_layout_transform_model = "./model_afer_layoutTrans.mgb"
network.set_network_algo_policy(
LiteAlgoSelectStrategy.LITE_ALGO_PROFILE
| LiteAlgoSelectStrategy.LITE_ALGO_OPTIMIZED
)
network.enable_global_layout_transform()
network.load(self.model_path)
self.do_forward(network)
network.dump_layout_transform_model(global_layout_transform_model)
LiteGlobal.dump_persistent_cache(fast_run_cache)
fi = open(fast_run_cache, "r")
fi.close()
fi = open(global_layout_transform_model, "r")
fi.close()
LiteGlobal.set_persistent_cache(path=fast_run_cache)
self.do_forward(network)
os.remove(fast_run_cache)
os.remove(global_layout_transform_model)
...@@ -422,6 +422,8 @@ void NetworkImplDft::load_model( ...@@ -422,6 +422,8 @@ void NetworkImplDft::load_model(
m_load_result = m_loader->load(m_load_config, true); m_load_result = m_loader->load(m_load_config, true);
modify_exection_policy();
global_layout_transform(); global_layout_transform();
adapt_option_valid(); adapt_option_valid();
...@@ -436,7 +438,6 @@ void NetworkImplDft::load_model( ...@@ -436,7 +438,6 @@ void NetworkImplDft::load_model(
} }
void NetworkImplDft::compile_graph() { void NetworkImplDft::compile_graph() {
modify_exection_policy();
replace_dev_input_pass(); replace_dev_input_pass();
make_output_spec(); make_output_spec();
m_execute_func = m_load_result.graph_compile(m_output_spec); m_execute_func = m_load_result.graph_compile(m_output_spec);
...@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy( ...@@ -793,7 +794,8 @@ void NetworkImplDft::set_network_algo_policy(
if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) { if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
dst_strategy = dst_strategy | S::OPTIMIZED; dst_strategy = dst_strategy | S::OPTIMIZED;
} }
m_execution_policy = dst_strategy; if (static_cast<uint32_t>(dst_strategy) != 0)
m_execution_policy = dst_strategy;
auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config; auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
fast_run_config.binary_equal_between_batch = binary_equal_between_batch; fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
...@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy( ...@@ -808,12 +810,10 @@ void NetworkImplDft::set_network_algo_policy(
} }
void NetworkImplDft::modify_exection_policy() { void NetworkImplDft::modify_exection_policy() {
mgb::SymbolVarArray vars; auto& vars = m_load_result.output_var_list;
for (auto i : m_output_spec) { if (static_cast<uint32_t>(m_execution_policy) != 0) {
vars.push_back(i.first);
}
if (static_cast<uint32_t>(m_execution_policy) != 0)
mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy); mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
}
} }
//! set opr algorithm selection strategy in the network //! set opr algorithm selection strategy in the network
......
...@@ -289,21 +289,21 @@ namespace intl { ...@@ -289,21 +289,21 @@ namespace intl {
template <typename Opr> template <typename Opr>
struct OprFormatModifier; struct OprFormatModifier;
#define INST(_Opr) \ #define INST(_Opr) \
template <> \ template <> \
struct OprFormatModifier<_Opr> { \ struct OprFormatModifier<_Opr> { \
using OprFormat = typename _Opr::Param::Format; \ using OprFormat = typename _Opr::Param::Format; \
static VarNode* make( \ static VarNode* make( \
OprFormat opr_format, const VarNodeArray& i, \ OprFormat opr_format, const VarNodeArray& i, \
const cg::OperatorNodeBase* opr_) { \ const cg::OperatorNodeBase* opr_) { \
MIDOUT_B(_Opr) \ MIDOUT_B(_Opr) \
auto&& opr = opr_->cast_final_safe<_Opr>(); \ auto&& opr = opr_->cast_final_safe<_Opr>(); \
auto param = opr.param(); \ auto param = opr.param(); \
param.format = opr_format; \ param.format = opr_format; \
return OprWithPolicyMaker<_Opr>::make( \ return OprWithPolicyMaker<_Opr>::make( \
i, param, opr.execution_policy(), opr.config()); \ i, param, opr.execution_policy_transient(), opr.config()); \
MIDOUT_E \ MIDOUT_E \
} \ } \
}; };
INST(Convolution); INST(Convolution);
INST(ConvBiasForward); INST(ConvBiasForward);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册