提交 4b5cbe5d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2186 Optimization for opt

Merge pull request !2186 from Kang/opt
......@@ -51,8 +51,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul});
special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
......@@ -72,9 +72,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem);
add_env_get_item_ = MakeSubstitution(AddEnvGetItem(), "add_env_get_item", prim::kPrimEnvGetItem);
env_get_set_item_ = MakeSubstitution(EnvGetSetItem(), "env_get_set_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ =
......@@ -91,8 +90,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ);
stop_gradient_eliminate_ =
MakeSubstitution(StopGradientEliminater(), "stop_gradient_eliminate", prim::kPrimStopGradient);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling
......@@ -113,9 +110,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph);
// Incorporation
incorporate_getitem_ = MakeSubstitution(IncorporateGetitem(), "incorporate_getitem", prim::kPrimTupleGetItem);
incorporate_getitem_switch_ =
MakeSubstitution(IncorporateGetitemSwitch(), "incorporate_getitem_switch", prim::kPrimTupleGetItem);
incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup);
......
......@@ -50,9 +50,8 @@ class OptimizeIRPassLib {
SubstitutionPtr reset_defer_inline_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_;
SubstitutionPtr add_env_get_item_;
SubstitutionPtr env_get_set_item_;
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_switch_;
......@@ -74,7 +73,6 @@ class OptimizeIRPassLib {
// Gradient irpasses
SubstitutionPtr expand_jprim_;
SubstitutionPtr stop_gradient_eliminate_;
SubstitutionPtr minmaximum_grad_;
// inline
......@@ -83,8 +81,7 @@ class OptimizeIRPassLib {
SubstitutionPtr specialize_transform_;
// Incorporation
SubstitutionPtr incorporate_getitem_;
SubstitutionPtr incorporate_getitem_switch_;
SubstitutionPtr incorporate_getitem_set_;
SubstitutionPtr incorporate_call_;
SubstitutionPtr incorporate_call_switch_;
......@@ -115,51 +112,30 @@ class InferenceOptPrepareLib {
// predicate functions
inline bool IsNode(const AnfNodePtr &) { return true; }
inline bool IsCNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<CNode>();
}
return false;
}
inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); }
inline bool IsVNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<ValueNode>();
}
return false;
}
inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
inline bool IsParam(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<Parameter>();
}
return false;
}
inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); }
// Check if CNode Input 0 is Func Graph
inline bool IsCNodeGraph(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
if (!node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (IsValueNode<FuncGraph>(inp0)) {
return true;
}
return false;
return IsValueNode<FuncGraph>(inp0);
}
// Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) {
if (node == nullptr || !node->isa<CNode>()) {
if (!node->isa<CNode>()) {
return false;
}
auto inp0 = node->cast<CNodePtr>()->input(0);
if (inp0 != nullptr && inp0->isa<CNode>()) {
return true;
}
return false;
return (inp0 != nullptr) && inp0->isa<CNode>();
}
} // namespace irpass
} // namespace opt
......
......@@ -225,6 +225,33 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false};
};
class EnvGetItemEliminater {
public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
}
~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
NewEnvGetItem new_env_get_item_;
AddEnvGetItem add_env_get_item_;
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
public:
......
......@@ -55,21 +55,6 @@ class ExpandJPrim : public AnfVisitor {
private:
ValueNodePtr x_{nullptr};
};
// stop_gradient(x) ==> x
class StopGradientEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
x_ = nullptr;
AnfVisitor::Match(prim::kPrimStopGradient)(node);
return x_;
}
void Visit(const AnfNodePtr &node) override { x_ = node; }
private:
AnfNodePtr x_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -197,6 +197,31 @@ class IncorporateGetitemSwitch : public AnfVisitor {
std::vector<AnfNodePtr> args_{};
internal::GetitemTransform getitem_transform_;
};
class IncorporateGetitemSet {
public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() {
eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_);
}
~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
IncorporateGetitem incorporate_getitem_;
IncorporateGetitemSwitch incorporate_getitem_switch_;
std::vector<TransformFuncType> eliminaters_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -35,12 +35,14 @@ class SpecialOpEliminater {
public:
SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf),
stop_gradient_(prim::kPrimStopGradient),
hook_backward_(prim::kPrimHookBackward),
print_shape_type_(prim::kPrimPrintShapeType),
get_ref_value_(prim::kPrimGetRefValue),
mirror_(prim::kPrimMirror),
virtual_div_(prim::kPrimVirtualDiv) {
eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_);
eliminaters_.emplace_back(print_shape_type_);
eliminaters_.emplace_back(get_ref_value_);
......@@ -61,7 +63,8 @@ class SpecialOpEliminater {
}
private:
PrimEliminater insert_gradient_of_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, virtual_div_;
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_;
std::vector<TransformFuncType> eliminaters_{};
};
......
......@@ -44,8 +44,17 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return false;
}
auto cnode = node->cast<CNodePtr>();
auto inp0 = cnode->input(0);
auto prim0 = GetValueNode<PrimitivePtr>(inp0);
if (prim0 == nullptr) {
return false;
}
auto hash = prim0->Hash();
auto const &name = prim0->name();
for (auto &prim : prims) {
if (IsPrimitiveCNode(node, prim)) {
if (hash == prim->Hash() && name == prim->name()) {
return true;
}
}
......@@ -172,7 +181,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo
}
#ifdef ENABLE_PROFILE
MsProfile::StatTime("opt.transform", GetTime() - start);
MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start);
#endif
return changes;
}
......
......@@ -79,16 +79,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Specialization
irpass.specialize_transform_,
// Arithmetic simplifications
irpass.arithmetic_simplify_,
irpass.addn_zero_filter_,
irpass.adjust_all_reduce_mul_add_,
// Miscellaneous
irpass.item_tuple_eliminate_,
irpass.env_get_set_item_,
irpass.new_env_get_item_,
irpass.add_env_get_item_,
irpass.env_get_item_eliminate_,
irpass.cast_eliminate_,
irpass.reshape_eliminate_,
irpass.reduce_eliminate_,
......@@ -96,13 +89,20 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.transpose_eliminate_,
irpass.minmaximum_grad_,
irpass.get_make_ref_eliminate_,
// Arithmetic simplifications
irpass.arithmetic_simplify_,
irpass.addn_zero_filter_,
irpass.adjust_all_reduce_mul_add_,
// Safe inlining
irpass.inline_,
});
opt::OptPassConfig a_2 = opt::OptPassConfig({
irpass.merge_addn_,
irpass.float_tuple_getitem_switch_,
irpass.float_env_getitem_switch_,
irpass.incorporate_getitem_,
irpass.incorporate_getitem_switch_,
irpass.incorporate_getitem_set_,
irpass.incorporate_call_,
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_,
......@@ -145,7 +145,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.reset_defer_inline_,
irpass.inline_,
irpass.special_op_eliminate_,
irpass.stop_gradient_eliminate_,
irpass.get_make_ref_eliminate_,
});
opt::OptPassConfig b_2 = opt::OptPassConfig({
......
......@@ -401,7 +401,7 @@ TEST_F(TestOptLib, test_incorporate_getitem) {
FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after1");
FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_incorporate_getitem", "after2");
auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_});
auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_});
ASSERT_TRUE(CheckOpt(before1, after1, patterns));
ASSERT_TRUE(CheckOpt(before2, after2, patterns));
......@@ -411,7 +411,7 @@ TEST_F(TestOptLib, test_incorporate_getitem_through_switch) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_incorporate_getitem_through_switch", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_switch_});
auto patterns = std::vector<SubstitutionPtr>({irpass.incorporate_getitem_set_});
ASSERT_TRUE(CheckOpt(before, after, patterns));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册