提交 71acaa53 编写于 作者: H huanghui

enable ConfusionMulGrad fusion pass in bert only

上级 387bcece
...@@ -100,6 +100,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ...@@ -100,6 +100,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>()); ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>()); ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>()); ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>()); ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>()); ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>()); ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
......
...@@ -74,10 +74,21 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An ...@@ -74,10 +74,21 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An
} }
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf,
const AnfNodePtr &reduce_sum) { const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) {
MS_EXCEPTION_IF_NULL(mul0_anf); MS_EXCEPTION_IF_NULL(mul0_anf);
MS_EXCEPTION_IF_NULL(mul1_anf); MS_EXCEPTION_IF_NULL(mul1_anf);
MS_EXCEPTION_IF_NULL(reduce_sum); MS_EXCEPTION_IF_NULL(reduce_sum);
MS_EXCEPTION_IF_NULL(input2);
auto addn = input2->cast<CNodePtr>();
if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) {
MS_LOG(INFO) << "mul's second input is not addn";
return true;
}
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0);
if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) {
MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]";
return true;
}
if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) { if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) {
return true; return true;
} }
...@@ -86,11 +97,6 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf ...@@ -86,11 +97,6 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
auto mul0 = mul0_anf->cast<CNodePtr>(); auto mul0 = mul0_anf->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mul0); MS_EXCEPTION_IF_NULL(mul0);
// when network is _VirtualDatasetCell, quit fusion
if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) {
return true;
}
if (IsDepend(graph, mul0->input(1), reduce_sum)) { if (IsDepend(graph, mul0->input(1), reduce_sum)) {
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion";
return true; return true;
...@@ -128,7 +134,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons ...@@ -128,7 +134,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons
MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; MS_LOG(INFO) << "Mul0 do not exist, quit fusion";
return nullptr; return nullptr;
} }
if (QuitFusion(graph, mul0, mul1, node)) { if (QuitFusion(graph, mul0, mul1, node, input2)) {
return nullptr; return nullptr;
} }
......
...@@ -84,8 +84,9 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP ...@@ -84,8 +84,9 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
inputs.push_back(mul->input(index)); inputs.push_back(mul->input(index));
} }
auto another_input_node = add->input(add->size() - mul_index); auto another_input_node = add->input(add->size() - mul_index);
if (IsUsedByOthers(graph, another_input_node)) { if (another_input_node->isa<CNode>() &&
MS_LOG(INFO) << "Add's another input node is used by others, do not fuse"; AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) {
MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse";
return nullptr; return nullptr;
} }
inputs.push_back(another_input_node); inputs.push_back(another_input_node);
......
...@@ -32,7 +32,7 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon { ...@@ -32,7 +32,7 @@ class TestHWOptimizeConfusionMulGradFusion : public BackendCommon {
TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "before");
EXPECT_NE(g, nullptr); EXPECT_NE(g, nullptr);
std::vector<int> shp{1, 1, 1, 1}; std::vector<int> shp{10, 1024};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list; AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
...@@ -49,6 +49,5 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) { ...@@ -49,6 +49,5 @@ TEST_F(TestHWOptimizeConfusionMulGradFusion, test_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after"); FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_confusion_mul_grad_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
} }
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
from mindspore.ops import Primitive from mindspore.ops import Primitive
from mindspore.ops import operations as P from mindspore.ops import operations as P
addn = P.AddN()
mul = P.Mul() mul = P.Mul()
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
confusion_mul_grad = Primitive('ConfusionMulGrad') confusion_mul_grad = Primitive('ConfusionMulGrad')
make_tuple = Primitive('make_tuple') make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')
axis = 2 axis = 1
class FnDict: class FnDict:
...@@ -39,8 +40,10 @@ def test_confusion_mul_grad_fusion(tag): ...@@ -39,8 +40,10 @@ def test_confusion_mul_grad_fusion(tag):
@fns @fns
def before(input1, input2, input3): def before(input1, input2, input3):
output1 = mul(input1, input2) addn0 = addn((input2, input2))
mul1 = mul(input3, input2)
output1 = mul(input1, addn0)
mul1 = mul(input3, addn0)
# input axis will be convert to attr in step ConstructKernelGraph # input axis will be convert to attr in step ConstructKernelGraph
output2 = reduce_sum(mul1, axis) output2 = reduce_sum(mul1, axis)
res = make_tuple(output1, output2) res = make_tuple(output1, output2)
...@@ -48,7 +51,8 @@ def test_confusion_mul_grad_fusion(tag): ...@@ -48,7 +51,8 @@ def test_confusion_mul_grad_fusion(tag):
@fns @fns
def after(input1, input2, input3): def after(input1, input2, input3):
res = confusion_mul_grad(input1, input2, input3) addn0 = addn(input2, input2)
res = confusion_mul_grad(input1, addn0, input3)
item0 = tuple_getitem(res, 0) item0 = tuple_getitem(res, 0)
item1 = tuple_getitem(res, 1) item1 = tuple_getitem(res, 1)
res = make_tuple(item0, item1) res = make_tuple(item0, item1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册