提交 b3c6da90 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1714 Add 2 patterns for SoftmaxGradExt fusion pass

Merge pull request !1714 from huanghui/SoftmaxGradExt-fussion-pass
......@@ -100,6 +100,8 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusion>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV2>());
ir_fusion_pm->AddPass(std::make_shared<SoftmaxGradExtFusionV3>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
......
......@@ -31,6 +31,24 @@ const BaseRef SoftmaxGradExtFusion::DefinePattern() const {
return mul_grad;
}
const BaseRef SoftmaxGradExtFusionV2::DefinePattern() const {
VectorRef mul({prim::kPrimMul, input1_, input0_});
VectorRef sum({sum_var_, mul});
VectorRef sub({prim::kPrimSub, input0_, sum});
VectorRef mul1({prim::kPrimMul, input1_, sub});
VectorRef mul_grad({prim::kPrimMul, input2_, mul1});
return mul_grad;
}
const BaseRef SoftmaxGradExtFusionV3::DefinePattern() const {
VectorRef mul({prim::kPrimMul, input1_, input0_});
VectorRef sum({sum_var_, mul});
VectorRef sub({prim::kPrimSub, input0_, sum});
VectorRef mul1({prim::kPrimMul, input1_, sub});
VectorRef mul_grad({prim::kPrimMul, mul1, input2_});
return mul_grad;
}
const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
......@@ -46,7 +64,7 @@ const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const
MS_EXCEPTION_IF_NULL(fusion_node);
fusion_node->set_scope(node->scope());
fusion_node->set_abstract(node->abstract());
AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, fusion_node);
AnfAlgo::CopyNodeAttr(kAttrKeepDims, "keepdims", sum, fusion_node);
AnfAlgo::CopyNodeAttr(kAttrAxis, sum, fusion_node);
return fusion_node;
}
......
......@@ -17,13 +17,15 @@
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_
#include <memory>
#include <string>
#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class SoftmaxGradExtFusion : public PatternProcessPass {
public:
explicit SoftmaxGradExtFusion(bool multigraph = true) : PatternProcessPass("softmax_grad_ext_fusion", multigraph) {
explicit SoftmaxGradExtFusion(const std::string &name = "softmax_grad_ext_fusion", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
input0_ = std::make_shared<Var>();
input1_ = std::make_shared<Var>();
input2_ = std::make_shared<Var>();
......@@ -33,12 +35,28 @@ class SoftmaxGradExtFusion : public PatternProcessPass {
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
protected:
VarPtr input0_;
VarPtr input1_;
VarPtr input2_;
VarPtr sum_var_;
};
class SoftmaxGradExtFusionV2 : public SoftmaxGradExtFusion {
public:
explicit SoftmaxGradExtFusionV2(bool multigraph = true)
: SoftmaxGradExtFusion("softmax_grad_ext_fusion_v2", multigraph) {}
~SoftmaxGradExtFusionV2() override = default;
const BaseRef DefinePattern() const override;
};
class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion {
public:
explicit SoftmaxGradExtFusionV3(bool multigraph = true)
: SoftmaxGradExtFusion("softmax_grad_ext_fusion_v3", multigraph) {}
~SoftmaxGradExtFusionV3() override = default;
const BaseRef DefinePattern() const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_
......@@ -49,5 +49,47 @@ TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion_v2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v2", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{1, 1, 1, 1};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 3; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::SoftmaxGradExtFusionV2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v2", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion_v3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v3", "before");
EXPECT_NE(g, nullptr);
std::vector<int> shp{1, 1, 1, 1};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 3; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::SoftmaxGradExtFusionV3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v3", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore
......@@ -54,3 +54,43 @@ def test_softmax_grad_ext_fusion(tag):
return MakeTuple(res)
return fns[tag]
def test_softmax_grad_ext_fusion_v2(tag):
fns = FnDict()
@fns
def before(input0, input1, input2):
mul = Mul(input1, input0)
reduce_sum = ReduceSum(mul, axes)
sub = Sub(input0, reduce_sum)
mul1 = Mul(input1, sub)
mul_grad = Mul(input2, mul1)
return mul_grad
@fns
def after(input0, input1, input2):
res = SoftmaxGradExt(input0, input1, input2)
return MakeTuple(res)
return fns[tag]
def test_softmax_grad_ext_fusion_v3(tag):
fns = FnDict()
@fns
def before(input0, input1, input2):
mul = Mul(input1, input0)
reduce_sum = ReduceSum(mul, axes)
sub = Sub(input0, reduce_sum)
mul1 = Mul(input1, sub)
mul_grad = Mul(mul1, input2)
return mul_grad
@fns
def after(input0, input1, input2):
res = SoftmaxGradExt(input0, input1, input2)
return MakeTuple(res)
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册