提交 8f29e724 编写于 作者: B BowenK

add pass to eliminate depend value

上级 d638578d
...@@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate // Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
......
...@@ -48,6 +48,7 @@ class OptimizeIRPassLib { ...@@ -48,6 +48,7 @@ class OptimizeIRPassLib {
SubstitutionPtr same_eliminate_; SubstitutionPtr same_eliminate_;
SubstitutionPtr check_bprop_eliminate_; SubstitutionPtr check_bprop_eliminate_;
SubstitutionPtr reset_defer_inline_; SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;
// Env Item Eliminate // Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_; SubstitutionPtr env_get_item_eliminate_;
......
...@@ -24,9 +24,11 @@ ...@@ -24,9 +24,11 @@
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h" #include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "ir/pattern_matcher.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor { ...@@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor {
AnfNodePtr y_{nullptr}; AnfNodePtr y_{nullptr};
PrimitivePtr PrimFill_, PrimShape_, PrimDType_; PrimitivePtr PrimFill_, PrimShape_, PrimDType_;
}; };
// {prim::kPrimDepend, X, ValueCond}->X
class DependValueElim : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, cond;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node)));
return nullptr;
}
};
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_, irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_, irpass.new_env_get_item_,
irpass.depend_value_elim_,
}); });
opt::OptPassConfig a_3 = opt::OptPassConfig({ opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.same_eliminate_, irpass.same_eliminate_,
......
...@@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) { ...@@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) {
ASSERT_TRUE(CheckOpt(before, after, patterns)); ASSERT_TRUE(CheckOpt(before, after, patterns));
} }
TEST_F(TestOptLib, test_elim_depend_value) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_});
ASSERT_TRUE(CheckOpt(before, after, patterns));
}
TEST_F(TestOptLib, test_elim_tile_multiply_one) { TEST_F(TestOptLib, test_elim_tile_multiply_one) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before"); FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after"); FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after");
......
...@@ -494,6 +494,21 @@ def test_elim_transpose(tag): ...@@ -494,6 +494,21 @@ def test_elim_transpose(tag):
return fns[tag] return fns[tag]
def test_elim_depend_value(tag):
""" test_elim_depend_value """
fns = FnDict()
depend = P.Depend()
@fns
def before(x):
return depend(x, None)
@fns
def after(x):
return x
return fns[tag]
def test_elim_tile_multiply_one(tag): def test_elim_tile_multiply_one(tag):
""" test_elim_tile_multiply_one """ """ test_elim_tile_multiply_one """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册