提交 202b4071 编写于 作者: M Megvii Engine Team

fix(core): fix output var replaced by optpass

GitOrigin-RevId: aea62de3454722df4610dc3799acbd362ae942e0
上级 e715423f
......@@ -510,7 +510,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) {
std::vector<std::vector<std::shared_ptr<Tensor>>> inputs;
std::vector<std::vector<std::shared_ptr<Tensor>>> outputs;
std::shared_ptr<Network> network = std::make_shared<Network>();
Config config;
config.options.graph_opt_level = 0;
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
input_names = network->get_all_input_name();
......@@ -559,10 +562,10 @@ void test_io_no_copy_ax(std::string model_name, int record = 1) {
outputs.push_back(net_outputs);
}
Config config;
config.options.force_output_use_user_specified_memory = true;
config.options.comp_node_seq_record_level = record;
config.options.const_shape = true;
config.options.graph_opt_level = 2;
std::shared_ptr<Network> network_record = std::make_shared<Network>(config);
......
......@@ -10,6 +10,7 @@
*/
#include "megbrain/serialization/serializer.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/utility.h"
namespace mgb {
......@@ -27,6 +28,35 @@ std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
return ret;
}
void GraphLoader::LoadResult::graph_compile_ahead() {
//! when force_output_use_user_specified_memory is set, the output var may
//! be changed by gopt, then the var in LoadResult can not exist, so here
//! just do basic optimize_for_inference ahead, and replace the var in
//! LoadResult
if (graph->options().force_output_use_user_specified_memory) {
auto options = gopt::OptimizeForInferenceOptions{};
auto new_vars = gopt::optimize_for_inference(output_var_list, options);
output_var_list = new_vars;
output_var_map.clear();
for (auto& var : new_vars) {
output_var_map[var.node()->cname()] = var;
}
std::unordered_map<size_t, SymbolVar> var_map_id;
for (auto& var : new_vars) {
bool found = false;
for (auto& old_var_it : output_var_map_id) {
if (old_var_it.second.node()->name() == var.node()->name()) {
found = true;
var_map_id[old_var_it.first] = var;
}
}
mgb_assert(
found, "can't find var name %s when optimize_for_inference. ",
var.node()->cname());
}
}
}
GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
SharedTensorNameMap ret;
for (auto&& i : shared_tensor_id_map()) {
......
......@@ -946,6 +946,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
mgb_assert(fbs_end > cur);
// Skip to Graph end
m_file->skip(fbs_end - cur);
result.graph_compile_ahead();
return result;
}
......
......@@ -63,6 +63,14 @@ public:
*/
MGE_WIN_DECLSPEC_FUC std::unique_ptr<cg::AsyncExecutable> graph_compile(
const ComputingGraph::OutputSpec& outspec);
/*!
* \brief after graph is loaded, do some basic optimized_for_inference,
* because some dest var maybe replaced, case error when optimize flag
* force_output_use_user_specified_memory is on
*
*/
MGE_WIN_DECLSPEC_FUC void graph_compile_ahead();
};
//! helper to disable inplace arith graph optimization during
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册