From 603ba5e01d71cf237e6506ea1e83a4d52a3c0ccc Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 19 Oct 2018 22:38:41 +0800 Subject: [PATCH] add seqconv eltadd relu pass --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 50 +++++++++ .../framework/ir/graph_pattern_detector.h | 25 +++++ .../ir/seqconv_eltadd_relu_fuse_pass.cc | 101 ++++++++++++++++++ .../ir/seqconv_eltadd_relu_fuse_pass.h | 38 +++++++ paddle/fluid/inference/analysis/analyzer.h | 23 ++-- 6 files changed, 227 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index abab290e7..d2429d5b2 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) pass_library(conv_bn_fuse_pass inference) +pass_library(seqconv_eltadd_relu_fuse_pass inference) if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(conv_relu_mkldnn_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4664953c6..067467097 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -349,6 +349,11 @@ PDNode *PDNode::assert_is_op() { return this; } +// PDNode *PDNode::assert_op_attr() { +// asserts_.emplace_back([](Node *x) { return x && x->IsOp(); }); +// return this; +// } + PDNode *PDNode::assert_is_op(const std::string &op_type) { asserts_.emplace_back([op_type](Node *x) { return x && x->IsOp() && x->Op()->Type() == op_type; @@ -761,6 +766,51 @@ PDNode *patterns::ConvReLU::operator()( return relu_out_var; } +PDNode *patterns::SeqConvEltAddRelu::operator()( + paddle::framework::ir::PDNode *seqconv_input) { + // Create Operators + seqconv_input->assert_is_op_input("sequence_conv", "X"); + auto *seqconv_op = + pattern->NewNode(seqconv_repr())->assert_is_op("sequence_conv"); + // ->assert_op_attr("paddingTrainable", false) + // ->assert_op_attr("contextStride", 1) + + auto *eltadd_op = + pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add"); + auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); + // Create variables + // Filter + auto *seqconv_weight_var = + pattern->NewNode(seqconv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("sequence_conv", "Filter"); + // Bias + auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr()) + ->AsInput() + ->assert_is_op_input("elementwise_add"); + // intermediate variable, will be removed in the IR after fuse. + auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("sequence_conv") + ->assert_is_op_input("elementwise_add"); + auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("elementwise_add") + ->assert_is_only_input_of_op("relu"); + // output + auto *relu_out_var = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu"); + + seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var}) + .LinksTo({seqconv_out_var}); + eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var}) + .LinksTo({eltadd_out_var}); + relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var}); + return relu_out_var; +} + PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, bool with_bias) { // Create shared nodes. diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index cdd6413d9..558eea353 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -434,6 +434,31 @@ struct ConvReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; +// SEQCONV with Elementwise_Add ReLU +// op: seqconv + elementwise_add + relu +// named nodes: +// seqconv_input, seqconv_weight, +// seqconv_out, seqconv, +// elementwise_add_bias, elementwise_add_out, elementwise_add +// relu_out, relu +struct SeqConvEltAddRelu : public PatternBase { + SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {} + + PDNode* operator()(PDNode* seqconv_input); + + // declare operator node's name + PATTERN_DECL_NODE(seqconv); + PATTERN_DECL_NODE(eltadd); + PATTERN_DECL_NODE(relu); + // declare variable node's name + PATTERN_DECL_NODE(seqconv_weight); + PATTERN_DECL_NODE(seqconv_out); + PATTERN_DECL_NODE(eltadd_bias); + PATTERN_DECL_NODE(eltadd_out); + PATTERN_DECL_NODE(relu_out); +}; + // FC with bias // op: mul + elementwise_add // named nodes: diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc new file mode 100644 index 000000000..0a1f65d27 --- /dev/null +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h" +#include +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X")) + ->assert_is_op_input("sequence_conv") + ->assert_var_not_persistable(); + patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope); + fuse_pattern(x); + + // Create New OpDesc + auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight, + Node* eltadd_bias, Node* relu_out) { + OpDesc op_desc; + op_desc.SetType("fusion_seqconv_eltadd_relu"); + op_desc.SetInput("X", {input->Name()}); + op_desc.SetInput("Filter", {seqconv_weight->Name()}); + op_desc.SetInput("Bias", {eltadd_bias->Name()}); + op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength")); + op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart")); + op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride")); + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); + const std::string ColMat = patterns::UniqueKey("SeqConvColMat"); + op_desc.SetOutput("ColMat", {ColMat}); + op_desc.SetOutput("Out", {relu_out->Name()}); + scope->Var(ColMat)->GetMutable(); + + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input, op); + IR_NODE_LINK_TO(seqconv_weight, op); + IR_NODE_LINK_TO(eltadd_bias, op); + IR_NODE_LINK_TO(op, relu_out); + return op; + }; + + int fusion_count{0}; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle SeqConv EltAdd Relu fuse"; + GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern); + + fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias, + relu_out); + std::unordered_set marked_nodes( + {seqconv, seqconv_out, eltadd, eltadd_out, relu}); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + + return fusion_count; +} + +std::unique_ptr SeqConvEltAddReluFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope()); + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(seqconv_eltadd_relu_fuse_pass, + paddle::framework::ir::SeqConvEltAddReluFusePass); diff --git a/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h new file mode 100644 index 000000000..dac9de719 --- /dev/null +++ b/paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class SeqConvEltAddReluFusePass : public FusePassBase { + public: + virtual ~SeqConvEltAddReluFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"seqconv_eltadd_relu_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 6f45c6bf7..84d622fa7 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry { // larger fusion. const std::vector all_ir_passes_{{ // Manual update the passes here. - "infer_clean_graph_pass", // - "attention_lstm_fuse_pass", // - "embedding_fc_lstm_fuse_pass", // - "fc_lstm_fuse_pass", // - "mul_lstm_fuse_pass", // - "fc_gru_fuse_pass", // - "mul_gru_fuse_pass", // - "seq_concat_fc_fuse_pass", // - "fc_fuse_pass", // - "conv_bn_fuse_pass", // - "conv_eltwiseadd_bn_fuse_pass", // + "infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "seqconv_eltadd_relu_fuse_pass", // + "embedding_fc_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_eltwiseadd_bn_fuse_pass", // #ifdef PADDLE_WITH_MKLDNN "conv_relu_mkldnn_fuse_pass", // #endif -- GitLab