提交 12eaaf71 编写于 作者: H huanghui

mul_add_fusion pass supports when add's 2nd is mul

上级 c176bbe4
......@@ -24,40 +24,57 @@
#include "pre_activate/common/helper.h"
namespace mindspore {
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(add);
for (size_t index = 1; index < add->size(); ++index) {
auto input = add->input(index);
MS_EXCEPTION_IF_NULL(input);
if (input->isa<CNode>()) {
auto cnode = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
if (!opt::IsUsedByOthers(graph, cnode)) {
*mul = cnode;
*mul_index = index;
return true;
}
}
}
}
return false;
}
namespace opt {
const BaseRef MulAddFusion::DefinePattern() const {
VarPtr mul_x_ = std::make_shared<Var>();
VarPtr mul_y_ = std::make_shared<Var>();
VarPtr add_y_ = std::make_shared<Var>();
VectorRef mul({prim::kPrimMul, mul_x_, mul_y_});
VectorRef add({prim::kPrimTensorAdd, mul, add_y_});
return add;
VarPtr x = std::make_shared<Var>();
VarPtr y = std::make_shared<Var>();
VectorRef pattern({prim::kPrimTensorAdd, x, y});
return pattern;
}
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const {
if (graph == nullptr || node == nullptr || equiv == nullptr) {
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
if (graph == nullptr || node == nullptr) {
return nullptr;
}
auto add = node->cast<CNodePtr>();
if (add == nullptr || add->inputs().size() != kAddInputNum) {
return nullptr;
}
auto mul_anf = add->input(1);
if (mul_anf == nullptr) {
return nullptr;
}
auto mul = mul_anf->cast<CNodePtr>();
if (mul == nullptr || mul->inputs().size() != kMulInputNum) {
return nullptr;
}
if (IsUsedByOthers(graph, mul)) {
MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse";
CNodePtr mul = nullptr;
size_t mul_index = 0;
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {
MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs";
return nullptr;
}
auto prim = std::make_shared<Primitive>(kFusedMulAddOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), mul->input(1), mul->input(2), add->input(2)};
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
for (size_t index = 1; index < mul->size(); ++index) {
inputs.push_back(mul->input(index));
}
inputs.push_back(add->input(add->size() - mul_index));
auto fusion_node = graph->NewCNode(inputs);
fusion_node->set_scope(add->scope());
fusion_node->set_abstract(add->abstract());
......
......@@ -28,8 +28,28 @@ class TestHWMulAddFusion : public BackendCommon {
UT::PyFuncGraphFetcher get_py_fun_;
};
TEST_F(TestHWMulAddFusion, test_mul_add_fusion) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before");
TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before1");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 3; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::MulAddFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_mul_add_fusion", "before2");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
......
......@@ -21,7 +21,6 @@ fused_mul_add = Primitive('FusedMulAdd')
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
class FnDict:
def __init__(self):
self.fnDict = {}
......@@ -32,16 +31,21 @@ class FnDict:
def __getitem__(self, name):
return self.fnDict[name]
def test_mul_add_fusion(tag):
fns = FnDict()
@fns
def before(x, y, z):
def before1(x, y, z):
res = mul(x, y)
res = add(res, z)
return res
@fns
def before2(x, y, z):
res = mul(x, y)
res = add(z, res)
return res
@fns
def after(x, y, z):
res = fused_mul_add(x, y, z)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册