提交 e8ae576d 编写于 作者: P panyifeng 提交者: 高东海

fix grad missing due to indirect dependent free morphism

上级 a86b31ee
...@@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { ...@@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
return node_adjoint; return node_adjoint;
} }
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
// Do not care about non-CNode
if (!node->isa<CNode>()) {
return false;
}
// Do not care about kPrimReturn
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
return false;
}
auto &users = primal_graph_->manager()->node_users()[node];
// Do not care about isolated morphisms
if (users.empty()) {
return false;
}
// Not free if it's used by some node in primal_graph
bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
auto &user = kv.first;
return user->func_graph() == primal_graph_;
});
return !nonfree;
}
void DFunctor::MapFreeMorphism() { void DFunctor::MapFreeMorphism() {
// Handle cnode not attached to output, that might be refered in other functions. // Handle cnode not attached to output, that might be refered in other functions.
for (auto &node : primal_graph_->nodes()) { for (auto &node : primal_graph_->nodes()) {
auto adjoint = FindAdjoint(node); if (!IsFreeMorphism(node)) {
if (adjoint != nullptr) {
continue;
}
if (!node->isa<CNode>()) {
MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " "
<< node->type_name() << ".";
continue;
}
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
continue; continue;
} }
MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
...@@ -256,9 +269,10 @@ void DFunctor::MapMorphism() { ...@@ -256,9 +269,10 @@ void DFunctor::MapMorphism() {
// Set stop_gradient before MapMorphism. // Set stop_gradient before MapMorphism.
BroadCastStopFlag(); BroadCastStopFlag();
// Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
MapFreeMorphism();
// Handle morphism from output. // Handle morphism from output.
(void)MapMorphism(primal_graph_->output()); (void)MapMorphism(primal_graph_->output());
MapFreeMorphism();
// Construct K for primal_graph_ // Construct K for primal_graph_
auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output());
...@@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { ...@@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
const size_t param_diff = 1; const size_t param_diff = 1;
if (bprop_graph->output()->isa<CNode>() && if (bprop_graph->output()->isa<CNode>() &&
bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) { bprop_graph->output()->cast<CNodePtr>()->size() + param_diff != bprop_graph->parameters().size()) {
MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " // It does not matter with the final tangents, just a tip for debugging
<< primal->output()->scope()->name() MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope "
<< " output must be a tuple and output number should be the same with inputs."; << primal->output()->scope()->name()
<< " output must be a tuple and output number should be the same with inputs.";
} }
resources_->manager()->AddFuncGraph(bprop_graph); resources_->manager()->AddFuncGraph(bprop_graph);
......
...@@ -61,6 +61,7 @@ class DFunctor { ...@@ -61,6 +61,7 @@ class DFunctor {
private: private:
// Map one morphism. // Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph); AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node);
// Map morphism that's not attached to output. // Map morphism that's not attached to output.
void MapFreeMorphism(); void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
......
...@@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { ...@@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
irpass.replace_applicator_, irpass.replace_applicator_,
}); });
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
OptPassGroupMap map_a({{"a_1", a_1}, OptPassGroupMap map_a({{"a_1", a_1},
{"a_2", a_2}, {"a_2", a_2},
......
...@@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell): ...@@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell):
def test_grad_mul_add_with_wrong_output_num(): def test_grad_mul_add_with_wrong_output_num():
mul_add = MulAddWithWrongOutputNum() mul_add = MulAddWithWrongOutputNum()
with pytest.raises(RuntimeError): C.grad_all(mul_add)(1, 2)
C.grad_all(mul_add)(1, 2)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" test_framstruct """ """ test_framstruct """
import pytest import pytest
import numpy as np import numpy as np
import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.ops import composite as C from mindspore.ops import composite as C
...@@ -706,3 +707,24 @@ def grad_refactor_14(a, b): ...@@ -706,3 +707,24 @@ def grad_refactor_14(a, b):
return inner1(b) + inner2(a) + inner3(a) return inner1(b) + inner2(a) + inner3(a)
def test_grad_refactor_14(): def test_grad_refactor_14():
assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
class IfDeferInline(nn.Cell):
def __init__(self, mul_size):
super().__init__()
self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
self.mul = P.Mul()
def construct(self, inputs):
x = self.mul(inputs, self.mul_weight)
if True:
x = x
return x
def test_grad_if_defer_inline():
""" test_grad_if_defer_inline """
network = IfDeferInline([128, 96])
network.add_flags(defer_inline=False)
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grads = C.grad_all(network)(inp)
assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册