From a5d2a6d1addf918c1f9ea30d677e260c80e201d7 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sun, 13 Jan 2019 11:12:56 +0000 Subject: [PATCH] add fuse pass of sequared mat sub fusion --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/squared_mat_sub_fuse_pass.cc | 379 ++++++++++++++++++ .../framework/ir/squared_mat_sub_fuse_pass.h | 41 ++ .../fluid/inference/api/paddle_pass_builder.h | 1 + 4 files changed, 422 insertions(+) create mode 100644 paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index c888f96d910..84b53212647 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -44,6 +44,7 @@ pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference) pass_library(repeated_fc_relu_fuse_pass inference) +pass_library(squared_mat_sub_fuse_pass inference) pass_library(is_test_pass base) pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc new file mode 100644 index 00000000000..7134fecf8d2 --- /dev/null +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -0,0 +1,379 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h" +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, + const std::string& name_scope) { + auto var_is_op_input = [=](Node* x, const std::string& op_type, + const std::string& arg_name = "") -> bool { + if (!(x && x->IsVar())) { + return false; + } + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + if (arg_name.empty()) { + return true; + } + for (auto& name : op->Op()->Input(arg_name)) { + if (name == x->Name()) { + return true; + } + } + } + } + return false; + }; + + auto var_is_op_only_output = [](Node* x, const std::string& op_type) -> bool { + return x && x->IsVar() && x->inputs.size() == 1 && x->inputs[0] && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == op_type && + x->inputs[0]->outputs.size() == 1; + }; + + auto next_op = [=](Node* x, const std::string& op_type) -> Node* { + if (!(x && x->IsVar())) { + return false; + } + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return op; + } + } + return nullptr; + }; + + auto get_op_input_var = [=](Node* x, const std::string& arg_name) -> Node* { + if (!(x && x->IsOp())) { + return false; + } + for (auto* var : x->inputs) { + for (auto name : x->Op()->Input(arg_name)) { + if (var->Name() == name) { + return var; + } + } + } + return nullptr; + }; + + auto is_fusion_input_var = [=](Node* x, const std::string& arg_name) { + bool basic = var_is_op_input(x, "matmul", arg_name) && + var_is_op_input(x, "square", "X"); + if (!basic) { + return false; + } + auto* squared_x_op = next_op(x, "square"); + if (!(squared_x_op && squared_x_op->outputs.size() == 1)) { + return false; + } + auto* squared_x = squared_x_op->outputs[0]; + bool next_is_matmul_from_arg = + var_is_op_input(squared_x, "matmul", arg_name) && + squared_x->outputs.size() == 1 && + squared_x->outputs[0]->outputs.size() == 1; + if (!next_is_matmul_from_arg) { + return false; + } + auto* sub_x = squared_x->outputs[0]->outputs[0]; + return var_is_op_input(sub_x, "elementwise_sub", "X") && + sub_x->outputs[0]->outputs.size() == 1 && + var_is_op_input(sub_x->outputs[0]->outputs[0], "elementwise_mul"); + }; + + auto is_fusion_first_mul_out = [=](Node* x) -> bool { + bool input_is_matmul_op = x && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && + x->inputs[0]->Op()->Type() == "matmul"; + if (!input_is_matmul_op) { + return false; + } + auto* mat_x = get_op_input_var(x->inputs[0], "X"); + auto* mat_y = get_op_input_var(x->inputs[0], "Y"); + bool input_mul_is_valid = mat_x && is_fusion_input_var(mat_x, "X") && + mat_y && is_fusion_input_var(mat_y, "Y"); + if (!input_mul_is_valid) { + return false; + } + + bool next_is_square = var_is_op_input(x, "square", "X") && + x->outputs.size() == 1 && + x->outputs[0]->outputs.size() == 1; + if (!next_is_square) { + return false; + } + auto* sub_y = x->outputs[0]->outputs[0]; + return var_is_op_input(sub_y, "elementwise_sub", "Y") && + sub_y->outputs[0]->outputs.size() == 1 && + var_is_op_input(sub_y->outputs[0]->outputs[0], "elementwise_mul"); + }; + + auto* x = pattern->NewNode( + [=](Node* x) { return is_fusion_input_var(x, "X"); }, name_scope + "/x"); + + auto* y = pattern->NewNode( + [=](Node* x) { return is_fusion_input_var(x, "Y"); }, name_scope + "/y"); + + auto* square_x_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_input_var(x->inputs[0], "X"); + }, + name_scope + "/squared_x_op"); + + auto* square_y_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_input_var(x->inputs[0], "Y"); + }, + name_scope + "/squared_y_op"); + + auto* squared_x = pattern->NewNode( + [=](Node* x) { + return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(x->inputs[0]->inputs[0], "X"); + }, + name_scope + "/squared_x"); + + auto* squared_y = pattern->NewNode( + [=](Node* x) { + return x && x->inputs.size() == 1 && x->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(x->inputs[0]->inputs[0], "Y"); + }, + name_scope + "/squared_y"); + + auto* matmuled_xy = + pattern->NewNode([=](Node* x) { return is_fusion_first_mul_out(x); }, + name_scope + "/matmuled_xy"); + + auto* matmul_xy_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "matmul" && + is_fusion_first_mul_out(x->outputs[0]); + }, + name_scope + "/matmul_xy_op"); + + auto* square_matmuled_xy_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "square" && + is_fusion_first_mul_out(x->inputs[0]); + }, + name_scope + "/square_matmuled_xy_op"); + + auto* squared_xmuly = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "square" && + is_fusion_first_mul_out(x->inputs[0]->inputs[0]); + }, + name_scope + "/squared_xmuly"); + + auto is_fusion_mat_squared_x_y_op_out = [=](Node* x) -> bool { + bool basic = x && x->IsVar() && x->inputs.size() == 1 && + x->inputs[0]->IsOp() && x->inputs[0]->Op()->Type() == "matmul"; + if (!basic) { + return false; + } + auto* sqx = get_op_input_var(x->inputs[0], "X"); + auto* sqy = get_op_input_var(x->inputs[0], "Y"); + + return var_is_op_only_output(sqx, "square") && + var_is_op_only_output(sqy, "square") && sqx->inputs[0] && + sqx->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(sqx->inputs[0]->inputs[0], "X") && + sqy->inputs[0] && sqy->inputs[0]->inputs.size() == 1 && + is_fusion_input_var(sqy->inputs[0]->inputs[0], "Y"); + }; + + auto* matmul_squared_x_y_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "matmul" && + is_fusion_mat_squared_x_y_op_out(x->outputs[0]); + }, + name_scope + "/matmul_squared_x_y_op"); + + auto* mat_squared_x_y_op_out = pattern->NewNode( + [=](Node* x) { return is_fusion_mat_squared_x_y_op_out(x); }, + name_scope + "/mat_squared_x_y_op_out"); + + auto is_fusion_sub_op = [=](Node* x) -> bool { + bool is_sub_op = x && x->IsOp() && x->Op()->Type() == "elementwise_sub"; + if (!is_sub_op) { + return false; + } + auto* matmul_sqx_sqy_var = get_op_input_var(x, "X"); + return is_fusion_mat_squared_x_y_op_out(matmul_sqx_sqy_var); + }; + + auto* sub_op = pattern->NewNode([=](Node* x) { return is_fusion_sub_op(x); }, + name_scope + "/sub_op"); + + auto* sub_op_out = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + is_fusion_sub_op(x->inputs[0]); + }, + name_scope + "/sub_op_out"); + + auto is_fusion_element_op = [=](Node* x) -> bool { + bool is_elemul_op = x && x->IsOp() && x->Op()->Type() == "elementwise_mul"; + if (!is_elemul_op) { + return false; + } + for (auto* in : x->inputs) { + if (in && in->inputs[0] && is_fusion_sub_op(in->inputs[0])) { + return true; + } + } + return false; + }; + + auto* elementmul_op = + pattern->NewNode([=](Node* x) { return is_fusion_element_op(x); }, + name_scope + "/elementmul_op"); + + auto* constant_op = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "fill_constant" && + x->outputs.size() == 1 && + is_fusion_element_op(x->outputs[0]->outputs[0]); + }, + name_scope + "/fill_constant_op"); + + auto* constant_op_out = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && var_is_op_input(x, "elementwise_mul") && + x->inputs[0] && x->inputs[0]->IsOp() && + x->inputs[0]->Op()->Type() == "fill_constant" && x->outputs[0] && + is_fusion_element_op(x->outputs[0]); + }, + name_scope + "/constant_op_out"); + + auto* last_out_var = pattern->NewNode( + [=](Node* x) { + return var_is_op_only_output(x, "elementwise_mul") && + is_fusion_element_op(x->inputs[0]); + }, + name_scope + "/out"); + + square_x_op->LinksFrom({x}).LinksTo({squared_x}); + square_y_op->LinksFrom({y}).LinksTo({squared_y}); + matmul_xy_op->LinksFrom({x, y}).LinksTo({matmuled_xy}); + matmul_squared_x_y_op->LinksFrom({squared_x, squared_y}) + .LinksTo({mat_squared_x_y_op_out}); + square_matmuled_xy_op->LinksFrom({matmuled_xy}).LinksTo({squared_xmuly}); + sub_op->LinksFrom({mat_squared_x_y_op_out, squared_xmuly}) + .LinksTo({sub_op_out}); + constant_op->LinksFrom({}).LinksTo({constant_op_out}); + elementmul_op->LinksFrom({constant_op_out, sub_op_out}) + .LinksTo({last_out_var}); + + return last_out_var; +} + +static int BuildFusion(Graph* graph, const std::string& name_scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + BuildSquaredMatSubPattern(pattern, name_scope); + + auto retrieve_node = [](const std::string& name, + const GraphPatternDetector::subgraph_t& subgraph, + const PDPattern& pat) -> Node* { + PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), + "pattern has no Node called %s", name.c_str()); + Node* p = subgraph.at(pat.RetrieveNode(name)); + PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); + return p; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + LOG(INFO) << "handle sqaure mat sub fuse"; + auto& fused_pattern = gpd.pattern(); + + auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern); + auto* maty = retrieve_node(name_scope + "/y", subgraph, fused_pattern); + auto* squaredx = + retrieve_node(name_scope + "/squared_x", subgraph, fused_pattern); + auto* squaredy = + retrieve_node(name_scope + "/squared_y", subgraph, fused_pattern); + auto* squaredxy = + retrieve_node(name_scope + "/squared_xmuly", subgraph, fused_pattern); + auto* last_out_var = + retrieve_node(name_scope + "/out", subgraph, fused_pattern); + auto* fill_constant_op = retrieve_node(name_scope + "/fill_constant_op", + subgraph, fused_pattern); + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_squared_mat_sub"); + op_desc.SetInput("X", {matx->Name()}); + op_desc.SetInput("Y", {maty->Name()}); + op_desc.SetOutput("SquaredX", {squaredx->Name()}); + op_desc.SetOutput("SquaredY", {squaredy->Name()}); + op_desc.SetOutput("SquaredXY", {squaredxy->Name()}); + op_desc.SetOutput("Out", {last_out_var->Name()}); + op_desc.SetAttr("scalar", fill_constant_op->Op()->GetAttr("value")); + + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(matx, op); + IR_NODE_LINK_TO(maty, op); + IR_NODE_LINK_TO(op, squaredx); + IR_NODE_LINK_TO(op, squaredy); + IR_NODE_LINK_TO(op, squaredxy); + IR_NODE_LINK_TO(op, last_out_var); + + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + + marked_nodes.erase(matx); + marked_nodes.erase(maty); + marked_nodes.erase(squaredx); + marked_nodes.erase(squaredy); + marked_nodes.erase(squaredxy); + marked_nodes.erase(last_out_var); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +std::unique_ptr SquaredMatSubFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + int fusion_count = BuildFusion(graph.get(), name_scope_); + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(squared_mat_sub_fuse_pass, + paddle::framework::ir::SquaredMatSubFusePass); diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h new file mode 100644 index 00000000000..5d94c3b07e2 --- /dev/null +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Fuse ( (A.^2 * B.^2) - (A * B).^2 ) .* scalar + */ +class SquaredMatSubFusePass : public FusePassBase { + public: + virtual ~SquaredMatSubFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"squared_mat_sub"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index aea0a6914ec..efe1ba106a2 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -99,6 +99,7 @@ class CpuPassStrategy : public PassStrategy { "seq_concat_fc_fuse_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // + "squared_mat_sub_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // -- GitLab