diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index ad17fa14f04e9ce7a1cc77e4e1933f09aedb38d5..2b2564d6083d944e072b59dfbe5aae40456f1a78 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -510,7 +510,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { std::vector>> inputs; std::vector>> outputs; - std::shared_ptr network = std::make_shared(); + Config config; + + config.options.graph_opt_level = 0; + std::shared_ptr network = std::make_shared(config); network->load_model(model_path); input_names = network->get_all_input_name(); @@ -559,10 +562,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) { outputs.push_back(net_outputs); } - Config config; config.options.force_output_use_user_specified_memory = true; config.options.comp_node_seq_record_level = record; config.options.const_shape = true; + config.options.graph_opt_level = 2; std::shared_ptr network_record = std::make_shared(config); diff --git a/src/serialization/impl/serializer.cpp b/src/serialization/impl/serializer.cpp index 88bd6922eca8471b754d90b48dacfb3fa7d824f9..8e46060ec27dabcd41c9b4f573929f182431c1b4 100644 --- a/src/serialization/impl/serializer.cpp +++ b/src/serialization/impl/serializer.cpp @@ -10,6 +10,7 @@ */ #include "megbrain/serialization/serializer.h" +#include "megbrain/gopt/inference.h" #include "megbrain/opr/utility.h" namespace mgb { @@ -27,6 +28,35 @@ std::unique_ptr GraphLoader::LoadResult::graph_compile( return ret; } +void GraphLoader::LoadResult::graph_compile_ahead() { + //! when force_output_use_user_specified_memory is set, the output var may + //! be changed by gopt, then the var in LoadResult can not exist, so here + //! just do basic optimize_for_inference ahead, and replace the var in + //! LoadResult + if (graph->options().force_output_use_user_specified_memory) { + auto options = gopt::OptimizeForInferenceOptions{}; + auto new_vars = gopt::optimize_for_inference(output_var_list, options); + output_var_list = new_vars; + output_var_map.clear(); + for (auto& var : new_vars) { + output_var_map[var.node()->cname()] = var; + } + std::unordered_map var_map_id; + for (auto& var : new_vars) { + bool found = false; + for (auto& old_var_it : output_var_map_id) { + if (old_var_it.second.node()->name() == var.node()->name()) { + found = true; + var_map_id[old_var_it.first] = var; + } + } + mgb_assert( + found, "can't find var name %s when optimize_for_inference. ", + var.node()->cname()); + } + } +} + GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { SharedTensorNameMap ret; for (auto&& i : shared_tensor_id_map()) { diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index b2aa73705e9fdd7a34e9a9b26d68df74b928d6cf..9b0cae341fc0f6da65107bf51a9f1e911e479716 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -946,6 +946,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi mgb_assert(fbs_end > cur); // Skip to Graph end m_file->skip(fbs_end - cur); + result.graph_compile_ahead(); return result; } diff --git a/src/serialization/include/megbrain/serialization/serializer.h b/src/serialization/include/megbrain/serialization/serializer.h index feea6efd4d74d8e62203699cf3f9a2763339a8ca..1868918d93a3af325e20cbc5053dd9221ac5cea6 100644 --- a/src/serialization/include/megbrain/serialization/serializer.h +++ b/src/serialization/include/megbrain/serialization/serializer.h @@ -63,6 +63,14 @@ public: */ MGE_WIN_DECLSPEC_FUC std::unique_ptr graph_compile( const ComputingGraph::OutputSpec& outspec); + + /*! + * \brief after graph is loaded, do some basic optimized_for_inference, + * because some dest var maybe replaced, case error when optimize flag + * force_output_use_user_specified_memory is on + * + */ + MGE_WIN_DECLSPEC_FUC void graph_compile_ahead(); }; //! helper to disable inplace arith graph optimization during