提交 ffa33520 编写于 作者: F fary86

Fix partial primitive poly node

上级 703c1b26
......@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
}
auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);
if (func->context() != nullptr) {
if (!IsVisible(func_graph_, func->context()->func_graph())) {
MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
}
} else {
if (func->context() == nullptr) {
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
}
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
......@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is partial, second is func so arg is start from 2
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
func = inputs[1];
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);
}
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);
AbstractBasePtrList argvals;
MS_EXCEPTION_IF_NULL(new_inputs[0]);
......@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
}
if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
if (!func->isa<ValueNode>()) {
MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
AbstractBasePtrList empty_args;
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
func->abstract()->isa<PartialAbstractClosure>()))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}
}
}
if (CanSpecializeNode(func)) {
......
......@@ -14,9 +14,12 @@
# ============================================================================
""" test nn ops """
import numpy as np
from numpy.random import normal
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.composite import core
from mindspore.common.api import ms_function
from mindspore import Tensor
from mindspore.ops import functional as F
......@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
net(t1, t2)
# test free variable function list as parameter
def test_remove_and_fv_2():
@core(loop_can_uroll=True)
def inner_loop(x, input_data, fv_func_list):
ret = ()
for fv_fn in fv_func_list:
ele = fv_fn(input_data)
ret += (ele,)
return ret
@ms_function
def out_loop(input1, input_data):
ret = ()
def fv_func1(y):
return input1 * y
def fv_func2(y):
return input1 - y
fv_func_list = [fv_func1, fv_func2]
ele0 = inner_loop(input1, input_data[0], fv_func_list)
ele1 = inner_loop(input1, input_data[1], fv_func_list)
ret = (ele0, ele1)
return ret
input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
input1 = Tensor(normal(0, 0.1, (3, 3)))
out_loop(input1, input_data)
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_conv2d_op_with_arg():
def test_conv2d_op_with_argi_1():
class Conv2dNet(nn.Cell):
def __init__(self):
super(Conv2dNet, self).__init__()
......@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_partial_as_arg():
def test_partial_as_arg():
class PartialArgNet(nn.Cell):
def __init__(self):
super(PartialArgNet, self).__init__()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册