提交 ce119ef5 编写于 作者: M Megvii Engine Team

fix(lite): fix lite error when record level is 2

GitOrigin-RevId: 7dabfd8876f5c6b3004db1fea6faa8d5152066bd
上级 643ab1c1
......@@ -436,6 +436,10 @@ void NetworkImplDft::start() const {
void NetworkImplDft::forward() {
start();
if (m_load_config.comp_graph &&
m_user_config->options.comp_node_seq_record_level == 2) {
m_load_config.comp_graph.reset();
}
LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
m_execute_func->execute();
}
......
......@@ -89,6 +89,23 @@ TEST(TestNetWorkOptions, const_shape) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}
TEST(TestNetWorkOptions, record2) {
Config config;
std::string model_path = "./shufflenet.mge";
config.options.var_sanity_check_first_run = false;
config.options.const_shape = true;
config.options.comp_node_seq_record_level = 2;
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
for (int i = 0; i < 3; i++) {
network->forward();
network->wait();
}
}
TEST(TestNetWorkOptions, NCHW44) {
Config config;
auto tensor = get_input_data("./input_data.npy");
......
......@@ -126,7 +126,7 @@ ComputingGraph::ComputingGraph() {
void ComputingGraph::assert_destroy(std::shared_ptr<ComputingGraph>& ptr) {
mgb_assert(
ptr.use_count() == 1, "unexpected use_count: %zu", size_t(ptr.use_count()));
ptr.use_count() <= 2, "unexpected use_count: %zu", size_t(ptr.use_count()));
ptr.reset();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册