diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index 46cc91443bbb16bbad711f12c9931aa70beb6649..7e2c989f4979ef7f327b6343bb88051cb517c8da 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { - return nullptr; - } - auto addn_maketuple = addn->input(1); - - auto fg = all_reduce_fg_; - // addn inputs cross the graph, make the inputs same as allreduce node. - if (z_->isa() && fg != z_->func_graph()) { - auto cnode_z = z_->cast(); - z_ = NewCNode(cnode_z->inputs(), fg); - } - - auto addn_op_node = addn->input(0); - auto make_tuple_op_node = addn->input(1)->cast()->input(0); + PatternNode x, y, z; + auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x); + auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true); + auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true); + auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat); + auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr { + auto fg = all_reduce_pat.GetFuncGraph(); + auto z_ = z.GetNode(node); + // If addn inputs cross the graph, make the inputs same as allreduce node. + if (z_->isa() && fg != z_->func_graph()) { + auto cnode_z = z_->cast(); + z_ = NewCNode(cnode_z->inputs(), fg); + } - AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); - AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); - AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); - AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); - ProcessDependEdge(fg, addn_maketuple, all_reduce); - return mul; + auto addn_cnode = addn_pat.GetOriginalNode()->cast(); + auto addn_op_node = addn_cnode->input(0); + auto make_tuple_op_node = addn_cnode->input(1)->cast()->input(0); + auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast()->input(0); + mul_cnode_ = mul_pat.GetOriginalNode(); + auto mul_prim = mul_cnode_->cast()->input(0); + auto addn_maketuple = admktup_pat.GetOriginalNode(); + + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg); + AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg); + ProcessDependEdge(fg, addn_maketuple, all_reduce); + return mul; + }; + MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda); + return nullptr; } void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, @@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN } } -void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - mul_ = node->cast()->input(0); - mul_cnode_ = node->cast(); - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - all_reduce_ = cnode->input(0); - x_ = cnode->input(1); - is_reduce_match_ = true; - all_reduce_fg_ = cnode->func_graph(); - } - } else { - tmp_ = node; - } - } -} - -void AdjustAllReduceMulAdd::Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - all_reduce_fg_ = nullptr; -} - } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h index 699005f7bf98a96f834f8dcdb335b85d11530cff..177a66fb0910bd9a2a638966179f335122876bac 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h @@ -38,20 +38,14 @@ namespace irpass { // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -class AdjustAllReduceMulAdd : public AnfVisitor { +class AdjustAllReduceMulAdd : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); - void Visit(const AnfNodePtr &node) override; - void Reset(); private: - int level_{0}; - bool is_reduce_match_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; - AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; - FuncGraphPtr all_reduce_fg_{nullptr}; + AnfNodePtr mul_cnode_{nullptr}; }; class ArithmeticSimplify : public OptimizerCaller { diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 7c1a856df64652e90d379b90818f9ea7e5d820d6..1ed559d656572285c2dd2851cafbf57df9ccd134 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -94,8 +94,8 @@ class PBinOperation : public PBase > { ~PBinOperation() = default; AnfNodePtr GetNode(const AnfNodePtr &node) const { - AnfNodePtr lhs = x_.GetNode(node->func_graph()); - AnfNodePtr rhs = y_.GetNode(node->func_graph()); + AnfNodePtr lhs = x_.GetNode(node); + AnfNodePtr rhs = y_.GetNode(node); AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; return NewCNode(list, node->func_graph()); } @@ -113,25 +113,42 @@ class PBinOperation : public PBase > { if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) { return false; } + captured_binop_node_ = node; return true; } return false; } + captured_binop_node_ = node; return true; } } return false; } + + /// Returns the original node captured by this Binary Operation Pattern. + /// Throws exception if a node was not captured before. + AnfNodePtr GetOriginalNode() const { + if (captured_binop_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; + } + + return captured_binop_node_; + } + void Reset() const { x_.Reset(); y_.Reset(); + captured_binop_node_ = nullptr; } + using Internal = const PBinOperation &; + private: const PrimitivePtr prim_; typename T::Internal x_; typename T2::Internal y_; bool is_commutative_{false}; + mutable AnfNodePtr captured_binop_node_{nullptr}; }; /// @@ -265,10 +282,11 @@ class PCNode : public PBase > { return *this; } + using Internal = const PCNode &; + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); - has_min_extra_nodes_ = false; extra_nodes_.clear(); } @@ -316,6 +334,9 @@ class PPrimitive : public PBase > { AnfNodePtrList tokens(inputs.begin() + 1, inputs.end()); tuple_utils::PTupleCapture capture_func(tokens); tuple_utils::apply_func_tuple(&capture_func, args_); + if (capture_func.captured_) { + captured_prim_node_ = node; + } return capture_func.captured_; } return false; @@ -329,9 +350,11 @@ class PPrimitive : public PBase > { tuple_utils::apply_func_tuple(&capture_func, args_); // If it could capture the initial set of nodes specified in the Pattern // and there are enough extra inputs to add - if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) { - extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); - return true; + if (capture_func.captured_) { + captured_prim_node_ = node; + if (inputs.size() > pattern_arg_len + 1) { + extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end()); + } } return capture_func.captured_; } @@ -349,19 +372,42 @@ class PPrimitive : public PBase > { return *this; } + /// Returns the FuncGraph of the original node captured by this Primitive Pattern. + /// Throws exception if a node was not captured before. + FuncGraphPtr GetFuncGraph() const { + if (captured_prim_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph."; + } + + return captured_prim_node_->func_graph(); + } + + /// Returns the original node captured by this Primitive Pattern. + /// Throws exception if a node was not captured before. + AnfNodePtr GetOriginalNode() const { + if (captured_prim_node_ == nullptr) { + MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it."; + } + + return captured_prim_node_; + } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); - has_min_extra_nodes_ = false; extra_nodes_.clear(); + captured_prim_node_ = nullptr; } + using Internal = const PPrimitive &; + private: const PrimitivePtr prim_; std::tuple args_; mutable AnfNodePtrList extra_nodes_; mutable bool has_min_extra_nodes_{false}; mutable size_t min_extra_nodes_{0}; + mutable AnfNodePtr captured_prim_node_{nullptr}; }; ///