From e5e0b726e5c2c561d6afd4765bbb75d30e0ff417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 4 Apr 2022 10:01:42 +0200 Subject: [PATCH] conv + elementwise_add refactor (#41286) * DRY * change nodes names * add const prefix * change asX to as_x in all files --- .../framework/ir/graph_pattern_detector.cc | 23 +++ .../framework/ir/graph_pattern_detector.h | 16 ++ paddle/fluid/framework/ir/graph_traits.cc | 48 +++++ paddle/fluid/framework/ir/graph_traits.h | 3 + .../conv_elementwise_add_mkldnn_fuse_pass.cc | 166 ++---------------- .../conv_elementwise_add_mkldnn_fuse_pass.h | 16 +- ...t_mkldnn_conv_elementwise_add_fuse_pass.py | 136 +------------- 7 files changed, 113 insertions(+), 295 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 03da128920..8eb1b64a27 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2069,6 +2069,29 @@ PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var, return out_var; } +PDNode *patterns::ResidualElementwise::operator()( + PDNode *op_var, PDNode *residual_var, const std::string elementwise_type, + bool as_x) { + auto elementwise_op = + pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type); + + if (as_x) { + op_var->AsInput()->assert_is_op_input(elementwise_type, "X"); + residual_var->AsInput()->assert_is_op_input(elementwise_type, "Y"); + } else { + op_var->AsInput()->assert_is_op_input(elementwise_type, "Y"); + residual_var->AsInput()->assert_is_op_input(elementwise_type, "X"); + } + auto out_var = pattern->NewNode(elementwise_out_repr()) + ->AsOutput() + ->assert_is_op_output(elementwise_type, "Out"); + + elementwise_op->LinksFrom({op_var, residual_var}); + elementwise_op->LinksTo({out_var}); + + return out_var; +} + PDNode *patterns::Concat::operator()() { auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1f253c6b91..434ede6cf7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1032,6 +1032,22 @@ struct Elementwise : public PatternBase { PATTERN_DECL_NODE(elementwise_out); }; +// Residual Elementwise ops +// This pattern allows operator output to be X or Y +// and residual data Y or X, based on as_x flag +struct ResidualElementwise : public PatternBase { + ResidualElementwise(PDPattern* pattern, const std::string& name_scope, + bool as_x) + : PatternBase(pattern, name_scope, "residual_elementwise") {} + PDNode* operator()(PDNode* op_var, PDNode* residual_var, + const std::string elementwise_type, bool as_x); + + PATTERN_DECL_NODE(operator_output); + PATTERN_DECL_NODE(residual_data); + PATTERN_DECL_NODE(elementwise_op); + PATTERN_DECL_NODE(elementwise_out); +}; + // Transpose op // Forward pass for transpose. // transpose_out is a result of the operator. diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 262a523bd8..b063145630 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include + #include "paddle/fluid/framework/ir/graph_traits.h" namespace paddle { @@ -23,6 +26,51 @@ namespace ir { // class Node; +bool IsReachable(ir::Graph *graph, Node *from, Node *to) { + if (from == to) { + return true; + } + + std::map visited; + + for (auto &node : GraphTraits::DFS(*graph)) { + visited[&node] = false; + } + + visited[from] = true; + + std::list queue; + queue.push_back(from); + + while (!queue.empty()) { + auto cur = FindNode(graph, queue.front()); + queue.pop_front(); + + if (!cur) return false; + + for (const auto &n : cur->outputs) { + if (n == to) { + return true; + } + + if (!visited[n]) { + visited[n] = true; + queue.push_back(n); + } + } + } + return false; +} + +Node *FindNode(ir::Graph *graph, const Node *node) { + for (const auto &n : graph->Nodes()) { + if (n == node) { + return n; + } + } + return nullptr; +} + NodesDFSIterator::NodesDFSIterator(const std::vector &source) { for (auto *x : source) stack_.push(x); } diff --git a/paddle/fluid/framework/ir/graph_traits.h b/paddle/fluid/framework/ir/graph_traits.h index a54cc61a63..7e313e17f4 100644 --- a/paddle/fluid/framework/ir/graph_traits.h +++ b/paddle/fluid/framework/ir/graph_traits.h @@ -29,6 +29,9 @@ namespace ir { class Graph; class Node; +bool IsReachable(ir::Graph *graph, Node *from, Node *to); +Node *FindNode(ir::Graph *graph, const Node *node); + template class iterator_range { IteratorT begin_, end_; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index fc2758c273..16c4f251e0 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -14,12 +14,6 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" -#include -#include -#include -#include -#include - #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/string/pretty_log.h" @@ -28,60 +22,6 @@ namespace paddle { namespace framework { namespace ir { -bool IsReachable(ir::Graph* graph, Node* from, Node* to) { - auto find_node = [](ir::Graph* graph, const Node* node) -> Node* { - for (auto n : graph->Nodes()) { - if (n == node) { - return n; - } - } - - return nullptr; - }; - - if (from == to) { - return true; - } - - std::map visited; - - for (auto& node : GraphTraits::DFS(*graph)) { - visited[&node] = false; - } - - visited[from] = true; - - std::list queue; - queue.push_back(from); - - while (!queue.empty()) { - auto cur = find_node(graph, queue.front()); - queue.pop_front(); - - if (!cur) return false; - - for (auto n : cur->outputs) { - if (n == to) { - return true; - } - - if (!visited[n]) { - visited[n] = true; - queue.push_back(n); - } - } - } - return false; -} - -template -paddle::optional HasAttribute(const Node& op, const std::string& attr) { - if (op.Op()->HasAttr(attr)) - return BOOST_GET_CONST(T, op.Op()->GetAttr(attr)); - else - return paddle::none; -} - ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { AddOpCompat(OpCompat("conv2d")) .AddInput("Input") @@ -136,89 +76,22 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { .End(); } -GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( - const std::string& name_scope, - const GraphWithStats& graph_with_stats) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - - patterns::Conv conv_pattern{pattern, name_scope}; - auto conv_output = conv_pattern(); - - patterns::Elementwise elementwise_pattern{pattern, name_scope}; - elementwise_pattern( - conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()), - "elementwise_add"); - conv_output->AsIntermediate(); - - int found_conv_as_x_count = 0; - - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); - GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); - - GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y, - elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, - elementwise_pattern); - - if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; - - if (!IsReachable(g, elementwise_identity, conv_output)) return; - - if (HasFusedActivation(conv_op)) return; - - if (!IsCompat(subgraph, g)) { - LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; - return; - } - - conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()}); - conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); - conv_op->Op()->SetAttr("fuse_residual_connection", true); - - GraphSafeRemoveNodes(g, {conv_output, elementwise_op}); - - IR_NODE_LINK_TO(elementwise_identity, conv_op); - IR_NODE_LINK_TO(conv_op, elementwise_out); - - found_conv_as_x_count++; - }; - - gpd(graph_with_stats.first, handler); - if (!Has("disable_logs") || !Get("disable_logs")) { - std::stringstream msg_ss; - msg_ss << "--- Fused " << found_conv_as_x_count - << " conv (as x) + elementwise_add patterns"; - paddle::string::PrettyLogDetail(msg_ss.str().c_str()); - } - - return std::make_pair(graph_with_stats.first, - found_conv_as_x_count + graph_with_stats.second); -} - -GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( - const std::string& name_scope, - const GraphWithStats& graph_with_stats) const { +GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConv( + const std::string& name_scope, const GraphWithStats& graph_with_stats, + bool as_x) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::Conv conv_pattern{pattern, name_scope}; auto conv_output = conv_pattern(); - patterns::Elementwise elementwise_pattern{pattern, name_scope}; + patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x}; elementwise_pattern( - pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output, - "elementwise_add"); + conv_output, pattern->NewNode(elementwise_pattern.residual_data_repr()), + "elementwise_add", as_x); conv_output->AsIntermediate(); - int found_conv_as_y_count = 0; + int found_conv_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -229,15 +102,13 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op, elementwise_pattern); - GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x, + GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data, elementwise_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, elementwise_pattern); if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return; - - if (!IsReachable(g, elementwise_x, conv_output)) return; - + if (!IsReachable(g, residual_data, conv_output)) return; if (HasFusedActivation(conv_op)) return; if (!IsCompat(subgraph, g)) { @@ -246,28 +117,29 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( return; } - conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()}); + conv_op->Op()->SetInput("ResidualData", {residual_data->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()}); conv_op->Op()->SetAttr("fuse_residual_connection", true); GraphSafeRemoveNodes(g, {conv_output, elementwise_op}); - IR_NODE_LINK_TO(elementwise_x, conv_op); + IR_NODE_LINK_TO(residual_data, conv_op); IR_NODE_LINK_TO(conv_op, elementwise_out); - found_conv_as_y_count++; + found_conv_count++; }; gpd(graph_with_stats.first, handler); if (!Has("disable_logs") || !Get("disable_logs")) { std::stringstream msg_ss; - msg_ss << "--- Fused " << found_conv_as_y_count - << " conv (as y) + elementwise_add patterns"; + std::string fusionMode = as_x ? "x" : "y"; + msg_ss << "--- Fused " << found_conv_count << " conv (as " << fusionMode + << ") + elementwise_add patterns"; paddle::string::PrettyLogDetail(msg_ss.str().c_str()); } return std::make_pair(graph_with_stats.first, - found_conv_as_y_count + graph_with_stats.second); + found_conv_count + graph_with_stats.second); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( @@ -308,7 +180,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( if (!IsCompat(subgraph, g)) { LOG(WARNING) - << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + << "op compat for conv_elementwise_add_mkldnn_fuse_pass failed."; return; } @@ -361,8 +233,8 @@ void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); auto graph_with_stats = FuseProjectionConv(name_scope_, std::make_pair(graph, 0)); - graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats); - graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats); + graph_with_stats = FuseConv(name_scope_, graph_with_stats, true); + graph_with_stats = FuseConv(name_scope_, graph_with_stats, false); AddStatis(graph_with_stats.second); } diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index c4351b3821..7c6e992716 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -14,30 +14,20 @@ #pragma once -#include -#include -#include -#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include - namespace paddle { namespace framework { namespace ir { using GraphWithStats = std::pair; -bool IsReachable(ir::Graph* graph, Node* from, Node* to); - class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: - GraphWithStats FuseConvAsX(const std::string& name_scope, - const GraphWithStats& graph_with_stats) const; - GraphWithStats FuseConvAsY(const std::string& name_scope, - const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseConv(const std::string& name_scope, + const GraphWithStats& graph_with_stats, + bool as_x) const; GraphWithStats FuseProjectionConv( const std::string& name_scope, const GraphWithStats& graph_with_stats) const; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py index 2e84607e2f..58d09a8806 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_elementwise_add_fuse_pass.py @@ -26,7 +26,7 @@ import hypothesis.strategies as st # the two inputs of elementwise_add are tensor -class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest): +class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: attrs = [ program_config.ops[i].attrs @@ -125,139 +125,5 @@ class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest): quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) -''' -class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): - def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) - ] - if "elementwise_weight" in program_config.weights: - if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]: - if attrs[2]['axis'] != 1: - return False - if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]: - if attrs[2]['axis'] != -1: - return False - return True - - def sample_program_config(self, draw): - data_format = draw(st.sampled_from(["NCHW", "NHWC"])) - dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) - groups = draw(st.sampled_from([1, 2, 4])) - paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]])) - strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - axis = draw(st.sampled_from([-1, 0, 1])) - batch_size = draw(st.integers(min_value=1, max_value=4)) - - def generate_input1(): - if data_format == "NCHW": - return np.random.random( - [batch_size, 48, 64, 64]).astype(np.float32) - else: - return np.random.random( - [batch_size, 64, 64, 48]).astype(np.float32) - - def generate_weight1(): - return np.random.random( - [48, int(48 / groups), 3, 3]).astype(np.float32) - - def compute_out_shape(padding_alg): - import paddle - import paddle.nn as nn - - x_var = paddle.uniform( - (batch_size, 48, 64, 64), dtype='float32', min=-1., max=1.) - if padding_alg == "EXPLICIT": - conv = nn.Conv2D(48, 48, (3, 3), strides, paddings, dilations, - 1) - else: - conv = nn.Conv2D(48, 48, (3, 3), strides, padding_alg, - dilations, 1) - y_var = conv(x_var) - return y_var.shape - - def generate_weight2(): - return np.random.random([48]).astype(np.float32) - - if compute_out_shape(padding_algorithm) != (batch_size, 48, 64, 64): - axis = 1 - - relu_op = OpConfig( - type="relu", - inputs={"X": ["input_data1"]}, - outputs={"Out": ["sigmoid_out"]}, - attrs={}) - - conv2d_op = OpConfig( - type="conv2d", - inputs={"Input": ["sigmoid_out"], - "Filter": ["conv_weight"]}, - outputs={"Output": ["conv_output"]}, - attrs={ - "data_format": data_format, - "dilations": dilations, - "padding_algorithm": padding_algorithm, - "groups": groups, - "paddings": paddings, - "strides": strides - }) - - if axis == 0: - elt_op = OpConfig( - type="elementwise_add", - inputs={"X": ["input_data1"], - "Y": ["conv_output"]}, - outputs={"Out": ["elementwise_output"]}, - attrs={'axis': axis}) - else: - elt_op = OpConfig( - type="elementwise_add", - inputs={"X": ["conv_output"], - "Y": ["elementwise_weight"]}, - outputs={"Out": ["elementwise_output"]}, - attrs={'axis': axis}) - - model_net = [relu_op, conv2d_op, elt_op] - - if axis == 0: - program_config = ProgramConfig( - ops=model_net, - weights={ - "conv_weight": - TensorConfig(data_gen=partial(generate_weight1)) - }, - inputs={ - "input_data1": - TensorConfig(data_gen=partial(generate_input1)) - }, - outputs=["elementwise_output"]) - else: - program_config = ProgramConfig( - ops=model_net, - weights={ - "conv_weight": - TensorConfig(data_gen=partial(generate_weight1)), - "elementwise_weight": - TensorConfig(data_gen=partial(generate_weight2)) - }, - inputs={ - "input_data1": - TensorConfig(data_gen=partial(generate_input1)) - }, - outputs=["elementwise_output"]) - - return program_config - - def sample_predictor_configs(self, program_config): - config = self.create_inference_config(use_mkldnn=True) - yield config, ["relu", "conv2d"], (1e-5, 1e-5) - - def test(self): - self.run_and_statis( - quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) -''' - if __name__ == "__main__": unittest.main() -- GitLab