提交 252ed4f7 编写于 作者: W Wei Luning

use the old op

上级 27a88a6b
......@@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
return nullptr;
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
auto fg = node->func_graph();
AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg);
AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg);
AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg);
return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg);
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
return NewCNode({mul_, all_reduce, y_}, fg);
}
void Visit(const AnfNodePtr &node) override {
......@@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
y_ = tmp_;
} else {
z_ = node;
......@@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
}
......@@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr};
};
class ArithmeticSimplify {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册