diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index f18873e16911b1704ecf0c8cfeaac1947c6f24f2..109c27af6db815a71c273c8461107344714e35f5 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -157,13 +157,19 @@ bool CombineLikeGraphs(const ResourcePtr &res) { if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { continue; } + auto &cloned_nodes = *cloner->cloned_node(); for (auto &fv : fg->paramter_obj_nodes()) { TraceManager::DebugTrace(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); TraceManager::EndTrace(); auto &node_users = res->manager()->node_users()[fv]; for (auto &n : node_users) { - auto repl_n = (*cloner->cloned_node())[n.first]->cast(); + // If the user is not in this graph, no need to change. + auto cloned = cloned_nodes[n.first]; + if (cloned == nullptr) { + continue; + } + auto repl_n = cloned->cast(); repl_n->set_input(n.second, param); } } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 9d81dc4c3b7719ce051590fcfe313312e965f222..60ec92167d90bde2792a17bbd492a388d0393227 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -109,7 +109,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object node->set_abstract(abs); para_node = node; } - + func_graph->add_parameter_obj_node(para_node); return para_node; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index dd35f15dd6eec273c52cb3155163163790272335..be45af748af18089c9638e346cc3f75228eed792 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -189,12 +189,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { - if (arg_spec->isa()) { - return arg_spec->cast()->ref()->GetShapeTrack(); - } - return arg_spec->GetShapeTrack(); - }); + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } @@ -212,12 +208,8 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); func_graph_->joined_shapes_.clear(); std::transform(joined_args_spec_list.begin(), joined_args_spec_list.end(), - std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { - if (arg_spec->isa()) { - return arg_spec->cast()->ref()->GetShapeTrack(); - } - return arg_spec->GetShapeTrack(); - }); + std::back_inserter(func_graph_->joined_shapes_), + [](const AbstractBasePtr &arg_spec) { return arg_spec->GetShapeTrack(); }); joined_args_spec_list = NormalizeArgs(joined_args_spec_list); MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; } @@ -317,10 +309,17 @@ EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { AbstractBasePtrList args_spec_list; + auto is_py_eval = (identifier_ == "PythonPrimEvaluator"); (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { + [is_py_eval](const ConfigPtr &conf) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); + auto abstract = conf->GetEvaluatedValue()->abstract(); + // broaden the ref_key, while infer python prim for cache + if (is_py_eval && abstract->isa()) { + auto abs_ref = abstract->cast(); + abstract = std::make_shared(abs_ref->ref_key()->Broaden(), abs_ref); + } + return abstract; }); EvalResultPtr ret = EvalPrim(engine, args_spec_list); return ret; diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index ec5d181469204fd51b160ad54663bdb4cf5cdb47..65d9a7deffcb15c586db9421b4dd37095b61164d 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -409,7 +409,7 @@ def _run_op(obj, op_name, args): if op_name == "Cast" or obj.update_parameter: cast_args = args else: - cast_args = args + cast_args = list(args) for idx, arg in enumerate(args): cast_type = getattr(arg, "cast_type", None) if cast_type: