diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc index 2a03eb6d5c1b474207414e16c25d8138842facf0..e01b98841ba470d8dd9f74098952f73f6a19ce01 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc @@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr } auto real_eval = dyn_cast(eval); - if (func->context() != nullptr) { - if (!IsVisible(func_graph_, func->context()->func_graph())) { - MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); - } - } else { + if (func->context() == nullptr) { MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); } AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); @@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { // First element is partial, second is func so arg is start from 2 (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); func = inputs[1]; - new_inputs = args; - (void)new_inputs.insert(new_inputs.begin(), func); } + new_inputs = args; + (void)new_inputs.insert(new_inputs.begin(), func); AbstractBasePtrList argvals; MS_EXCEPTION_IF_NULL(new_inputs[0]); @@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); } - if (func->isa() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { - auto wrapped_node = BuildSpecializedParameterNode(new_node); - new_inputs[0] = wrapped_node; + if (!func->isa()) { + MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); + if (func->abstract()->isa() && !func->abstract()->isa()) { + auto func_abs = func->abstract()->cast(); + EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); + std::pair result; + AbstractBasePtrList empty_args; + auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); + MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; + // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early + if (status == kSpecializeFindUniqueArgvalPoly || + (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || + func->abstract()->isa()))) { + auto wrapped_node = BuildSpecializedParameterNode(new_node); + new_inputs[0] = wrapped_node; + } + } } if (CanSpecializeNode(func)) { diff --git a/tests/ut/python/ops/test_ops_attr_infer.py b/tests/ut/python/ops/test_ops_attr_infer.py index d4f6d8e2993da54df58eb14192baec42ced5ed10..6f187105586502ca0c9f67b4569fa45d714e48e1 100644 --- a/tests/ut/python/ops/test_ops_attr_infer.py +++ b/tests/ut/python/ops/test_ops_attr_infer.py @@ -14,9 +14,12 @@ # ============================================================================ """ test nn ops """ import numpy as np +from numpy.random import normal import mindspore.nn as nn import mindspore.context as context +from mindspore.ops.composite import core +from mindspore.common.api import ms_function from mindspore import Tensor from mindspore.ops import functional as F @@ -59,10 +62,39 @@ def test_conv2d_same_primitive(): net(t1, t2) +# test free variable function list as parameter +def test_remove_and_fv_2(): + @core(loop_can_uroll=True) + def inner_loop(x, input_data, fv_func_list): + ret = () + for fv_fn in fv_func_list: + ele = fv_fn(input_data) + ret += (ele,) + return ret + + @ms_function + def out_loop(input1, input_data): + ret = () + + def fv_func1(y): + return input1 * y + def fv_func2(y): + return input1 - y + fv_func_list = [fv_func1, fv_func2] + ele0 = inner_loop(input1, input_data[0], fv_func_list) + ele1 = inner_loop(input1, input_data[1], fv_func_list) + ret = (ele0, ele1) + return ret + + input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1)))) + input1 = Tensor(normal(0, 0.1, (3, 3))) + out_loop(input1, input_data) + + # test cell as high order argument # The graph with free variables used as argument is not supported yet # because of the limit of inference specialize system -def Xtest_conv2d_op_with_arg(): +def test_conv2d_op_with_argi_1(): class Conv2dNet(nn.Cell): def __init__(self): super(Conv2dNet, self).__init__() @@ -279,7 +311,7 @@ def test_op_with_arg_as_input(): # The partial application used as argument is not supported yet # because of the limit of inference specialize system -def Xtest_partial_as_arg(): +def test_partial_as_arg(): class PartialArgNet(nn.Cell): def __init__(self): super(PartialArgNet, self).__init__()