diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 0a6fb0b3f678b524a04f72fb478be6a3b8010a84..59dd011c9df1b999149881343bae2d9e550ef003 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -59,6 +59,7 @@ const PrimitivePtr kPrimHasType = std::make_shared("hastype"); // Statements const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); const PrimitivePtr kPrimReturn = std::make_shared("return"); const PrimitivePtr kPrimAssign = std::make_shared("Assign"); const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 8c63660c3e71d1a1038cf0df199fa540b2fc4983..f46f76bfcee36a8ba20f1ec19e7af6869bd86dfd 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -65,6 +65,7 @@ extern const PrimitivePtr kPrimHasType; // Statements extern const PrimitivePtr kPrimSwitch; +extern const PrimitivePtr kPrimSwitchLayer; extern const PrimitivePtr kPrimReturn; extern const PrimitivePtr kPrimAssign; extern const PrimitivePtr kPrimAssignAdd; diff --git a/mindspore/ccsrc/operator/prim_statement.cc b/mindspore/ccsrc/operator/prim_statement.cc index 0b9d491ce6cb87c57a1f2b2201fba260608def0d..e639b58a05bbd8eb6c903b27ce1f692400abb1f6 100644 --- a/mindspore/ccsrc/operator/prim_statement.cc +++ b/mindspore/ccsrc/operator/prim_statement.cc @@ -126,6 +126,30 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); } +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: index, branch + if (args_spec_list.size() != 2) { + MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is " + << args_spec_list.size() << "."; + } + AbstractTuplePtr branches_abs = CheckArg(primitive->name(), args_spec_list, 1); + AbstractBasePtrList branches = branches_abs->elements(); + const size_t maximum_layer_num = 1000; + if (branches.size() < 0 || branches.size() > maximum_layer_num) { + MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got " + << branches.size() << " branches."; + } + + MS_EXCEPTION_IF_NULL(branches[0]); + auto b = branches[0]; + for (size_t i = 1; i < branches.size(); i++) { + MS_EXCEPTION_IF_NULL(branches[i]); + b = b->Join(branches[i]); + } + return b; +} + std::vector GetSupportedTargetValue() { std::vector list = {kNone, MakeValue(false), MakeValue(true)}; return list; diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index de368dbdd2f62abdd636af909ad109308b2deca2..6f648b5728c01e82ee1704304959583f828fe305 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -38,6 +38,7 @@ namespace mindspore { namespace ad { std::unordered_map DFunctor::func_graph_to_functor_; std::unordered_map DFunctor::anfnode_to_adjoin_definition_; +FuncGraphSet DFunctor::scope_; DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { @@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas void DFunctor::Init(const DFunctorPtr &functor, bool is_top) { func_graph_to_functor_[primal_graph_] = functor; is_top_ = is_top; + if (is_top) { + scope_ = primal_graph_->scope(); + } } void DFunctor::Clear() { func_graph_to_functor_.clear(); anfnode_to_adjoin_definition_.clear(); + scope_.clear(); } void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { @@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { fv_adjoint->second->AccumulateDout(dfv); } +void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { + // Take switch_layer as a set of candidate functions. + auto input = cnode_morph->input(2); + if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; + } + auto tuple_graphs = input->cast(); + for (size_t i = 1; i < tuple_graphs->size(); ++i) { + auto graph = tuple_graphs->input(i); + if (!IsValueNode(graph)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() + << " as the " << i << "th element."; + } + auto func_graph = GetValueNode(graph); + auto functor = func_graph_to_functor_.find(func_graph); + if (functor == func_graph_to_functor_.end()) { + MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " + << func_graph->ToString() << "."; + } + // Consider direct and indirect fvs. + for (auto fv : func_graph->free_variables_nodes()) { + BackPropagateFv(fv, env); + } + for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " + << indirect_fv.first->ToString() << "."; + BackPropagateFv(indirect_fv.first, env); + } + } +} + void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); // Call with delimited continuation dout. auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); node_adjoint->RegisterDoutUser(bprop_app, 1); + // Special case for switch_layer + if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); + BackPropagateSwitchLayer(cnode_morph, din); + return; + } for (size_t i = 0; i < cnode_morph->size(); i++) { auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); auto input = cnode_morph->input(i); @@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { return primal; } +bool DFunctor::IsInScope(const AnfNodePtr &node) { + return std::any_of(scope_.begin(), scope_.end(), + [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); +} + void DFunctor::MapFvObject() { // Map free variable. const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); @@ -414,8 +461,8 @@ void DFunctor::MapFvObject() { if (parent_adjoint != nullptr) { adjoint = std::make_shared(node, parent_adjoint->k(), tape_); } else { - if (is_top_) { - // Top graph for ad, add adjoint for free variables. + if (is_top_ || node->isa() || !IsInScope(node)) { + // Out of ad scope, add adjoint for free variables. adjoint = std::make_shared(node, node, tape_); UpdateAdjoint(adjoint); } else { diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h index 1358cc8f28538398015455f319af3d074df4fda7..7da90e7c3ea24e79d569a4ed7b093cdbfb6ebdb0 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.h @@ -62,9 +62,11 @@ class DFunctor { // Map one morphism. AdjointPtr MapMorphism(const AnfNodePtr &morph); bool IsFreeMorphism(const AnfNodePtr &node); + bool IsInScope(const AnfNodePtr &node); // Map morphism that's not attached to output. void MapFreeMorphism(); void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); + void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); @@ -101,6 +103,7 @@ class DFunctor { bool is_top_; static std::unordered_map> func_graph_to_functor_; static std::unordered_map anfnode_to_adjoin_definition_; + static FuncGraphSet scope_; }; // D Functor's rules to map primitive object. @@ -120,6 +123,7 @@ class KPrim { private: FuncGraphPtr GetBprop(const PrimitivePtr &prim); + FuncGraphPtr GetFprop(const PrimitivePtr &prim); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); // Given a bprop rule, do the K mapping. template diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index c74670e55d008c87f3de4065294fb0e165e27fbd..6fbb9d1ae8a314beeb56d305834995bafbc5cbc4 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -62,6 +62,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { return func_graph; } +FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { + static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; + std::string func_name = "_fprop_" + prim->name(); + py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); + auto func_graph = parse::ParsePythonCode(fn); + MS_EXCEPTION_IF_NULL(func_graph); + return BasicClone(func_graph); +} + MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); @@ -92,6 +101,13 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R return iter->second; } + if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") { + auto fprop = GetFprop(prim); + fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); + bprop_registry_[prim::kPrimSwitchLayer] = fprop; + return fprop; + } + if (prim->name() == "make_tuple") { return nullptr; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 274f63844cfad73ab5d2494ae0285b39edc22008..f80a0cdcc2d0c22190c254f6e96b70acb4dc7cb6 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimHasType, {InferImplHasType, false}}, {prim::kPrimDot, {InferImplDot, true}}, {prim::kPrimSwitch, {InferImplSwitch, true}}, + {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, {prim::kPrimIs_, {InferImplIs_, true}}, {prim::kPrimIsNot, {InferImplIsNot, true}}, {prim::kPrimInDict, {InferImplInDict, true}}, diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index be71f3200a6828e2880bde3b81b103060d215204..f0694bfb6f346433119d439da1b1da7959bc8fde 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -174,6 +174,8 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index ffdd5e6a5acb7981890d265a179230e4f7798571..ee3117c83af22a1f89667927083484c81f588e09 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -242,3 +242,9 @@ def bprop_switch(cond, tb, fb, out, dout): """Backpropagator for primitive `switch`.""" return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ F.switch(cond, C.zeros_like(fb), dout) + +def _fprop_switch_layer(index, layers): + """Backpropagator for primitive `switch_layer`.""" + def _bprop_switch_layer(dout): + return dout, C.zeros_like(index), () + return F.switch_layer(index, layers), _bprop_switch_layer diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 5f7cabc54d9f79d7c307a7973bb3236c187eeba8..f7e014d8a5b03791033c0c20117fcfe24a1a6974 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -135,6 +135,7 @@ env_getitem = Primitive('env_getitem') env_add = Primitive('env_add') J = Primitive('J') switch = Primitive('switch') +switch_layer = Primitive('switch_layer') # for sum bprop reduced_shape = Primitive("reduced_shape") # shape_mul:input mush be shape multiply elemts in tuple(shape) diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index b17eea8dddad2826a3f82d7a61af19466523f86a..a6c15444e4ef0b32ceb0ae570e203eb5b1e2c357 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -19,6 +19,9 @@ from mindspore import nn from mindspore import Tensor from mindspore import context from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter, ParameterTuple context.set_context(mode=context.GRAPH_MODE) @@ -358,3 +361,33 @@ def test_if_compile_true(): def test_if_compile_false(): output = if_compile_test(8, 3) print("test_if_compile_false:", output) + + +def test_switch_layer(): + class Layer1(nn.Cell): + def __init__(self): + super(Layer1, self).__init__() + self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') + def construct(self, x): + return x * self.z1 + + class Layer2(nn.Cell): + def __init__(self): + super(Layer2, self).__init__() + self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') + def construct(self, x): + return x * self.z2 + + class SwitchLayerCell(nn.Cell): + def __init__(self): + super(SwitchLayerCell, self).__init__() + self.layers = (Layer1(), Layer2()) + self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') + def construct(self, index, x): + ret = F.switch_layer(index, self.layers)(x) * self.z3 + return ret + + net = SwitchLayerCell() + net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))