提交 826a7393 编写于 作者: W Wei Luning 提交者: 高东海

fix InsertGradientOf with class method

上级 2ebd6049
......@@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object&
if (para_node == nullptr) {
ParameterPtr node = top_graph->AddWeightParameter(param_name);
node->set_default_param(obj);
// set_abstract for parameter
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ValuePtr converted = nullptr;
(void)ConvertData(to_convert, &converted);
bool broaden = true;
node->set_abstract(abstract::FromValue(converted, broaden));
para_node = node;
}
auto iter = func_graph->make_ref_params().find(para_node);
......
......@@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
opt::irpass::ResolveIRPassLib resolve_irpass;
opt::OptPassConfig resolve_pass = opt::OptPassConfig({
resolve_irpass.resolver_resolve_,
resolve_irpass.resolver_getattr_,
irpass.get_make_ref_eliminate_,
});
OptPassGroupMap map_a({{"a_1", a_1},
{"a_2", a_2},
......@@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
{"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
{"virtual_dataset", virtual_dataset},
{"grad", grad},
{"resolve", resolve_pass},
{"renormalize", opt::OptPassConfig::Renormalize()},
{"cse", opt::OptPassConfig(opt::CSE(false))},
{"a_3", a_3}});
......
......@@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
return eng->ForwardConfig(old_conf, fn_conf);
}
AbstractBasePtr GenerateResolveAbstract(const AnfNodeConfigPtr &out_conf, const py::object &obj,
const ValuePtr &converted_ret) {
if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) {
TypePtr cls_ptr = parse::ParseDataClass(converted_ret->cast<std::shared_ptr<parse::PyObjectWrapper>>()->obj());
std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial), NewValueNode(prim::kPrimMakeRecord),
NewValueNode(cls_ptr)};
MS_EXCEPTION_IF_NULL(out_conf);
FuncGraphPtr func_graph = out_conf->node()->func_graph();
CNodePtr new_cnode = func_graph->NewCNode(input);
AnalysisEnginePtr eng = out_conf->engine();
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, out_conf->context());
return eng->ForwardConfig(out_conf, fn_conf);
} else {
return ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
}
}
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
......@@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
// item_name to func addr from obj_map
parse::SymbolPtr symbol = item_v->cast<parse::SymbolPtr>();
parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
FuncGraphPtr func_graph = out_conf->node()->func_graph();
parse::SymbolResolverPtr symbol_resolver =
std::make_shared<parse::SymbolResolver>(name_space, symbol, out_conf->node());
if (!symbol_resolver->Resolve()) {
auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node());
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Resolve node failed";
}
py::object obj = symbol_resolver->result();
ValuePtr converted_ret = nullptr;
bool converted = parse::ConvertData(obj, &converted_ret, true);
if (!converted) {
MS_LOG(EXCEPTION) << "Convert data failed";
}
if (converted_ret->isa<FuncGraph>()) {
AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
}
return GenerateResolveAbstract(out_conf, obj, converted_ret);
AnalysisEnginePtr eng = out_conf->engine();
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context());
return eng->ForwardConfig(out_conf, fn_conf);
}
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
......
......@@ -17,13 +17,14 @@ import numpy as np
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.api import ms_function
from ....mindspore_test_framework.utils.bprop_util import bprop
from ....mindspore_test_framework.utils.debug_util import PrintShapeTypeCell, PrintGradShapeTypeCell
from mindspore import Tensor
from mindspore import context
import mindspore
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
......@@ -107,3 +108,36 @@ def test_print_shape_type():
return z
bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32)))
def test_cell_assign():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class GradNetWrap(nn.Cell):
""" GradNetWrap definition """
def __init__(self, net):
super(GradNetWrap, self).__init__()
self.net = net
self.weights = mindspore.ParameterTuple(net.get_parameters())
def construct(self, x, y):
return C.grad_by_list(self.net, self.weights)(x, y)
class Mul(nn.Cell):
def __init__(self):
super(Mul, self).__init__()
self.get_g = P.InsertGradientOf(self.save_gradient)
self.matrix_w = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_w")
self.matrix_g = mindspore.Parameter(Tensor(np.ones([2, 2], np.float32)), name="matrix_g")
def save_gradient(self, dout):
self.matrix_g = dout
return dout
def construct(self, x, y):
z = x * self.matrix_w
z = self.get_g(z)
z = z * y
return z
input_x = Tensor(np.ones([2, 2], np.float32))
input_y = Tensor(np.ones([2, 2], np.float32))
GradNetWrap(Mul())(input_x, input_y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册