提交 69d1b4c0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2468 Update Ref Eliminate and TileEliminate to Pattern Matcher

Merge pull request !2468 from Giancarlo/optimizer_update
......@@ -39,6 +39,10 @@ namespace mindspore {
template <typename T>
class PBase {
public:
bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) {
return func(get_object().GetNode(node));
}
const T &get_object() const { return *static_cast<const T *>(this); }
template <typename TN>
......
......@@ -45,7 +45,7 @@ class SwitchSimplify : public OptimizerCaller {
};
MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda,
IsValueNode<BoolImm>(cond.GetNode(node)));
cond.CheckFunc(IsValueNode<BoolImm>, node));
return nullptr;
}
......@@ -61,7 +61,7 @@ class FloatTupleGetItemSwitch : public OptimizerCaller {
PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x),
PPrimitive(prim::kPrimTupleGetItem, false_br, x)),
IsVNode(x.GetNode(node)));
x.CheckFunc(IsVNode, node));
return nullptr;
}
};
......@@ -72,11 +72,10 @@ class FloatEnvGetItemSwitch : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> cond, true_br, false_br, x, x2;
MATCH_REPLACE_IF(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)),
IsNode(x.GetNode(node)) && IsNode(x2.GetNode(node)));
MATCH_REPLACE(node,
PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2),
PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2),
PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2)));
return nullptr;
}
......@@ -142,9 +141,9 @@ class ConvertSwitchReplacement : public OptimizerCaller {
return nnode;
};
MATCH_REPLACE_LAMBDA_IF(node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
IsNode(cond.GetNode(node_)) && IsValueNode<FuncGraph>(true_br.GetNode(node_)) &&
IsValueNode<FuncGraph>(false_br.GetNode(node_)));
MATCH_REPLACE_LAMBDA_IF(
node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda,
true_br.CheckFunc(IsValueNode<FuncGraph>, node_) && false_br.CheckFunc(IsValueNode<FuncGraph>, node_));
return nullptr;
}
......
......@@ -21,109 +21,70 @@
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "utils/graph_utils.h"
#include "operator/composite/composite.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimMakeRef, X, Y, Z} -> Y
class MakeRefEliminater : public AnfVisitor {
class MakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
y_ = nullptr;
auto gety = [this](const AnfNodePtr &node) -> bool {
this->y_ = node;
return true;
};
AnfVisitor::Match(prim::kPrimMakeRef, {IsNode, gety, IsNode})(node);
return y_;
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
return nullptr;
}
void Visit(const AnfNodePtr &) override {}
private:
AnfNodePtr y_{nullptr};
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public AnfVisitor {
class GetRefParamEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
x_ = nullptr;
AnfVisitor::Match(prim::kPrimGetRefOrigin, {IsParam})(node);
if (x_ != nullptr) {
return x_;
}
AnfVisitor::Match(prim::kPrimGetRefValue, {IsParam})(node);
return x_;
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
return nullptr;
}
void Visit(const AnfNodePtr &node) override { x_ = node; }
private:
AnfNodePtr x_{nullptr};
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public AnfVisitor {
class GetMakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 2) {
return nullptr;
}
// {prim::kPrimGetRefKey/Value, {...}}
auto ref = cnode->input(1)->cast<CNodePtr>();
if (ref == nullptr || !ref->IsApply(prim::kPrimMakeRef) || ref->size() != 4) {
return nullptr;
}
// {prim::kPrimMakeRef, X, Y, Z}
if (cnode->IsApply(prim::kPrimGetRefKey)) {
return ref->input(1);
}
if (cnode->IsApply(prim::kPrimGetRefValue)) {
return ref->input(2);
}
if (cnode->IsApply(prim::kPrimGetRefOrigin)) {
return ref->input(3);
}
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
return nullptr;
}
};
// IsValueNode<RefKey>
class ReplaceRefkeyByParam : public AnfVisitor {
class ReplaceRefkeyByParam : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!IsValueNode<RefKey>(node)) {
return nullptr;
}
auto refkey = GetValueNode<RefKeyPtr>(node);
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
MS_EXCEPTION_IF_NULL(resource);
auto top_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(top_graph);
for (const auto &tnode : top_graph->parameters()) {
auto para = tnode->cast<ParameterPtr>();
if (para != nullptr && para->name() == refkey->tag()) {
return para;
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
auto refkey = GetValueNode<RefKeyPtr>(node);
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
MS_EXCEPTION_IF_NULL(resource);
auto top_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(top_graph);
for (const auto &tnode : top_graph->parameters()) {
auto para = tnode->cast<ParameterPtr>();
if (para != nullptr && para->name() == refkey->tag()) {
return para;
}
}
}
return nullptr;
};
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
return nullptr;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册