提交 2b2c1730 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3696 Update AdjustAllReduce to use Pattern Matcher

Merge pull request !3696 from Giancarlo/update_adjust_allreduce
......@@ -95,37 +95,37 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
Reset();
// {prim::kPrimAddN, Zs}
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) {
return nullptr;
}
auto addn_maketuple = addn->input(1);
auto fg = all_reduce_fg_;
// addn inputs cross the graph, make the inputs same as allreduce node.
if (z_->isa<CNode>() && fg != z_->func_graph()) {
auto cnode_z = z_->cast<CNodePtr>();
z_ = NewCNode(cnode_z->inputs(), fg);
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
PatternNode x, y, z;
auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
auto fg = all_reduce_pat.GetFuncGraph();
auto z_ = z.GetNode(node);
// If addn inputs cross the graph, make the inputs same as allreduce node.
if (z_->isa<CNode>() && fg != z_->func_graph()) {
auto cnode_z = z_->cast<CNodePtr>();
z_ = NewCNode(cnode_z->inputs(), fg);
}
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg);
ProcessDependEdge(fg, addn_maketuple, all_reduce);
return mul;
auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
auto addn_op_node = addn_cnode->input(0);
auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
mul_cnode_ = mul_pat.GetOriginalNode();
auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
auto addn_maketuple = admktup_pat.GetOriginalNode();
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x.GetNode(node)}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
ProcessDependEdge(fg, addn_maketuple, all_reduce);
return mul;
};
MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
return nullptr;
}
void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
......@@ -146,48 +146,6 @@ void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfN
}
}
void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) {
if (level_ == 0) {
level_ = 1;
is_reduce_match_ = false;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
mul_cnode_ = node->cast<CNodePtr>();
y_ = tmp_;
} else {
z_ = node;
}
}
if (level_ == 1) {
// {prim::kPrimAllReduce, X}
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
all_reduce_fg_ = cnode->func_graph();
}
} else {
tmp_ = node;
}
}
}
void AdjustAllReduceMulAdd::Reset() {
level_ = 0;
is_reduce_match_ = false;
x_ = nullptr;
y_ = nullptr;
z_ = nullptr;
tmp_ = nullptr;
all_reduce_fg_ = nullptr;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......@@ -38,20 +38,14 @@ namespace irpass {
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class AdjustAllReduceMulAdd : public AnfVisitor {
class AdjustAllReduceMulAdd : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node);
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr};
FuncGraphPtr all_reduce_fg_{nullptr};
AnfNodePtr mul_cnode_{nullptr};
};
class ArithmeticSimplify : public OptimizerCaller {
......
......@@ -94,8 +94,8 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
~PBinOperation() = default;
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtr lhs = x_.GetNode(node->func_graph());
AnfNodePtr rhs = y_.GetNode(node->func_graph());
AnfNodePtr lhs = x_.GetNode(node);
AnfNodePtr rhs = y_.GetNode(node);
AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs};
return NewCNode(list, node->func_graph());
}
......@@ -113,25 +113,42 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
if (!x_.TryCapture(inputs[2]) || !y_.TryCapture(inputs[1])) {
return false;
}
captured_binop_node_ = node;
return true;
}
return false;
}
captured_binop_node_ = node;
return true;
}
}
return false;
}
/// Returns the original node captured by this Binary Operation Pattern.
/// Throws exception if a node was not captured before.
AnfNodePtr GetOriginalNode() const {
if (captured_binop_node_ == nullptr) {
MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
}
return captured_binop_node_;
}
void Reset() const {
x_.Reset();
y_.Reset();
captured_binop_node_ = nullptr;
}
using Internal = const PBinOperation<T, T2> &;
private:
const PrimitivePtr prim_;
typename T::Internal x_;
typename T2::Internal y_;
bool is_commutative_{false};
mutable AnfNodePtr captured_binop_node_{nullptr};
};
///
......@@ -265,10 +282,11 @@ class PCNode : public PBase<PCNode<TArgs...> > {
return *this;
}
using Internal = const PCNode<TArgs...> &;
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
has_min_extra_nodes_ = false;
extra_nodes_.clear();
}
......@@ -316,6 +334,9 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
AnfNodePtrList tokens(inputs.begin() + 1, inputs.end());
tuple_utils::PTupleCapture capture_func(tokens);
tuple_utils::apply_func_tuple(&capture_func, args_);
if (capture_func.captured_) {
captured_prim_node_ = node;
}
return capture_func.captured_;
}
return false;
......@@ -329,9 +350,11 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
tuple_utils::apply_func_tuple(&capture_func, args_);
// If it could capture the initial set of nodes specified in the Pattern
// and there are enough extra inputs to add
if (capture_func.captured_ && inputs.size() > pattern_arg_len + 1) {
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
return true;
if (capture_func.captured_) {
captured_prim_node_ = node;
if (inputs.size() > pattern_arg_len + 1) {
extra_nodes_.insert(extra_nodes_.end(), inputs.begin() + 1 + pattern_arg_len, inputs.end());
}
}
return capture_func.captured_;
}
......@@ -349,19 +372,42 @@ class PPrimitive : public PBase<PPrimitive<TArgs...> > {
return *this;
}
/// Returns the FuncGraph of the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
FuncGraphPtr GetFuncGraph() const {
if (captured_prim_node_ == nullptr) {
MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get its FuncGraph.";
}
return captured_prim_node_->func_graph();
}
/// Returns the original node captured by this Primitive Pattern.
/// Throws exception if a node was not captured before.
AnfNodePtr GetOriginalNode() const {
if (captured_prim_node_ == nullptr) {
MS_EXCEPTION(ValueError) << "A Node wasn't captured for this Pattern before attempting to get it.";
}
return captured_prim_node_;
}
void Reset() const {
tuple_utils::PTupleResetCapture reset;
tuple_utils::apply_func_tuple(&reset, args_);
has_min_extra_nodes_ = false;
extra_nodes_.clear();
captured_prim_node_ = nullptr;
}
using Internal = const PPrimitive<TArgs...> &;
private:
const PrimitivePtr prim_;
std::tuple<typename TArgs::Internal...> args_;
mutable AnfNodePtrList extra_nodes_;
mutable bool has_min_extra_nodes_{false};
mutable size_t min_extra_nodes_{0};
mutable AnfNodePtr captured_prim_node_{nullptr};
};
///
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册