diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a595a8ab4299298f625b8322a0adbed6d0b4fda3..42fb6a1aa5375bfbb266454cfbc7f0fb756f779c 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -48,6 +48,17 @@ pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference) +pass_library(transpose_flatten_concat_fuse_pass inference) + +# There may be many transpose-flatten structures in a model, and the output of +# these structures will be used as inputs to the concat Op. This pattern will +# be detected by our pass. The index here represents the number of structures in the +# pattern. We use index 3 ~ 6, because these quantities of structures are +# common in the models. +foreach (index RANGE 3 6) + file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n") +endforeach() + if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c513fe2dd8f5733c87802f6fa9980ad885dfd865..6282ced1e47329915bb3626b410e55ad8251071d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1306,6 +1306,69 @@ PDNode *patterns::ConvAffineChannel::operator()( return ac_out_var; } +// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a +// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b +// ... +// z -> transpose_op(n) -> transpose_out_z -> flatten_op(n) -> flatten_out_z +// flatten_out_a -> concat_op flatten_out_b -> concat_op ... flatten_out_z -> +// concat_op +PDNode *patterns::TransposeFlattenConcat::operator()( + std::vector conv_in, int times) { + // The times represents the repeat times of the + // {trans, trans_out, flatten, flatten_out} + const int kNumFields = 4; + const int kTransOutOffset = 1; + const int kFlattenOffset = 2; + const int kFlattenOutOffset = 3; + + std::vector nodes; + + for (int i = 0; i < times; i++) { + nodes.push_back( + pattern->NewNode(GetNodeName("transpose" + std::to_string(i))) + ->assert_is_op("transpose2")); + nodes.push_back( + pattern->NewNode(GetNodeName("transpose_out" + std::to_string(i))) + ->assert_is_op_output("transpose2") + ->assert_is_op_input("flatten2", "X") + ->AsIntermediate()); + nodes.push_back(pattern->NewNode(GetNodeName("flatten" + std::to_string(i))) + ->assert_is_op("flatten2")); + + nodes.push_back( + pattern->NewNode(GetNodeName("flatten_out" + std::to_string(i))) + ->assert_is_op_output("flatten2") + ->assert_is_op_nth_input("concat", "X", i) + ->AsIntermediate()); + } + + auto concat_op = pattern->NewNode(GetNodeName("concat")) + ->assert_is_op("concat") + ->assert_op_has_n_inputs("concat", times); + auto concat_out = pattern->NewNode(GetNodeName("concat_out")) + ->assert_is_op_output("concat") + ->AsOutput(); + + std::vector flatten_outs; + for (int i = 0; i < times; i++) { + conv_in[i]->AsInput(); + // trans + nodes[i * kNumFields]->LinksFrom({conv_in[i]}); + // trans_out + nodes[i * kNumFields + kTransOutOffset]->LinksFrom({nodes[i * kNumFields]}); + // flatten + nodes[i * kNumFields + kFlattenOffset]->LinksFrom( + {nodes[i * kNumFields + kTransOutOffset]}); + // flatten_out + nodes[i * kNumFields + kFlattenOutOffset]->LinksFrom( + {nodes[i * kNumFields + kFlattenOffset]}); + flatten_outs.push_back(nodes[i * kNumFields + kFlattenOutOffset]); + } + + concat_op->LinksFrom(flatten_outs).LinksTo({concat_out}); + return concat_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 61a53003449710da2a52c90197c9f2f3ac56c7bb..c8be586f546dc604375401b13a801841efbf08d2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -766,6 +766,21 @@ struct ConvAffineChannel : public PatternBase { PATTERN_DECL_NODE(ac_out); // Out }; +struct TransposeFlattenConcat : public PatternBase { + TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "transpose_flatten_concat") {} + + PDNode* operator()(std::vector conv_inputs, int times); + + std::string GetNodeName(const std::string& op_type) { + return PDNodeName(name_scope_, repr_, id_, op_type); + } + + PDNode* GetPDNode(const std::string& op_type) { + return pattern->RetrieveNode(GetNodeName(op_type)); + } +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..fda43948d567689103815e3ad7ba285719dae80f --- /dev/null +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +std::unique_ptr TransposeFlattenConcatFusePass::ApplyImpl( + std::unique_ptr graph) const { + const std::string pattern_name = + "transpose_flatten" + std::to_string(times) + "_concat_fuse"; + FusePassBase::Init(pattern_name, graph.get()); + + GraphPatternDetector gpd; + std::vector input_nodes; + for (int i = 0; i < times; i++) { + input_nodes.push_back(gpd.mutable_pattern() + ->NewNode("x" + std::to_string(i)) + ->assert_is_op_input("transpose2", "X") + ->AsInput()); + } + + patterns::TransposeFlattenConcat pattern(gpd.mutable_pattern(), pattern_name); + pattern(input_nodes, times); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + const int kNumFields = 5; + const int kTransOffset = 1; + const int kTransOutOffset = 2; + const int kFlattenOffset = 3; + const int kFlattenOutOffset = 4; + std::vector nodes; + + for (int i = 0; i < times; i++) { + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); + PADDLE_ENFORCE( + subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); + PADDLE_ENFORCE(subgraph.at(input_nodes[i])); + + nodes.push_back(subgraph.at(input_nodes[i])); + nodes.push_back( + subgraph.at(pattern.GetPDNode("transpose" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("transpose_out" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("flatten" + std::to_string(i)))); + nodes.push_back( + subgraph.at(pattern.GetPDNode("flatten_out" + std::to_string(i)))); + } + + Node *concat_op = subgraph.at(pattern.GetPDNode("concat")); + Node *concat_out = subgraph.at(pattern.GetPDNode("concat_out")); + std::vector input_names; + std::vector trans_axis = boost::get>( + nodes[kTransOffset]->Op()->GetAttr("axis")); + int flatten_axis = + boost::get(nodes[kFlattenOffset]->Op()->GetAttr("axis")); + int concat_axis = boost::get(concat_op->Op()->GetAttr("axis")); + std::string output_name = concat_out->Name(); + + for (int i = 0; i < times; i++) { + input_names.push_back(nodes[i * kNumFields]->Name()); + } + + framework::OpDesc new_op_desc; + new_op_desc.SetType("fusion_transpose_flatten_concat"); + new_op_desc.SetInput("X", input_names); + new_op_desc.SetAttr("trans_axis", trans_axis); + new_op_desc.SetAttr("flatten_axis", flatten_axis); + new_op_desc.SetAttr("concat_axis", concat_axis); + new_op_desc.SetOutput("Out", {output_name}); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto *new_conv_op = graph->CreateOpNode(&new_op_desc); + + std::unordered_set delete_nodes; + + for (int i = 0; i < times; i++) { + nodes[i * kNumFields]->outputs.push_back(new_conv_op); + new_conv_op->inputs.push_back(nodes[i * kNumFields]); + delete_nodes.insert(nodes[i * kNumFields + kTransOffset]); + delete_nodes.insert(nodes[i * kNumFields + kTransOutOffset]); + delete_nodes.insert(nodes[i * kNumFields + kFlattenOffset]); + delete_nodes.insert(nodes[i * kNumFields + kFlattenOutOffset]); + } + delete_nodes.insert(concat_op); + + new_conv_op->outputs.push_back(concat_out); + concat_out->inputs.push_back(new_conv_op); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph.get(), delete_nodes); + }; + + gpd(graph.get(), handler); + return graph; +} + +template class TransposeFlattenConcatFusePass<1>; +template class TransposeFlattenConcatFusePass<3>; +template class TransposeFlattenConcatFusePass<4>; +template class TransposeFlattenConcatFusePass<5>; +template class TransposeFlattenConcatFusePass<6>; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(transpose_flatten_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<1>); + +REGISTER_PASS(transpose_flatten3_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<3>); + +REGISTER_PASS(transpose_flatten4_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<4>); + +REGISTER_PASS(transpose_flatten5_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<5>); + +REGISTER_PASS(transpose_flatten6_concat_fuse_pass, + paddle::framework::ir::TransposeFlattenConcatFusePass<6>); diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..fb0f0ae9efdc5a25a799d6123fa658a99860cd86 --- /dev/null +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// There may be many transpose-flatten structures in a model, and the output of +// these structures will be used as inputs to the concat Op. This pattern will +// be detected by our pass. The times here represents the repeat times of this +// structure. +template +class TransposeFlattenConcatFusePass : public FusePassBase { + public: + virtual ~TransposeFlattenConcatFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 211c691504de2c0bd8ff50f34b92cbc01397d5c9..336ab426c21d9de93693c44d8fc6bc5b37b58864 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -127,6 +127,7 @@ void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, use_tensorrt_ = true; tensorrt_workspace_size_ = workspace_size; tensorrt_max_batchsize_ = max_batch_size; + Update(); } void contrib::AnalysisConfig::Update() { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 1e5712e1638ea802dfa9c3b41ab1d3f7f62f090b..de9650735adfe158e72213d4f6d5d3569aa90d55 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -141,6 +141,10 @@ class GpuPassStrategy : public PassStrategy { "conv_elementwise_add_fuse_pass", // }); + for (int i = 6; i >= 3; i--) { + passes_.push_back("transpose_flatten" + std::to_string(i) + + "_concat_fuse_pass"); + } use_gpu_ = true; } diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 6975086193d991dc9f53b2d9d988f960c8ad118d..79362f9677010247dffa4fbaa155a7a56eed6f85 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -39,6 +39,7 @@ class ElementwiseWeightOpConverter : public OpConverter { const framework::Scope& scope, bool test_mode) override { // Here the two nullptr looks strange, that's because the // framework::OpDesc's constructor is strange. + nvinfer1::ILayer* layer = nullptr; framework::OpDesc op_desc(op, nullptr); VLOG(3) << "Convert a fluid elementwise op to TensorRT IScaleLayer"; @@ -98,13 +99,21 @@ class ElementwiseWeightOpConverter : public OpConverter { 0}; TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; + if (op_type_ == "add") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get()); + layer = scale_layer; + } else if (op_type_ == "mul") { + nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *X, scale_mode, scale_weights.get(), + shift_weights.get(), power_weights.get()); + layer = scale_layer; + } - nvinfer1::IScaleLayer* layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *const_cast(X), scale_mode, - shift_weights.get(), scale_weights.get(), power_weights.get()); auto output_name = op_desc.Output("Out")[0]; - - layer->setName(("elementwise_add (Output: " + output_name + ")").c_str()); + layer->setName( + ("elementwise_" + op_type_ + "(Output: " + output_name + ")").c_str()); layer->getOutput(0)->setName(output_name.c_str()); engine_->weight_map[op_desc.Input("Y").front()] = std::move(weight_tensor); engine_->SetITensor(output_name, layer->getOutput(0)); @@ -113,6 +122,9 @@ class ElementwiseWeightOpConverter : public OpConverter { engine_->DeclareOutput(output_name); } } + + protected: + std::string op_type_; }; class ElementwiseTensorOpConverter : public OpConverter { @@ -188,6 +200,16 @@ const std::unordered_map {"max", nvinfer1::ElementWiseOperation::kMAX}, }; +class ElementwiseWeightAddOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightAddOpConverter() { op_type_ = "add"; } +}; + +class ElementwiseWeightMulOpConverter : public ElementwiseWeightOpConverter { + public: + ElementwiseWeightMulOpConverter() { op_type_ = "mul"; } +}; + class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter { public: ElementwiseTensorAddOpConverter() { op_type_ = "add"; } @@ -227,7 +249,10 @@ class ElementwiseTensorPowOpConverter : public ElementwiseTensorOpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, ElementwiseWeightOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_add_weight, + ElementwiseWeightAddOpConverter); +REGISTER_TRT_OP_CONVERTER(elementwise_mul_weight, + ElementwiseWeightMulOpConverter); REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor, ElementwiseTensorAddOpConverter);