提交 ffa33520 编写于 作者: F fary86

Fix partial primitive poly node

上级 703c1b26
...@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr ...@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
} }
auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval); auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
if (func->context() != nullptr) { if (func->context() == nullptr) {
if (!IsVisible(func_graph_, func->context()->func_graph())) {
MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
}
} else {
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
} }
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
...@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { ...@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is partial, second is func so arg is start from 2 // First element is partial, second is func so arg is start from 2
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
func = inputs[1]; func = inputs[1];
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);
} }
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);
AbstractBasePtrList argvals; AbstractBasePtrList argvals;
MS_EXCEPTION_IF_NULL(new_inputs[0]); MS_EXCEPTION_IF_NULL(new_inputs[0]);
...@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { ...@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
} }
if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { if (!func->isa<ValueNode>()) {
auto wrapped_node = BuildSpecializedParameterNode(new_node); MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
new_inputs[0] = wrapped_node; if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
AbstractBasePtrList empty_args;
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
func->abstract()->isa<PartialAbstractClosure>()))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}
}
} }
if (CanSpecializeNode(func)) { if (CanSpecializeNode(func)) {
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
# ============================================================================ # ============================================================================
""" test nn ops """ """ test nn ops """
import numpy as np import numpy as np
from numpy.random import normal
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.ops.composite import core
from mindspore.common.api import ms_function
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import functional as F from mindspore.ops import functional as F
...@@ -59,10 +62,39 @@ def test_conv2d_same_primitive(): ...@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
net(t1, t2) net(t1, t2)
# test free variable function list as parameter
def test_remove_and_fv_2():
@core(loop_can_uroll=True)
def inner_loop(x, input_data, fv_func_list):
ret = ()
for fv_fn in fv_func_list:
ele = fv_fn(input_data)
ret += (ele,)
return ret
@ms_function
def out_loop(input1, input_data):
ret = ()
def fv_func1(y):
return input1 * y
def fv_func2(y):
return input1 - y
fv_func_list = [fv_func1, fv_func2]
ele0 = inner_loop(input1, input_data[0], fv_func_list)
ele1 = inner_loop(input1, input_data[1], fv_func_list)
ret = (ele0, ele1)
return ret
input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
input1 = Tensor(normal(0, 0.1, (3, 3)))
out_loop(input1, input_data)
# test cell as high order argument # test cell as high order argument
# The graph with free variables used as argument is not supported yet # The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system # because of the limit of inference specialize system
def Xtest_conv2d_op_with_arg(): def test_conv2d_op_with_argi_1():
class Conv2dNet(nn.Cell): class Conv2dNet(nn.Cell):
def __init__(self): def __init__(self):
super(Conv2dNet, self).__init__() super(Conv2dNet, self).__init__()
...@@ -279,7 +311,7 @@ def test_op_with_arg_as_input(): ...@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
# The partial application used as argument is not supported yet # The partial application used as argument is not supported yet
# because of the limit of inference specialize system # because of the limit of inference specialize system
def Xtest_partial_as_arg(): def test_partial_as_arg():
class PartialArgNet(nn.Cell): class PartialArgNet(nn.Cell):
def __init__(self): def __init__(self):
super(PartialArgNet, self).__init__() super(PartialArgNet, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册