提交 5a68eba5 编写于 作者: H huanghui

Refactor LambNextMVWithDecayRule fusion pass

上级 93e7c97a
......@@ -112,7 +112,6 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<DereluFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionMulGradFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
......
......@@ -20,28 +20,23 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
const AnfNodePtr &add3, const AnfNodePtr &add5, const AnfNodePtr &real_div0,
const AnfNodePtr &real_div1) {
AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph,
const AnfNodePtr &new_node, const AnfNodePtr &add3,
const AnfNodePtr &add5, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(new_node);
MS_EXCEPTION_IF_NULL(add3);
MS_EXCEPTION_IF_NULL(real_div0);
MS_EXCEPTION_IF_NULL(real_div1);
MS_EXCEPTION_IF_NULL(add5);
MS_EXCEPTION_IF_NULL(equiv);
auto add0 = GetAnfNodeByVar(equiv, add0_var_);
MS_EXCEPTION_IF_NULL(add0);
auto add1 = GetAnfNodeByVar(equiv, add1_var_);
MS_EXCEPTION_IF_NULL(add1);
// Set abstract of new node
AbstractBasePtrList new_node_list;
new_node_list.push_back(add3->abstract());
auto real_div0_cnode = real_div0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div0_cnode);
AnfNodePtr add0 = real_div0_cnode->input(1);
MS_EXCEPTION_IF_NULL(add0);
new_node_list.push_back(add0->abstract());
auto real_div1_cnode = real_div1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div1_cnode);
AnfNodePtr add1 = real_div1_cnode->input(1);
MS_EXCEPTION_IF_NULL(add1);
new_node_list.push_back(add1->abstract());
new_node_list.push_back(add5->abstract());
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(new_node_list);
......@@ -58,94 +53,8 @@ AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const An
return new_node_outputs[3];
}
void GetSharedInputNodesByAdd5(const AnfNodePtr &node, AnfNodePtr *mul4, AnfNodePtr *real_div0, AnfNodePtr *real_div1,
AnfNodePtr *constant_add2_y_input) {
MS_EXCEPTION_IF_NULL(node);
auto add5_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add5_cnode);
if (add5_cnode->inputs().size() < kAddInputNum) {
MS_LOG(EXCEPTION) << "The input size of Add5 is less than " << kAddInputNum;
}
*mul4 = add5_cnode->input(2);
AnfNodePtr real_div4 = add5_cnode->input(1);
MS_EXCEPTION_IF_NULL(real_div4);
auto real_div4_cnode = real_div4->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div4_cnode);
if (real_div4_cnode->inputs().size() < kRealDivInputNum) {
MS_LOG(EXCEPTION) << "The input size of RealDiv4 is less than " << kRealDivInputNum;
}
*real_div0 = real_div4_cnode->input(1);
AnfNodePtr add4 = real_div4_cnode->input(2);
MS_EXCEPTION_IF_NULL(add4);
auto add4_cnode = add4->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add4_cnode);
if (add4_cnode->inputs().size() < kAddInputNum) {
MS_LOG(EXCEPTION) << "The input size of Add4 is less than " << kAddInputNum;
}
AnfNodePtr sqrt1 = add4_cnode->input(1);
MS_EXCEPTION_IF_NULL(sqrt1);
auto sqrt1_cnode = sqrt1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt1_cnode);
if (sqrt1_cnode->inputs().size() < kSqrtInputNum) {
MS_LOG(EXCEPTION) << "The input size of Sqrt1 is less than " << kSqrtInputNum;
}
*real_div1 = sqrt1_cnode->input(1);
*constant_add2_y_input = add4_cnode->input(2);
}
bool MatchAdd3(const AnfNodePtr &add3, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, const AnfNodePtr &real_div1,
const AnfNodePtr &constant_add2_y) {
if (add3 == nullptr || !add3->isa<CNode>()) {
return false;
}
auto add3_cnode = add3->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add3_cnode);
if (AnfAlgo::GetCNodeName(add3_cnode) != prim::kPrimTensorAdd->name() ||
add3_cnode->inputs().size() != kAddInputNum) {
return false;
}
// Check the shared input nodes.
if (add3_cnode->input(2) != mul4) {
return false;
}
AnfNodePtr real_div2 = add3_cnode->input(1);
MS_EXCEPTION_IF_NULL(real_div2);
auto real_div2_cnode = real_div2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div2_cnode);
if (AnfAlgo::GetCNodeName(real_div2_cnode) != prim::kPrimMul->name() ||
real_div2_cnode->inputs().size() != kMulInputNum) {
return false;
}
if (real_div2_cnode->input(1) != real_div0) {
return false;
}
AnfNodePtr sqrt0 = real_div2_cnode->input(2);
MS_EXCEPTION_IF_NULL(sqrt0);
auto sqrt0_cnode = sqrt0->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
if (AnfAlgo::GetCNodeName(sqrt0_cnode) != kRsqrtOpName || sqrt0_cnode->inputs().size() != kRsqrtInputNum) {
return false;
}
AnfNodePtr add2 = sqrt0_cnode->input(1);
MS_EXCEPTION_IF_NULL(add2);
auto add2_cnode = add2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add2_cnode);
if (AnfAlgo::GetCNodeName(add2_cnode) != prim::kPrimTensorAdd->name() ||
add2_cnode->inputs().size() != kAddInputNum) {
return false;
}
MS_EXCEPTION_IF_NULL(add2_cnode->input(2));
MS_EXCEPTION_IF_NULL(constant_add2_y);
return add2_cnode->input(1) == real_div1 && *(add2_cnode->input(2)) == *constant_add2_y;
}
} // namespace
AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph,
const AnfNodePtr &add3, const AnfNodePtr &add5,
const AnfNodePtr &real_div0,
const AnfNodePtr &real_div1,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(add3);
......@@ -167,7 +76,7 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
MS_EXCEPTION_IF_NULL(constant_add2_y_node);
new_node_inputs.push_back(constant_add2_y_node);
auto new_node = func_graph->NewCNode(new_node_inputs);
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, real_div0, real_div1);
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
}
const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
......@@ -175,44 +84,82 @@ const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul4 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add0 =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}),
VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]})});
VectorRef real_div0 = VectorRef({prim_deal_div, add0, input_vars_[5]});
VectorRef add1 =
VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}),
VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]})});
VectorRef real_div1 = VectorRef({prim_deal_div, add1, input_vars_[2]});
VectorRef real_div4 = VectorRef(
{prim_deal_div, real_div0, VectorRef({prim::kPrimTensorAdd, VectorRef({prim_sqrt, real_div1}), constant_add2_y_})});
return VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
return add5;
}
const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
return add3;
}
bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
VarPtr fg = std::make_shared<Var>("RootG");
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
EquivPtr another_equiv =
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
*child_primitive_vars_, empty_equiv);
if (another_equiv != nullptr && !another_equiv->empty()) {
return IsShareNodes(equiv, another_equiv);
}
return false;
}
bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
}
const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
// Get the shared input nodes in patterns of add5 and add3
AnfNodePtr mul4 = nullptr;
AnfNodePtr real_div0 = nullptr;
AnfNodePtr real_div1 = nullptr;
AnfNodePtr constant_add2_y_input = nullptr;
GetSharedInputNodesByAdd5(node, &mul4, &real_div0, &real_div1, &constant_add2_y_input);
// Get add3 and try to match the add3 pattern
AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_);
MS_EXCEPTION_IF_NULL(mul4);
// Get add3 and match the add3 pattern
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if (manager->node_users().find(mul4) == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input";
}
AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4];
auto iter = std::find_if(
mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(),
[&node, &mul4, &real_div0, &real_div1, &constant_add2_y_input](const std::pair<AnfNodePtr, int> &node_index) {
return node_index.first != node && MatchAdd3(node_index.first, mul4, real_div0, real_div1, constant_add2_y_input);
});
if (iter != mul4_output_node_index_set.end()) {
return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, real_div0, real_div1, equiv);
AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4];
auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(),
[&node, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
return node_index.first != node && MatchAnotherPattern(node_index.first, equiv);
});
if (iter != mul4_outputs.end()) {
return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv);
}
return nullptr;
}
......
......@@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include <string>
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/helper.h"
......@@ -25,8 +26,13 @@ namespace mindspore {
namespace opt {
class LambNextMVWithDecayRule : public PatternProcessPass {
public:
explicit LambNextMVWithDecayRule(bool multigraph = true)
: PatternProcessPass("lamb_next_mv_with_decay_rule", multigraph) {
explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4",
bool multigraph = true)
: PatternProcessPass(name, multigraph),
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {
for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
input_vars_.push_back(std::make_shared<Var>());
}
......@@ -34,20 +40,39 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
constant_mul_input_vars_.push_back(std::make_shared<Var>());
}
constant_add2_y_ = std::make_shared<Var>();
mul4_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
}
~LambNextMVWithDecayRule() override = default;
const BaseRef DefinePattern() const override;
virtual const BaseRef DefineAnotherPattern() const;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
const AnfNodePtr &add5, const AnfNodePtr &real_div0,
const AnfNodePtr &real_div1, const EquivPtr &equiv) const;
protected:
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
// check two patterns whether share the same nodes or not
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const;
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const;
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
const AnfNodePtr &add5, const EquivPtr &equiv) const;
PatternEngine child_pattern_engine_;
PrimitiveVarMapPtr child_primitive_vars_;
std::vector<VarPtr> input_vars_;
std::vector<VarPtr> constant_mul_input_vars_;
// nodes which two patterns share
VarPtr constant_add2_y_;
VarPtr mul4_var_;
VarPtr real_div0_var_;
VarPtr real_div1_var_;
// part of output nodes
VarPtr add0_var_;
VarPtr add1_var_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -64,6 +64,8 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node);
AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node);
AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node);
auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node);
return new_node;
}
......
......@@ -64,6 +64,8 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node);
AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node);
AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node);
auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node);
return new_node;
}
......
......@@ -539,5 +539,169 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &i
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
}
}
bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>();
auto a_value = a_value_node->value();
auto a_prim = a_value->cast<PrimitivePtr>();
auto b_value_node = b_node->cast<ValueNodePtr>();
auto b_value = b_value_node->value();
auto b_prim = b_value->cast<PrimitivePtr>();
return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto a_value_ptr = a_value_node_ptr->value();
if (a_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (b_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto b_value_ptr = b_value_node_ptr->value();
if (b_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
return (*a_value_ptr) == (*b_value_ptr);
}
MS_LOG(DEBUG) << "check AnfNodePtr equal";
}
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
MS_LOG(DEBUG) << "check GraphPtr equal";
}
return a == b;
}
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return a.type() == b.type();
}
namespace {
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
input_nodes.push_back(node);
}
VarPtr var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
} // namespace
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_EXCEPTION_IF_NULL(primitive_vars);
if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
}
if (utils::isa<VarPtr>(sexp)) {
auto var_ptr = utils::cast<VarPtr>(sexp);
MS_EXCEPTION_IF_NULL(var_ptr);
if (var_ptr->primitive()) {
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
return NewValueNode(var_ptr->primitive());
}
return CreateVarNodeWithSexp(sexp, graph);
}
if (utils::isa<AnfNodePtr>(sexp)) {
return utils::cast<AnfNodePtr>(sexp);
}
auto value_node = CreateValueNodeWithSexp(sexp);
if (value_node == nullptr) {
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
}
return value_node;
}
bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) {
MS_EXCEPTION_IF_NULL(equiv1);
MS_EXCEPTION_IF_NULL(equiv2);
MS_EXCEPTION_IF_NULL(var_node);
auto equiv1_node = GetAnfNodeByVar(equiv1, var_node);
MS_EXCEPTION_IF_NULL(equiv1_node);
auto equiv2_node = GetAnfNodeByVar(equiv2, var_node);
MS_EXCEPTION_IF_NULL(equiv2_node);
return equiv1_node == equiv2_node;
}
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) {
MS_EXCEPTION_IF_NULL(equiv);
MS_EXCEPTION_IF_NULL(var_node);
auto iter = (*equiv).find(var_node);
if (iter == (*equiv).end()) {
MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched.";
return nullptr;
}
auto res = utils::cast<AnfNodePtr>(iter->second);
if (res == nullptr) {
MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node";
}
return res;
}
} // namespace opt
} // namespace mindspore
......@@ -23,6 +23,7 @@
#include "ir/func_graph.h"
#include "session/kernel_graph.h"
#include "common/utils.h"
#include "pre_activate/common/pattern_engine.h"
namespace mindspore {
namespace opt {
......@@ -162,6 +163,19 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
bool AnfEqual(const BaseRef &a, const BaseRef &b);
bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b);
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph = false);
// Check var_node in two equivs is the same node
bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node);
// Get anf_node from equiv by var_node
AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
......@@ -29,148 +29,6 @@
namespace mindspore {
namespace opt {
namespace {
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph);
ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) {
if (utils::isa<int>(sexp)) {
return NewValueNode(utils::cast<int>(sexp));
}
if (utils::isa<float>(sexp)) {
return NewValueNode(utils::cast<float>(sexp));
}
if (utils::isa<bool>(sexp)) {
return NewValueNode(utils::cast<bool>(sexp));
}
if (utils::isa<ValuePtr>(sexp)) {
return NewValueNode(utils::cast<ValuePtr>(sexp));
}
return nullptr;
}
CNodePtr CreateCNodeWithGraph(const std::vector<AnfNodePtr> &input_nodes, const BaseRef &graph) {
if (utils::isa<FuncGraphPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<FuncGraphPtr>(graph));
}
if (utils::isa<VarPtr>(graph)) {
return std::make_shared<CNode>(input_nodes, utils::cast<VarPtr>(graph));
}
return nullptr;
}
VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) {
if (utils::isa<VarPtr>(graph)) {
MS_LOG(DEBUG) << "make VarPtr " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), nullptr);
}
if (utils::isa<FuncGraphPtr>(graph)) {
MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString();
return std::make_shared<VarNode>(utils::cast<VarPtr>(sexp), utils::cast<FuncGraphPtr>(graph));
}
MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString();
return nullptr;
}
AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph = false) {
MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString();
MS_EXCEPTION_IF_NULL(primitive_vars);
if (utils::isa<VectorRef>(sexp)) {
return HandleSexpVector(sexp, graph, primitive_vars, multigraph);
}
if (utils::isa<VarPtr>(sexp)) {
auto var_ptr = utils::cast<VarPtr>(sexp);
MS_EXCEPTION_IF_NULL(var_ptr);
if (var_ptr->primitive()) {
(*primitive_vars)[var_ptr->primitive()] = var_ptr;
return NewValueNode(var_ptr->primitive());
}
return CreateVarNodeWithSexp(sexp, graph);
}
if (utils::isa<AnfNodePtr>(sexp)) {
return utils::cast<AnfNodePtr>(sexp);
}
auto value_node = CreateValueNodeWithSexp(sexp);
if (value_node == nullptr) {
MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString();
}
return value_node;
}
AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars,
bool multigraph) {
MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString();
std::vector<AnfNodePtr> input_nodes;
const auto &tuple = utils::cast<VectorRef>(sexp);
if (multigraph && utils::isa<VarPtr>(graph)) {
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, std::make_shared<Var>("G"), primitive_vars, true);
input_nodes.push_back(node);
}
VarPtr var_ptr = utils::cast<VarPtr>(graph);
return std::make_shared<CNode>(input_nodes, var_ptr);
}
for (auto &x : tuple) {
AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph);
input_nodes.push_back(node);
}
return CreateCNodeWithGraph(input_nodes, graph);
}
} // namespace
static bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (utils::isa<AnfNodePtr>(a) && utils::isa<AnfNodePtr>(b)) {
auto a_node = utils::cast<AnfNodePtr>(a);
auto b_node = utils::cast<AnfNodePtr>(b);
if (IsValueNode<Primitive>(a_node) && IsValueNode<Primitive>(b_node)) {
auto a_value_node = a_node->cast<ValueNodePtr>();
auto a_value = a_value_node->value();
auto a_prim = a_value->cast<PrimitivePtr>();
auto b_value_node = b_node->cast<ValueNodePtr>();
auto b_value = b_value_node->value();
auto b_prim = b_value->cast<PrimitivePtr>();
return a_prim->name() == b_prim->name();
} else if (a_node->isa<ValueNode>() && b_node->isa<ValueNode>()) {
auto a_value_node_ptr = a_node->cast<ValueNodePtr>();
if (a_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto a_value_ptr = a_value_node_ptr->value();
if (a_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
auto b_value_node_ptr = b_node->cast<ValueNodePtr>();
if (b_value_node_ptr == nullptr) {
MS_LOG(EXCEPTION) << "cast value node ptr fail";
}
auto b_value_ptr = b_value_node_ptr->value();
if (b_value_ptr == nullptr) {
MS_LOG(EXCEPTION) << "value ptr is nullptr";
}
return (*a_value_ptr) == (*b_value_ptr);
}
MS_LOG(DEBUG) << "check AnfNodePtr equal";
}
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
MS_LOG(DEBUG) << "check GraphPtr equal";
}
return a == b;
}
static bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) {
// To matchCNode and Kernel's type
if (utils::isa<CNode>(a) && utils::isa<CNode>(b)) {
return true;
}
return a.type() == b.type();
}
PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name),
multigraph_(multigraph),
......
......@@ -28,6 +28,7 @@
#include "pre_activate/common/pattern_engine.h"
#include "utils/graph_utils.h"
#include "common/utils.h"
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
......
......@@ -17,6 +17,7 @@
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
......
......@@ -14,7 +14,6 @@
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
batch_norm_grad = G.BatchNormGrad(is_training=False)
......
......@@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
LambNextMVWithDecay = Primitive('LambNextMVWithDecay')
class FnDict:
def __init__(self):
self.fnDict = {}
......@@ -35,7 +34,6 @@ class FnDict:
def __getitem__(self, name):
return self.fnDict[name]
def test_lamb_next_mv_with_decay_rule(tag):
fns = FnDict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册