提交 55027096 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1400 Add 3 patterns for lamb_next_mv fusion pass

Merge pull request !1400 from huanghui/TMP
......@@ -104,6 +104,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
......
......@@ -116,9 +116,116 @@ const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const A
return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
}
const BaseRef LambNextMVRuleCond1::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
auto add0 = VectorRef({add0_var_, mul0, mul1});
auto add1 = VectorRef({add1_var_, mul2, mul3});
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
const BaseRef LambNextMVRuleCond2::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
auto add0 = VectorRef({add0_var_, mul0, mul1});
auto add1 = VectorRef({add1_var_, mul2, mul3});
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
const BaseRef LambNextMVRuleCond3::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_});
auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
auto add0 = VectorRef({add0_var_, mul0, mul1});
auto add1 = VectorRef({add1_var_, mul2, mul3});
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
auto sqrt0 = VectorRef({prim_rsqrt, add2});
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
}
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
return real_div4;
}
const BaseRef LambNextMVRuleCond4::DefinePattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
......@@ -140,13 +247,9 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const {
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_real_div);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
// Two patterns share: real_div0, real_div1, add2_y_
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
......
......@@ -87,6 +87,33 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass {
VarPtr real_div2_var_;
};
class LambNextMVRuleCond1 : public LambNextMVRule {
public:
explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {}
~LambNextMVRuleCond1() override = default;
const BaseRef DefinePattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVRuleCond2 : public LambNextMVRule {
public:
explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {}
~LambNextMVRuleCond2() override = default;
const BaseRef DefinePattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVRuleCond3 : public LambNextMVRule {
public:
explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {}
~LambNextMVRuleCond3() override = default;
const BaseRef DefinePattern() const override;
BaseRef DefineAnotherPattern() const override;
};
class LambNextMVRuleCond4 : public LambNextMVRule {
public:
explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {}
......
......@@ -244,5 +244,125 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) {
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_unmatched) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_unmatched) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_unmatched) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
}
} // namespace opt
} // namespace mindspore
......@@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
LambNextMV = Primitive('LambNextMV')
class FnDict:
def __init__(self):
self.fnDict = {}
......@@ -35,7 +34,6 @@ class FnDict:
def __getitem__(self, name):
return self.fnDict[name]
def test_lamb_next_mv_rule_cond4(tag):
fns = FnDict()
......@@ -170,3 +168,192 @@ def test_lamb_next_mv_rule_cond4(tag):
return output
return fns[tag]
def test_lamb_next_mv_rule_cond1(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1)
mul2 = Mul(constant_mul2_x, input1)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
add4 = Add(constant_add2_y, sqrt1)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(constant_mul4_x, input6)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
constant_mul4_x, constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(constant_mul0_x, input4)
mul1 = Mul(constant_mul1_sub, input3)
add0 = Add(mul0, mul1)
mul2 = Mul(constant_mul2_x, input1)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
# un match
add4 = Add(sqrt1, constant_add2_y)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(constant_mul4_x, input6)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]
def test_lamb_next_mv_rule_cond2(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(input4, constant_mul0_x)
mul1 = Mul(input3, constant_mul1_sub)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
add4 = Add(sqrt1, constant_add2_y)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(input6, constant_mul4_x)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
constant_mul4_x, constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(input4, constant_mul0_x)
mul1 = Mul(input3, constant_mul1_sub)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
# un match
add4 = Add(constant_add2_y, sqrt1)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(input6, constant_mul4_x)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]
def test_lamb_next_mv_rule_cond3(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(input4, constant_mul0_x)
mul1 = Mul(input3, constant_mul1_sub)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(input0, constant_mul3_sub1)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(real_div1, constant_add2_y)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
add4 = Add(sqrt1, constant_add2_y)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(input6, constant_mul4_x)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
constant_mul4_x, constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul0 = Mul(input4, constant_mul0_x)
mul1 = Mul(input3, constant_mul1_sub)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(input0, constant_mul3_sub1)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(real_div1, constant_add2_y)
sqrt0 = Rsqrt(add2)
sqrt1 = Sqrt(real_div1)
# un match
add4 = Add(constant_add2_y, sqrt1)
real_div0 = RealDiv(add0, input5)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
mul4 = Mul(input6, constant_mul4_x)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, real_div4)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册