提交 84695f66 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5383 [bug]fix compile time

Merge pull request !5383 from vlne-v1/fix-compile-cost
......@@ -157,13 +157,19 @@ bool CombineLikeGraphs(const ResourcePtr &res) {
if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) {
continue;
}
auto &cloned_nodes = *cloner->cloned_node();
for (auto &fv : fg->paramter_obj_nodes()) {
TraceManager::DebugTrace(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
auto param = base_graph->add_parameter();
TraceManager::EndTrace();
auto &node_users = res->manager()->node_users()[fv];
for (auto &n : node_users) {
auto repl_n = (*cloner->cloned_node())[n.first]->cast<CNodePtr>();
// If the user is not in this graph, no need to change.
auto cloned = cloned_nodes[n.first];
if (cloned == nullptr) {
continue;
}
auto repl_n = cloned->cast<CNodePtr>();
repl_n->set_input(n.second, param);
}
}
......
......@@ -109,7 +109,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
node->set_abstract(abs);
para_node = node;
}
func_graph->add_parameter_obj_node(para_node);
return para_node;
}
......
......@@ -189,12 +189,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
......@@ -212,12 +208,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
func_graph_->joined_shapes_.clear();
std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(),
std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) {
if (arg_spec->isa<AbstractRef>()) {
return arg_spec->cast<AbstractRefPtr>()->ref()->GetShapeTrack();
}
return arg_spec->GetShapeTrack();
});
std::back_inserter(func_graph_->joined_shapes_),
[](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); });
joined_args_spec_list = NormalizeArgs(joined_args_spec_list);
MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag.";
}
......@@ -317,10 +309,17 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr) {
AbstractBasePtrList args_spec_list;
auto is_py_eval = (identifier_ == "PythonPrimEvaluator");
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
[is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue()->abstract();
auto abstract = conf->GetEvaluatedValue()->abstract();
// broaden the ref_key, while infer python prim for cache
if (is_py_eval && abstract->isa<AbstractRef>()) {
auto abs_ref = abstract->cast<AbstractRefPtr>();
abstract = std::make_shared<AbstractRef>(abs_ref->ref_key()->Broaden(), abs_ref);
}
return abstract;
});
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
return ret;
......
......@@ -409,7 +409,7 @@ def _run_op(obj, op_name, args):
if op_name == "Cast" or obj.update_parameter:
cast_args = args
else:
cast_args = args
cast_args = list(args)
for idx, arg in enumerate(args):
cast_type = getattr(arg, "cast_type", None)
if cast_type:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册