提交 d1c28c1d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!675 [revert] AdjustAllReduceMulAdd

Merge pull request !675 from vlne-v1/revert-allreduce-addn-opt
......@@ -230,7 +230,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// Debug ops
const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
......
......@@ -234,7 +234,6 @@ extern const PrimitivePtr kPrimInDict;
extern const PrimitivePtr kPrimNotInDict;
// Comm ops
extern const PrimitivePtr kPrimAllReduce;
extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;
......
......@@ -48,7 +48,7 @@ namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimAddN, prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
......
......@@ -228,86 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor {
CNodePtr cnode_;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
// grad = grad + weight * decy
// grad = AllReduce(grad) / worker_number
// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
class AdjustAllReduceMulAdd : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
// {prim::kPrimAddN, Zs}
if (!IsPrimitiveCNode(node, prim::kPrimAddN)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn->size() != 2) {
return nullptr;
}
AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1));
if (x_ == nullptr || y_ == nullptr || z_ == nullptr) {
return nullptr;
}
auto addn_op_node = addn->input(0);
auto make_tuple_op_node = addn->input(1)->cast<CNodePtr>()->input(0);
auto fg = node->func_graph();
AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg);
return NewCNode({mul_, all_reduce, y_}, fg);
}
void Visit(const AnfNodePtr &node) override {
if (level_ == 0) {
level_ = 1;
is_reduce_match_ = false;
// {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}
AnfVisitor::Match(prim::kPrimMul)(node);
level_ = 0;
if (is_reduce_match_) {
mul_ = node->cast<CNodePtr>()->input(0);
y_ = tmp_;
} else {
z_ = node;
}
}
if (level_ == 1) {
// {prim::kPrimAllReduce, X}
if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) {
auto cnode = node->cast<CNodePtr>();
if (cnode->size() > 1) {
all_reduce_ = cnode->input(0);
x_ = cnode->input(1);
is_reduce_match_ = true;
}
} else {
tmp_ = node;
}
}
}
void Reset() {
level_ = 0;
is_reduce_match_ = false;
x_ = nullptr;
y_ = nullptr;
z_ = nullptr;
tmp_ = nullptr;
}
private:
int level_{0};
bool is_reduce_match_{false};
AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr};
AnfNodePtr all_reduce_{nullptr}, mul_{nullptr};
};
class ArithmeticSimplify {
public:
ArithmeticSimplify()
......@@ -323,7 +243,6 @@ class ArithmeticSimplify {
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(adjust_allreduce_mul_add_);
}
~ArithmeticSimplify() = default;
......@@ -345,7 +264,6 @@ class ArithmeticSimplify {
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
AdjustAllReduceMulAdd adjust_allreduce_mul_add_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass
......
......@@ -1229,7 +1229,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float32)
>>> input_x = Tensor([1, 2, 3, 4], mindspore.float)
>>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
......
......@@ -1630,7 +1630,7 @@ class LayerNorm(Primitive):
`Layer Normalization <https://arxiv.org/abs/1607.06450>`_.
.. math::
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
......
......@@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) {
ASSERT_TRUE(CheckOpt(beforerl, after, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after, patterns));
}
TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell");
FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr");
FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl");
FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr");
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1");
FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r");
FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.arithmetic_simplify_});
ASSERT_TRUE(CheckOpt(beforell, after1, patterns));
ASSERT_TRUE(CheckOpt(beforelr, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerl, after1, patterns));
ASSERT_TRUE(CheckOpt(beforerr, after1, patterns));
ASSERT_TRUE(CheckOpt(before2l, after2, patterns));
ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
}
} // namespace opt
} // namespace mindspore
......@@ -908,8 +908,8 @@ def test_print_tuple_wrapper(tag):
def test_constant_duplicate_mul(tag):
fns = FnDict()
Mul = Primitive('Mul')
Sqrt = Primitive('Sqrt')
Mul = Primitive('Mul');
Sqrt = Primitive('Sqrt');
x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32'))
tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
......@@ -936,44 +936,3 @@ def test_constant_duplicate_mul(tag):
return Mul(Sqrt(x), Mul(tensor1, tensor2))
return fns[tag]
def test_adjust_allreduce_mul_add(tag):
fns = FnDict()
Mul = Primitive('Mul')
AddN = Primitive('AddN')
AllReduce = Primitive('AllReduce')
@fns
def beforell(x, y, z):
return AddN((z, Mul(y, AllReduce(x))))
@fns
def beforelr(x, y, z):
return AddN((z, Mul(AllReduce(x), y)))
@fns
def beforerl(x, y, z):
return AddN((Mul(y, AllReduce(x)), z))
@fns
def beforerr(x, y, z):
return AddN((Mul(AllReduce(x), y), z))
@fns
def after1(x, y, z):
return Mul(AllReduce(AddN((z, x))), y)
@fns
def before2r(x, y, z):
return AddN((Mul(AllReduce(x), y), Mul(z, z)))
@fns
def before2l(x, y, z):
return AddN((Mul(z, z), Mul(AllReduce(x), y)))
@fns
def after2(x, y, z):
return Mul(AllReduce(AddN((Mul(z, z), x))), y)
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册