diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index ebc1f65486cad0cbf038614e78c4b8a7f45c7fbb..f90fc5039c0510d033eb5f7ac9840c9ad4a32e20 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& if (para_node == nullptr) { ParameterPtr node = top_graph->AddWeightParameter(param_name); node->set_default_param(obj); + + // set_abstract for parameter + auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); + ValuePtr converted = nullptr; + (void)ConvertData(to_convert, &converted); + bool broaden = true; + node->set_abstract(abstract::FromValue(converted, broaden)); + para_node = node; } auto iter = func_graph->make_ref_params().find(para_node); diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index e0336443e65df20a0ed3c0ebf67ad9b5ccaf0331..a58ecf41b6015e27881e10680de0f7a6ddb6a14b 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); + opt::irpass::ResolveIRPassLib resolve_irpass; + + opt::OptPassConfig resolve_pass = opt::OptPassConfig({ + resolve_irpass.resolver_resolve_, + resolve_irpass.resolver_getattr_, + irpass.get_make_ref_eliminate_, + }); OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, @@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, {"virtual_dataset", virtual_dataset}, {"grad", grad}, + {"resolve", resolve_pass}, {"renormalize", opt::OptPassConfig::Renormalize()}, {"cse", opt::OptPassConfig(opt::CSE(false))}, {"a_3", a_3}}); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 56bcd77f671cf892ae1576615a2a386897967fcb..d71e0980094d69757ed06db0555f96a1fa5a3f80 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat return eng->ForwardConfig(old_conf, fn_conf); } -AbstractBasePtr GenerateResolveAbstract(const AnfNodeConfigPtr &out_conf, const py::object &obj, - const ValuePtr &converted_ret) { - if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { - TypePtr cls_ptr = parse::ParseDataClass(converted_ret->cast>()->obj()); - - std::vector input = {NewValueNode(prim::kPrimPartial), NewValueNode(prim::kPrimMakeRecord), - NewValueNode(cls_ptr)}; - MS_EXCEPTION_IF_NULL(out_conf); - FuncGraphPtr func_graph = out_conf->node()->func_graph(); - CNodePtr new_cnode = func_graph->NewCNode(input); - AnalysisEnginePtr eng = out_conf->engine(); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, out_conf->context()); - return eng->ForwardConfig(out_conf, fn_conf); - } else { - return ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); - } -} - AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const AnfNodeConfigPtr &out_conf) { @@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng // item_name to func addr from obj_map parse::SymbolPtr symbol = item_v->cast(); parse::NameSpacePtr name_space = data_v->cast(); + FuncGraphPtr func_graph = out_conf->node()->func_graph(); - parse::SymbolResolverPtr symbol_resolver = - std::make_shared(name_space, symbol, out_conf->node()); - if (!symbol_resolver->Resolve()) { + auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); + if (new_node == nullptr) { MS_LOG(EXCEPTION) << "Resolve node failed"; } - py::object obj = symbol_resolver->result(); - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(obj, &converted_ret, true); - if (!converted) { - MS_LOG(EXCEPTION) << "Convert data failed"; - } - if (converted_ret->isa()) { - AddToManager(engine, converted_ret->cast()); - } - return GenerateResolveAbstract(out_conf, obj, converted_ret); + AnalysisEnginePtr eng = out_conf->engine(); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); + return eng->ForwardConfig(out_conf, fn_conf); } AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, diff --git a/tests/ut/python/pynative_mode/test_insert_grad_of.py b/tests/ut/python/pynative_mode/test_insert_grad_of.py index 38432d79f3eb4ccc148c3e860701db43b0126d41..104ac4d1c7b5ffaececf24becf544b01e13e1b0b 100644 --- a/tests/ut/python/pynative_mode/test_insert_grad_of.py +++ b/tests/ut/python/pynative_mode/test_insert_grad_of.py @@ -17,13 +17,14 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.common.api import ms_function from ....mindspore_test_framework.utils.bprop_util import bprop from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell from mindspore import Tensor from mindspore import context - +import mindspore def setup_module(module): context.set_context(mode=context.PYNATIVE_MODE) @@ -107,3 +108,36 @@ def test_print_shape_type(): return z bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)), Tensor(np.ones([2, 2]).astype(np.float32))) + +def test_cell_assign(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + class GradNetWrap(nn.Cell): + """ GradNetWrap definition """ + def __init__(self, net): + super(GradNetWrap, self).__init__() + self.net = net + self.weights = mindspore.ParameterTuple(net.get_parameters()) + + def construct(self, x, y): + return C.grad_by_list(self.net, self.weights)(x, y) + + class Mul(nn.Cell): + def __init__(self): + super(Mul, self).__init__() + self.get_g = P.InsertGradientOf(self.save_gradient) + self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w") + self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g") + + def save_gradient(self, dout): + self.matrix_g = dout + return dout + + def construct(self, x, y): + z = x * self.matrix_w + z = self.get_g(z) + z = z * y + return z + + input_x = Tensor(np.ones([2, 2], np.float32)) + input_y = Tensor(np.ones([2, 2], np.float32)) + GradNetWrap(Mul())(input_x, input_y)