From e8ae576d1003786e0398d2360c33144f035436dd Mon Sep 17 00:00:00 2001 From: panyifeng Date: Fri, 3 Apr 2020 17:09:04 +0800 Subject: [PATCH] fix grad missing due to indirect dependent free morphism --- mindspore/ccsrc/optimizer/ad/dfunctor.cc | 43 +++++++++++++------ mindspore/ccsrc/optimizer/ad/dfunctor.h | 1 + mindspore/ccsrc/pipeline/pass.cc | 2 +- .../python/pynative_mode/test_cell_bprop.py | 3 +- .../python/pynative_mode/test_framstruct.py | 22 ++++++++++ 5 files changed, 54 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index 128e4463e..3e1aa6e55 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -185,19 +185,32 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { return node_adjoint; } +bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { + // Do not care about non-CNode + if (!node->isa()) { + return false; + } + // Do not care about kPrimReturn + if (IsPrimitiveCNode(node, prim::kPrimReturn)) { + return false; + } + auto &users = primal_graph_->manager()->node_users()[node]; + // Do not care about isolated morphisms + if (users.empty()) { + return false; + } + // Not free if it's used by some node in primal_graph + bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { + auto &user = kv.first; + return user->func_graph() == primal_graph_; + }); + return !nonfree; +} + void DFunctor::MapFreeMorphism() { // Handle cnode not attached to output, that might be refered in other functions. for (auto &node : primal_graph_->nodes()) { - auto adjoint = FindAdjoint(node); - if (adjoint != nullptr) { - continue; - } - if (!node->isa()) { - MS_LOG(DEBUG) << "MapFreeMorphism noncnode not mapped after MapMorphism " << node->ToString() << " " - << node->type_name() << "."; - continue; - } - if (IsPrimitiveCNode(node, prim::kPrimReturn)) { + if (!IsFreeMorphism(node)) { continue; } MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; @@ -256,9 +269,10 @@ void DFunctor::MapMorphism() { // Set stop_gradient before MapMorphism. BroadCastStopFlag(); + // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent + MapFreeMorphism(); // Handle morphism from output. (void)MapMorphism(primal_graph_->output()); - MapFreeMorphism(); // Construct K for primal_graph_ auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); @@ -298,9 +312,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { const size_t param_diff = 1; if (bprop_graph->output()->isa() && bprop_graph->output()->cast()->size() + param_diff != bprop_graph->parameters().size()) { - MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() - << " output must be a tuple and output number should be the same with inputs."; + // It does not matter with the final tangents, just a tip for debugging + MS_LOG(DEBUG) << "User defined Cell bprop " << primal->ToString() << " in scope " + << primal->output()->scope()->name() + << " output must be a tuple and output number should be the same with inputs."; } resources_->manager()->AddFuncGraph(bprop_graph); diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h index f50f866ef..305973617 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.h @@ -61,6 +61,7 @@ class DFunctor { private: // Map one morphism. AdjointPtr MapMorphism(const AnfNodePtr &morph); + bool IsFreeMorphism(const AnfNodePtr &node); // Map morphism that's not attached to output. void MapFreeMorphism(); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index d89a0090a..e0336443e 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -111,7 +111,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { irpass.replace_applicator_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); - opt::OptPassConfig grad = opt::OptPassConfig({irpass.inline_, irpass.expand_jprim_}, true); + opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); OptPassGroupMap map_a({{"a_1", a_1}, {"a_2", a_2}, diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py index 03ae1affa..054afe36c 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pynative_mode/test_cell_bprop.py @@ -304,5 +304,4 @@ class MulAddWithWrongOutputNum(nn.Cell): def test_grad_mul_add_with_wrong_output_num(): mul_add = MulAddWithWrongOutputNum() - with pytest.raises(RuntimeError): - C.grad_all(mul_add)(1, 2) + C.grad_all(mul_add)(1, 2) diff --git a/tests/ut/python/pynative_mode/test_framstruct.py b/tests/ut/python/pynative_mode/test_framstruct.py index 293933721..ff7cf67f5 100644 --- a/tests/ut/python/pynative_mode/test_framstruct.py +++ b/tests/ut/python/pynative_mode/test_framstruct.py @@ -15,6 +15,7 @@ """ test_framstruct """ import pytest import numpy as np +import mindspore as ms import mindspore.nn as nn from mindspore import context from mindspore.ops import composite as C @@ -706,3 +707,24 @@ def grad_refactor_14(a, b): return inner1(b) + inner2(a) + inner3(a) def test_grad_refactor_14(): assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) + + +class IfDeferInline(nn.Cell): + def __init__(self, mul_size): + super().__init__() + self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32)) + self.mul = P.Mul() + + def construct(self, inputs): + x = self.mul(inputs, self.mul_weight) + if True: + x = x + return x + +def test_grad_if_defer_inline(): + """ test_grad_if_defer_inline """ + network = IfDeferInline([128, 96]) + network.add_flags(defer_inline=False) + inp = Tensor(np.ones([128, 96]).astype(np.float32)) + grads = C.grad_all(network)(inp) + assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) -- GitLab