diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b7f7e2ee8ef590c0d0d8307de4400a8ce8ad4e7d..6d795e1e2d5407ecacf5fb4af539919d72bff404 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -45,6 +45,7 @@ pass_library(is_test_pass base) pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) pass_library(conv_elementwise_add_fuse_pass inference) +pass_library(conv_affine_channel_fuse_pass inference) if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7bfb8cf1ee09e78051e2f140c9a7ab4c40db60c --- /dev/null +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -0,0 +1,222 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_CONV_BN_NODES(pattern_name) \ + /* OPERATORS */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(affine_channel, affine_channel, pattern_name); \ + /* CONV inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \ + /* CONV outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \ + /* Affine Channel inputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(ac_scale, ac_scale, pattern_name); \ + GET_IR_NODE_FROM_SUBGRAPH(ac_bias, ac_bias, pattern_name); \ + /* Affine channel outputs */ \ + GET_IR_NODE_FROM_SUBGRAPH(ac_out, ac_out, pattern_name); /* Out */ + +void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, + const ir::Node& ac_scale, + const LoDTensor& ac_bias_tensor, + LoDTensor* eltwise_y_in_tensor) { + using EigenVectorArrayMap = + Eigen::Map>; + using ConstEigenVectorArrayMap = + Eigen::Map>; + using EigenMatrixArrayMap = Eigen::Map< + Eigen::Array>; + + // Re-compute bias of conv2d from AffineChannel + PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), ac_bias_tensor.dims()); + + auto* scale_tensor = scope->FindVar(ac_scale.Name())->GetMutable(); + + ConstEigenVectorArrayMap scale_array(scale_tensor->data(), + scale_tensor->numel(), 1); + ConstEigenVectorArrayMap ac_bias_array(ac_bias_tensor.data(), + ac_bias_tensor.numel(), 1); + + EigenVectorArrayMap eltwise_y_in_array( + eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 1); + + eltwise_y_in_array = (eltwise_y_in_array * scale_array) + ac_bias_array; + + // Re-compute weight of conv2d from AffineChannel + auto* weights = scope->FindVar(conv_weight->Name())->GetMutable(); + auto weights_shape = weights->dims(); + auto weights_shape_2d = flatten_to_2d(weights_shape, 1); + + EigenMatrixArrayMap weights_array_2d( + weights->mutable_data(platform::CPUPlace()), weights_shape_2d[0], + weights_shape_2d[1]); + + weights_array_2d.colwise() *= scale_array; +} + +std::unique_ptr ConvAffineChannelFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(), + name_scope_); + conv_ac_pattern(conv_input, false /*with_eltwise_add*/); + + int found_conv_ac_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvAffineChannel fuse"; + + GET_CONV_BN_NODES(conv_ac_pattern); + + // check if fuse can be done and if MKL-DNN should be used + FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel); + if (fuse_option == DO_NOT_FUSE) { + VLOG(3) << "do not perform conv+affinechannel fuse"; + return; + } + + // Create eltwise_y (conv bias) variable + VarDesc eltwise_y_in_desc( + patterns::PDNodeName(name_scope_, "eltwise_y_in")); + eltwise_y_in_desc.SetPersistable(true); + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); + auto* eltwise_y_in_tensor = + scope->Var(eltwise_y_in_node->Name())->GetMutable(); + + // Get affine_channel bias + auto* ac_bias_tensor = + scope->FindVar(ac_bias->Name())->GetMutable(); + + // Initialize eltwise_y + eltwise_y_in_tensor->Resize(ac_bias_tensor->dims()); + std::fill_n(eltwise_y_in_tensor->mutable_data(platform::CPUPlace()), + eltwise_y_in_tensor->numel(), 0.0f); + + // update weights and biases + recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor, + eltwise_y_in_tensor); + + // create an elementwise add node. + OpDesc desc; + desc.SetInput("X", std::vector({conv_out->Name()})); + desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); + desc.SetOutput("Out", std::vector({ac_out->Name()})); + desc.SetType("elementwise_add"); + desc.SetAttr("axis", 1); + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. + + GraphSafeRemoveNodes(graph.get(), {ac_scale, ac_bias, affine_channel}); + + IR_NODE_LINK_TO(conv_out, eltwise_op); + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); + IR_NODE_LINK_TO(eltwise_op, ac_out); + found_conv_ac_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_ac_count); + return graph; +} + +std::unique_ptr ConvEltwiseAddAffineChannelFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init(name_scope_, graph.get()); + + auto* scope = param_scope(); + PADDLE_ENFORCE(scope); + + GraphPatternDetector gpd; + auto* conv_input = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvAffineChannel conv_ac_pattern(gpd.mutable_pattern(), + name_scope_); + conv_ac_pattern(conv_input, true /*with_eltwise_add*/); + + int found_conv_ac_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBN fuse"; + + GET_CONV_BN_NODES(conv_ac_pattern); + // OPERATORS + GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern); + // BIAS inputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_ac_pattern); + // BIAS outputs + GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_ac_pattern); + + // Get eltwise_y (conv bias) variable + auto* eltwise_y_in_tensor = + scope->FindVar(eltwise_y_in->Name())->GetMutable(); + + // Get batch norm bias + auto* ac_bias_tensor = + scope->FindVar(ac_bias->Name())->GetMutable(); + + recompute_bias_and_weights(scope, conv_weight, *ac_scale, *ac_bias_tensor, + eltwise_y_in_tensor); + + // Update the elementwise_add node + eltwise->Op()->SetAttr("axis", 1); + eltwise->Op()->SetOutput("Out", std::vector({ac_out->Name()})); + + GraphSafeRemoveNodes(graph.get(), + {ac_scale, ac_bias, affine_channel, eltwise_out}); + + IR_NODE_LINK_TO(eltwise, ac_out); + + found_conv_ac_count++; + }; + + gpd(graph.get(), handler); + AddStatis(found_conv_ac_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_affine_channel_fuse_pass, + paddle::framework::ir::ConvAffineChannelFusePass); +REGISTER_PASS(conv_eltwiseadd_affine_channel_fuse_pass, + paddle::framework::ir::ConvEltwiseAddAffineChannelFusePass); diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..ad966e11e6222a4ed4c730089c454b0d1c7bd0b3 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h @@ -0,0 +1,49 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the Conv and ConvAffineChannel. + */ +class ConvAffineChannelFusePass : public FusePassBase { + public: + virtual ~ConvAffineChannelFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_affine_channel_fuse"}; +}; + +class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { + public: + virtual ~ConvEltwiseAddAffineChannelFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"conv_eltwiseadd_affine_channel_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 73d1a3da8fc921eea77a2bddbe1cb63bd9832ea3..c513fe2dd8f5733c87802f6fa9980ad885dfd865 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1234,6 +1234,78 @@ PDNode *patterns::ConvElementwiseadd::operator()(PDNode *conv_in) { return elementwise_add_out; } +PDNode *patterns::ConvAffineChannel::operator()( + paddle::framework::ir::PDNode *conv_input, bool with_eltwise_add) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + + PDNode *eltwise_op = nullptr; + if (with_eltwise_add) { + eltwise_op = + pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); + } + + auto *affine_channel_op = + pattern->NewNode(affine_channel_repr())->assert_is_op("affine_channel"); + // Create variables + // Conv Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d"); + + PDNode *eltwise_y_in_var = nullptr; + PDNode *eltwise_out_var = nullptr; + if (with_eltwise_add) { + // Conv output as Bias input + conv_out_var->assert_is_op_input("elementwise_add", "X"); + // Bias + eltwise_y_in_var = pattern->NewNode(eltwise_y_in_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); + eltwise_out_var = pattern->NewNode(eltwise_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("elementwise_add"); + } else { + // Conv output as AffineChannel input + conv_out_var->assert_is_op_input("affine_channel", "X"); + } + + // AC Scale + auto *ac_scale_var = pattern->NewNode(ac_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("affine_channel", "Scale"); + // AC Bias + auto *ac_bias_var = pattern->NewNode(ac_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("affine_channel", "Bias"); + + // AC output + auto *ac_out_var = pattern->NewNode(ac_out_repr()) + ->AsOutput() + ->assert_is_op_output("affine_channel"); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + + if (with_eltwise_add) { + eltwise_op->LinksFrom({conv_out_var, eltwise_y_in_var}) + .LinksTo({eltwise_out_var}); + affine_channel_op->LinksFrom({eltwise_out_var, ac_scale_var, ac_bias_var}) + .LinksTo({ac_out_var}); + } else { + affine_channel_op->LinksFrom({conv_out_var, ac_scale_var, ac_bias_var}) + .LinksTo({ac_out_var}); + } + return ac_out_var; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index eaedd9d08e0fab820481d6eaacb6e7bfc1ab6d1d..61a53003449710da2a52c90197c9f2f3ac56c7bb 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -734,6 +734,38 @@ struct ConvElementwiseadd : public PatternBase { PATTERN_DECL_NODE(elementwise_add_out); }; +// Conv with affine_channel +// op: conv + (elementwise_add +) affine_channel +// named nodes: +// conv_weight, conv_out, conv, +// ac_x, ac_scale, ac_bias +// affine_channel, ac_out +struct ConvAffineChannel : public PatternBase { + ConvAffineChannel(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_affine_channel") {} + + PDNode* operator()(PDNode* conv_input, bool with_eltwise_add); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(affine_channel); + PATTERN_DECL_NODE(eltwise); // ELEMENTWISE_ADD + // CONV inputs + PATTERN_DECL_NODE(conv_weight); // Filter + // CONV outputs + PATTERN_DECL_NODE(conv_out); // tmp + // ELTWISE inputs + PATTERN_DECL_NODE(eltwise_y_in); + // ELTWISE outputs + PATTERN_DECL_NODE(eltwise_out); // tmp + + // AC(Affine_Channel) inputs + PATTERN_DECL_NODE(ac_scale); + PATTERN_DECL_NODE(ac_bias); + // AC outputs + PATTERN_DECL_NODE(ac_out); // Out +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 83d411eecf6d706615243fd78cb7e4330d904fc1..2db5705d0944b2ab10defdda9a7b616daa8fd47e 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -127,6 +127,7 @@ struct Argument { std::function); DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int); DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int); + DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int); // The program transformed by IR analysis phase. DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 51bca8039d4531536cd7a3c39ef8a27f1a5412a1..b8c9426ed3b62d35f78247269cb32d2f6344b092 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -75,6 +75,8 @@ void IRPassManager::CreatePasses(Argument *argument, argument->tensorrt_node_teller_ptr()); pass->Set("workspace_size", new int(argument->tensorrt_workspace_size())); pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size())); + pass->Set("min_subgraph_size", + new int(argument->tensorrt_min_subgraph_size())); } // graph_ = pass->Apply(std::move(graph_)); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 9c42b83e7add348433635b1899087324e4e370d4..ad10010e42be9717e3298fc88c89764e4ae2690b 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" +#include #include #include + #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" +#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" namespace paddle { namespace inference { @@ -36,7 +38,8 @@ std::unique_ptr analysis::TensorRtSubgraphPass::ApplyImpl( auto teller = Get("tensorrt_node_teller"); - SubGraphFuser fuser(graph.get(), teller, 2 /*min subgraph size*/); + SubGraphFuser fuser(graph.get(), teller, + Get("min_subgraph_size") /*min subgraph size*/); fuser(); for (auto *node : graph->Nodes()) { @@ -197,10 +200,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node, std::vector ExtractParameters( const std::unordered_set &nodes) { + // We can judge whether a variable is a parameter by + // its presistable property, but sometimes the presistable + // of the feed op output is true, so we have to identify it. + std::vector feed_outputs; + for (const auto &node : nodes) { + if (!node->IsOp()) continue; + std::string op_type = node->Op()->Type(); + if (op_type == "feed") { + std::vector output_names = node->Op()->OutputArgumentNames(); + std::copy(output_names.begin(), output_names.end(), + std::back_inserter(feed_outputs)); + } + } + std::vector parameters; for (const auto &node : nodes) { if (!node->IsVar()) continue; - if (node->Var()->Persistable()) { + if (node->Var()->Persistable() && + std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) == + feed_outputs.end()) { parameters.push_back(node->Name()); } } @@ -215,4 +234,5 @@ REGISTER_PASS(tensorrt_subgraph_pass, paddle::inference::analysis::TensorRtSubgraphPass) .RequirePassAttr("tensorrt_node_teller") .RequirePassAttr("max_batch_size") - .RequirePassAttr("workspace_size"); + .RequirePassAttr("workspace_size") + .RequirePassAttr("min_subgraph_size"); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index dcefdd92f5157dce7426f2f3e4a2bc053ce24775..6d6e799fdec9c67b4714f203b91b8bccb61510ba 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -57,6 +57,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_; + tensorrt_min_subgraph_size_ = other.tensorrt_min_subgraph_size_; model_from_memory_ = other.model_from_memory_; if (use_gpu) { @@ -89,6 +90,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_workspace_size_ = other.tensorrt_workspace_size_; + tensorrt_min_subgraph_size_ = other.tensorrt_min_subgraph_size_; model_from_memory_ = other.model_from_memory_; pass_builder_ = std::move(other.pass_builder_); @@ -105,12 +107,14 @@ void contrib::AnalysisConfig::EnableMKLDNN() { } void contrib::AnalysisConfig::EnableTensorRtEngine(int workspace_size, - int max_batch_size) { + int max_batch_size, + int min_subgraph_size) { use_tensorrt_ = true; tensorrt_workspace_size_ = workspace_size; tensorrt_max_batchsize_ = max_batch_size; - // Append after the infer_clean pass. - pass_builder()->InsertPass(1, "tensorrt_subgraph_pass"); + tensorrt_min_subgraph_size_ = min_subgraph_size; + // Append after the conv+affine_channel fuse pass. + pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); } void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer, diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 3937884ce4a5a16a1093ac8977033eaa98b2678e..3f8feaaa1e9f91c2ea342ba9227305ad6eb34033 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -328,6 +328,7 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.SetUseTensorRT(true); argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); + argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); } if (config_.use_mkldnn_) { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index f05b9832da55f10b34eb2df914e443a478e5a4a4..e7ccea6587a250d9d931fa0e85146e32af714d26 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -49,7 +49,7 @@ struct AnalysisConfig : public NativeConfig { bool use_feed_fetch_ops{true}; void EnableTensorRtEngine(int workspace_size = 1 << 20, - int max_batch_size = 1); + int max_batch_size = 1, int min_subgraph_size = 3); bool use_tensorrt() const { return use_tensorrt_; } void EnableMKLDNN(); @@ -69,8 +69,19 @@ struct AnalysisConfig : public NativeConfig { bool use_tensorrt_{false}; bool use_mkldnn_{false}; std::unordered_set mkldnn_enabled_op_types_; + // For workspace_size, refer it from here: + // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting int tensorrt_workspace_size_; + // While TensorRT allows an engine optimized for a given max batch size + // to run at any smaller size, the performance for those smaller + // sizes may not be as well-optimized. Therefore, Max batch is best + // equivalent to the runtime batch size. int tensorrt_max_batchsize_; + // We transform the Ops that can be converted into TRT layer in the model, + // and aggregate these Ops into subgraphs for TRT execution. + // We set this variable to control the minimum number of nodes in the + // subgraph, 3 as default value. + int tensorrt_min_subgraph_size_{3}; std::unique_ptr pass_builder_; bool model_from_memory_{false}; }; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 40ca0d287ccde113a20abb1036af289a36f54e6c..1062ac5f58b90d8649dae8bacc9ce154b8b9d844 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -118,11 +118,13 @@ class GpuPassStrategy : public PassStrategy { public: GpuPassStrategy() : PassStrategy({}) { passes_.assign({ - "infer_clean_graph_pass", // - "conv_bn_fuse_pass", // - "conv_elementwise_add_act_fuse_pass", // - "conv_elementwise_add2_act_fuse_pass", // - "conv_elementwise_add_fuse_pass", // + "infer_clean_graph_pass", // + "conv_affine_channel_fuse_pass", // + "conv_eltwiseadd_affine_channel_fuse_pass", // + "conv_bn_fuse_pass", // + "conv_elementwise_add_act_fuse_pass", // + "conv_elementwise_add2_act_fuse_pass", // + "conv_elementwise_add_fuse_pass", // }); } diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc index acceadab16493177754d3a8e27ee800a90c0adc8..e73762f5fb2386633212c5aa9fc768153cf63f85 100644 --- a/paddle/fluid/operators/conv_fusion_op.cu.cc +++ b/paddle/fluid/operators/conv_fusion_op.cu.cc @@ -161,9 +161,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, "workspace_size to be allocated exceeds the limit"); - if ((activation == "identity") && - (algo != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) && - (!residual)) { + if ((activation == "identity") && (!residual)) { // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. // But test in some case, the speed is slower, change to use