提交 daccfef7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1361 Refactor multiple output pass

Merge pull request !1361 from huanghui/LambNextMVRule-fusion-pass
......@@ -99,11 +99,11 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
......
......@@ -27,82 +27,12 @@
namespace mindspore {
namespace opt {
namespace {
std::tuple<CNodePtr, CNodePtr, AnfNodePtr> GetSharedNodesByPattern(const AnfNodePtr &node) {
auto add3_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kAddInputNum);
MS_EXCEPTION_IF_NULL(add3_cnode);
auto real_div2_cnode = CheckAnfNodeIfCNodeAndInputSize(add3_cnode->input(1), kMulInputNum);
MS_EXCEPTION_IF_NULL(real_div2_cnode);
auto real_div0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(1), kRealDivInputNum);
MS_EXCEPTION_IF_NULL(real_div0_cnode);
auto sqrt0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div2_cnode->input(2), kSqrtInputNum);
MS_EXCEPTION_IF_NULL(sqrt0_cnode);
auto add2_cnode = CheckAnfNodeIfCNodeAndInputSize(sqrt0_cnode->input(1), kAddInputNum);
MS_EXCEPTION_IF_NULL(add2_cnode);
auto real_div1_cnode = CheckAnfNodeIfCNodeAndInputSize(add2_cnode->input(1), kRealDivInputNum);
auto constant_add2_y = add2_cnode->input(2);
return std::make_tuple(real_div0_cnode, real_div1_cnode, constant_add2_y);
}
bool MatchRealDiv4(const AnfNodePtr &real_div4, const AnfNodePtr &real_div1, const AnfNodePtr &constant_add2_y) {
if (real_div4 == nullptr || !real_div4->isa<CNode>()) {
return false;
}
auto real_div4_cnode = real_div4->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(real_div4_cnode);
if (AnfAlgo::GetCNodeName(real_div4_cnode) != kRealDivOpName || real_div4_cnode->inputs().size() < kRealDivInputNum) {
return false;
}
CNodePtr add4_cnode = nullptr;
if (!CheckIfCNodeAndInputSize(real_div4_cnode->input(2), kAddInputNum, &add4_cnode) ||
AnfAlgo::GetCNodeName(add4_cnode) != prim::kPrimTensorAdd->name()) {
return false;
}
CNodePtr sqrt1_cnode = nullptr;
if (!CheckIfCNodeAndInputSize(add4_cnode->input(1), kSqrtInputNum, &sqrt1_cnode) ||
AnfAlgo::GetCNodeName(sqrt1_cnode) != kSqrtOpName) {
return false;
}
MS_EXCEPTION_IF_NULL(add4_cnode->input(2));
MS_EXCEPTION_IF_NULL(constant_add2_y);
return sqrt1_cnode->input(1) == real_div1 && *(add4_cnode->input(2)) == *constant_add2_y;
}
} // namespace
const BaseRef LambNextMVRule::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
auto mul0 = VectorRef({prim::kPrimMul, input_varptr_[7], input_varptr_[4]});
auto mul1 = VectorRef({prim::kPrimMul, input_varptr_[8], input_varptr_[3]});
auto mul2 = VectorRef({prim::kPrimMul, input_varptr_[9], input_varptr_[1]});
auto mul3 = VectorRef({prim::kPrimMul, input_varptr_[10], input_varptr_[0]});
auto mul4 = VectorRef({prim::kPrimMul, input_varptr_[11], input_varptr_[6]});
auto add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1});
auto add1 = VectorRef({prim::kPrimTensorAdd, mul2, mul3});
auto real_div0 = VectorRef({prim_deal_div, add0, input_varptr_[5]});
auto real_div1 = VectorRef({prim_deal_div, add1, input_varptr_[2]});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, input_varptr_[12]});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
}
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
std::vector<AnfNodePtr> *old_pattern_outputs) const {
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr real_div0 = nullptr;
CNodePtr real_div1 = nullptr;
AnfNodePtr constant_add2_y = nullptr;
std::tie(real_div0, real_div1, constant_add2_y) = GetSharedNodesByPattern(node);
MS_EXCEPTION_IF_NULL(equiv);
auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_);
auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
......@@ -112,19 +42,17 @@ bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNode
}
AnfNodeIndexSet real_div0_outputs = users[real_div0];
auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(),
[&node, &real_div1, &constant_add2_y](const std::pair<AnfNodePtr, int> &node_index) {
return node_index.first != node && node_index.second == 1 &&
MatchRealDiv4(node_index.first, real_div1, constant_add2_y);
[&real_div2, &equiv, this](const std::pair<AnfNodePtr, int> &node_index) {
return node_index.first != real_div2 && node_index.second == 1 &&
MatchAnotherPattern(node_index.first, equiv);
});
if (iter == real_div0_outputs.end()) {
return false;
}
auto add0_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div0->input(1), kAddInputNum);
auto add1_cnode = CheckAnfNodeIfCNodeAndInputSize(real_div1->input(1), kAddInputNum);
(*old_pattern_outputs).push_back(node);
(*old_pattern_outputs).push_back(add0_cnode);
(*old_pattern_outputs).push_back(add1_cnode);
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_));
(*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_));
(*old_pattern_outputs).push_back(iter->first);
return true;
......@@ -136,8 +64,19 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(func_graph);
auto prim = std::make_shared<Primitive>(kLambNextMVOpName);
std::vector<AnfNodePtr> lamb_next_mv_rule_inputs = {NewValueNode(prim)};
(void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(lamb_next_mv_rule_inputs),
[&equiv](const VarPtr &in) { return utils::cast<AnfNodePtr>((*equiv)[in]); });
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input0_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input1_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input2_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input3_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input4_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input5_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[input6_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul0_x_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul1_sub_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul2_x_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul3_sub1_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[mul4_x_]));
lamb_next_mv_rule_inputs.push_back(utils::cast<AnfNodePtr>((*equiv)[add2_y_]));
auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs);
MS_EXCEPTION_IF_NULL(lamb_next_mv_rule);
......@@ -162,14 +101,60 @@ AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph,
return lamb_next_mv_rule_outputs[0];
}
bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) &&
IsSameNode(equiv1, equiv2, add2_y_);
}
const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
std::vector<AnfNodePtr> old_pattern_outputs;
if (!IsRuleMatched(func_graph, node, &old_pattern_outputs)) {
if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) {
return nullptr;
}
return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
}
const BaseRef LambNextMVRuleCond4::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
auto add0 = VectorRef({add0_var_, mul0, mul1});
auto add1 = VectorRef({add1_var_, mul2, mul3});
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0});
return VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
}
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_real_div);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
} // namespace opt
} // namespace mindspore
......@@ -29,23 +29,71 @@
namespace mindspore {
namespace opt {
class LambNextMVRule : public PatternProcessPass {
class LambNextMVRule : public MultipleOutputPatternProcessPass {
public:
explicit LambNextMVRule(bool multigraph = true) : PatternProcessPass("lamb_next_mv_rule", multigraph) {
for (size_t i = 0; i < kLambNextMVRuleInputNum - 1; ++i) {
input_varptr_.push_back(std::make_shared<Var>());
}
explicit LambNextMVRule(const std::string &name = "", bool multigraph = true)
: MultipleOutputPatternProcessPass(name, multigraph) {
input0_ = std::make_shared<Var>();
input1_ = std::make_shared<Var>();
input2_ = std::make_shared<Var>();
input3_ = std::make_shared<Var>();
input4_ = std::make_shared<Var>();
input5_ = std::make_shared<Var>();
input6_ = std::make_shared<Var>();
mul0_x_ = std::make_shared<Var>();
mul1_sub_ = std::make_shared<Var>();
mul2_x_ = std::make_shared<Var>();
mul3_sub1_ = std::make_shared<Var>();
mul4_x_ = std::make_shared<Var>();
add2_y_ = std::make_shared<Var>();
real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimTensorAdd->name()));
}
~LambNextMVRule() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefinePattern() const override = 0;
BaseRef DefineAnotherPattern() const override = 0;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override;
private:
std::vector<VarPtr> input_varptr_;
bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
protected:
bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
std::vector<AnfNodePtr> *old_pattern_outputs) const;
AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &old_pattern_outputs,
const EquivPtr &equiv) const;
VarPtr input0_;
VarPtr input1_;
VarPtr input2_;
VarPtr input3_;
VarPtr input4_;
VarPtr input5_;
VarPtr input6_;
VarPtr mul0_x_;
VarPtr mul1_sub_;
VarPtr mul2_x_;
VarPtr mul3_sub1_;
VarPtr mul4_x_;
VarPtr add2_y_;
// nodes which two patterns share, and add2_y_ also.
VarPtr real_div0_var_;
VarPtr real_div1_var_;
// part of output nodes
VarPtr add0_var_;
VarPtr add1_var_;
// other node
VarPtr real_div2_var_;
};
class LambNextMVRuleCond4 : public LambNextMVRule {
public:
explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {}
~LambNextMVRuleCond4() override = default;
const BaseRef DefinePattern() const override;
BaseRef DefineAnotherPattern() const override;
};
} // namespace opt
} // namespace mindspore
......
......@@ -79,63 +79,6 @@ AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGrap
return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv);
}
const BaseRef LambNextMVWithDecayRule::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
return add5;
}
const BaseRef LambNextMVWithDecayRule::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
return add3;
}
bool LambNextMVWithDecayRule::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
VarPtr fg = std::make_shared<Var>("RootG");
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
EquivPtr another_equiv =
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
*child_primitive_vars_, empty_equiv);
if (another_equiv != nullptr && !another_equiv->empty()) {
return IsShareNodes(equiv, another_equiv);
}
return false;
}
bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const {
return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) &&
IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_);
......@@ -164,7 +107,7 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
return nullptr;
}
const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
......@@ -205,7 +148,7 @@ const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
return add5;
}
const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
......@@ -246,7 +189,7 @@ const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
return add5;
}
const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
......@@ -286,5 +229,47 @@ const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
return add5;
}
BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
// Two patterns share: real_div0, real_div1, mul4, constant_add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4});
return add3;
}
const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4});
return add5;
}
} // namespace opt
} // namespace mindspore
......@@ -24,15 +24,10 @@
namespace mindspore {
namespace opt {
class LambNextMVWithDecayRule : public PatternProcessPass {
class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass {
public:
explicit LambNextMVWithDecayRule(const std::string &name = "lamb_next_mv_with_decay_rule_cond4",
bool multigraph = true)
: PatternProcessPass(name, multigraph),
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {
explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true)
: MultipleOutputPatternProcessPass(name, multigraph) {
for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) {
input_vars_.push_back(std::make_shared<Var>());
}
......@@ -48,21 +43,16 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
}
~LambNextMVWithDecayRule() override = default;
const BaseRef DefinePattern() const override;
virtual const BaseRef DefineAnotherPattern() const;
const BaseRef DefinePattern() const override = 0;
BaseRef DefineAnotherPattern() const override = 0;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override;
protected:
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
// check two patterns whether share the same nodes or not
bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const;
AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node,
const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const;
AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3,
const AnfNodePtr &add5, const EquivPtr &equiv) const;
PatternEngine child_pattern_engine_;
PrimitiveVarMapPtr child_primitive_vars_;
std::vector<VarPtr> input_vars_;
std::vector<VarPtr> constant_mul_input_vars_;
// nodes which two patterns share
......@@ -82,7 +72,7 @@ class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule {
~LambNextMVWithDecayRuleCond1() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
......@@ -92,7 +82,7 @@ class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
~LambNextMVWithDecayRuleCond2() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
......@@ -102,7 +92,17 @@ class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
~LambNextMVWithDecayRuleCond3() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule {
public:
explicit LambNextMVWithDecayRuleCond4(bool multigraph = true)
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {}
~LambNextMVWithDecayRuleCond4() override = default;
const BaseRef DefinePattern() const override;
BaseRef DefineAnotherPattern() const override;
};
} // namespace opt
} // namespace mindspore
......
......@@ -62,6 +62,21 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode
return nullptr;
}
bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
VarPtr fg = std::make_shared<Var>("RootG");
auto empty_equiv = std::make_shared<Equiv>();
MS_EXCEPTION_IF_NULL(child_primitive_vars_);
EquivPtr another_equiv =
child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node,
*child_primitive_vars_, empty_equiv);
if (another_equiv != nullptr && !another_equiv->empty()) {
return IsShareNodes(equiv, another_equiv);
}
return false;
}
void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) {
if (pass_manager != nullptr) {
pass_managers_.push_back(pass_manager);
......
......@@ -51,6 +51,25 @@ class PatternProcessPass : public NodePass {
PrimitiveVarMapPtr primitive_vars_;
};
class MultipleOutputPatternProcessPass : public PatternProcessPass {
public:
explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true)
: PatternProcessPass(name, multigraph),
child_pattern_engine_(PatternEngine(std::make_shared<DefaultVisitor>(),
std::function<bool(const BaseRef &, const BaseRef &)>(AnfEqual),
std::function<bool(const BaseRef &, const BaseRef &)>(CNodeTypeEqual))),
child_primitive_vars_(std::make_shared<PrimitiveVarMap>()) {}
~MultipleOutputPatternProcessPass() override = default;
virtual BaseRef DefineAnotherPattern() const = 0;
// check two patterns whether share the same nodes or not
virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0;
protected:
bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const;
PatternEngine child_pattern_engine_;
PrimitiveVarMapPtr child_primitive_vars_;
};
class GraphOptimizer {
public:
explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {}
......
......@@ -30,7 +30,7 @@ class TestHWLambNextMVRule : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_matched) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -54,7 +54,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -65,15 +65,15 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_matched) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div4) {
/*
* def before_unmatched_real_div4(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -97,7 +97,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div4");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div4");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -109,14 +109,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div4) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div2) {
/*
* def before_unmatched_real_div2(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -140,7 +140,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div2");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div2");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -152,14 +152,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div2) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div0) {
/*
* def before_unmatched_real_div0(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -183,7 +183,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div0");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div0");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -195,14 +195,14 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div0) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) {
/*
* def before_unmatched_real_div1(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x,
* constant_mul1_sub, constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -226,7 +226,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule", "before_unmatched_real_div1");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond4", "before_unmatched_real_div1");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -238,7 +238,7 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_unmatched_real_div1) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRule>());
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
......
......@@ -30,7 +30,7 @@ class TestHWLambNextMVWithDecayRule : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_matched) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -55,7 +55,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -66,15 +66,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_matched) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add3) {
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_add3) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -99,7 +99,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_add3");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_add3");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -111,15 +111,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_add
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul4) {
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_mul4) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -144,7 +144,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_mul4");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_mul4");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -156,15 +156,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_mul
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div0) {
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div0) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -189,7 +189,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div0");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div0");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -201,15 +201,15 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_real_div1) {
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_cond4_unmatched_real_div1) {
/*
* def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
* constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
......@@ -234,7 +234,7 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
* output = tuple_getitem(outputs, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "before_unmatched_real_div1");
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "before_unmatched_real_div1");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......@@ -246,11 +246,11 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRule>());
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond4>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond4", "after");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
......
......@@ -36,7 +36,7 @@ class FnDict:
return self.fnDict[name]
def test_lamb_next_mv_rule(tag):
def test_lamb_next_mv_rule_cond4(tag):
fns = FnDict()
@fns
......
......@@ -34,7 +34,7 @@ class FnDict:
def __getitem__(self, name):
return self.fnDict[name]
def test_lamb_next_mv_with_decay_rule(tag):
def test_lamb_next_mv_with_decay_rule_cond4(tag):
fns = FnDict()
@fns
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册