From bf653d8e6e33973a90c5cbc1f91ca2002aed92c8 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 29 Dec 2020 14:24:26 +0800 Subject: [PATCH] [Inference] Solve 2.0 trt performance reduce compare 1.8. (#29925) (#29964) --- paddle/fluid/framework/ir/CMakeLists.txt | 6 +- .../ir/adaptive_pool2d_convert_global_pass.cc | 61 ++++++++ .../ir/adaptive_pool2d_convert_global_pass.h | 42 ++++++ ...ptive_pool2d_convert_global_pass_tester.cc | 67 +++++++++ .../fluid/framework/ir/pass_tester_helper.h | 23 ++- .../ir/unsqueeze2_eltwise_fuse_pass.cc | 134 ++++++++++++++++++ .../ir/unsqueeze2_eltwise_fuse_pass.h | 45 ++++++ .../ir/unsqueeze2_eltwise_fuse_pass_tester.cc | 65 +++++++++ .../inference/api/paddle_pass_builder.cc | 13 +- 9 files changed, 447 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc create mode 100644 paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h create mode 100644 paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass_tester.cc create mode 100644 paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index e1f9a236b7e..760e237bcc1 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -88,6 +88,8 @@ pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) +pass_library(adaptive_pool2d_convert_global_pass inference) +pass_library(unsqueeze2_eltwise_fuse_pass inference) if(WITH_GPU) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) @@ -141,7 +143,9 @@ cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_test cc_test(test_fc_elementwise_layernorm_fuse_pass SRCS fc_elementwise_layernorm_fuse_pass_tester.cc DEPS fc_elementwise_layernorm_fuse_pass) cc_test(test_skip_layernorm_fuse_pass SRCS skip_layernorm_fuse_pass_tester.cc DEPS skip_layernorm_fuse_pass) cc_test(test_multihead_matmul_fuse_pass SRCS multihead_matmul_fuse_pass_tester.cc DEPS multihead_matmul_fuse_pass) -cc_test(test_conv_bn_fuse_pass SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) +cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_fuse_pass) +cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass) +cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass) if(WITH_GPU) cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass) cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass) diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc new file mode 100644 index 00000000000..a05a2bfa777 --- /dev/null +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h" + +#include +#include + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const { + std::string name_scope = "adaptive_pool2d_convert_global_pass"; + FusePassBase::Init(name_scope, graph); + int num = 0; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->HasAttr("adaptive") && op->HasAttr("ksize")) { + bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive")); + std::vector ksize = + BOOST_GET_CONST(std::vector, op->GetAttr("ksize")); + if (adaptive && ksize.size() == 2 && ksize[0] == 1 && ksize[1] == 1) { + op->SetAttr("adaptive", false); + op->SetAttr("global_pooling", true); + ++num; + } + } + } + } + // LOG(INFO) << "--- processed " << num << " nodes"; + AddStatis(num); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(adaptive_pool2d_convert_global_pass, + paddle::framework::ir::AdaptivePool2dConvertGlobalPass); + +REGISTER_PASS_CAPABILITY(adaptive_pool2d_convert_global_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "pool2d", 0)); diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h new file mode 100644 index 00000000000..f16f030d518 --- /dev/null +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +/* + * Update pool2d's attr to speed up trt engine. + * + * when adaptive=true, ksize=[1,1], we turn to adaptive=false, + * global_pooling=true. + */ +class AdaptivePool2dConvertGlobalPass : public FusePassBase { + public: + virtual ~AdaptivePool2dConvertGlobalPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass_tester.cc b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass_tester.cc new file mode 100644 index 00000000000..19b0c5ca7fc --- /dev/null +++ b/paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass_tester.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(AdaptivePool2dConvertGlobalPass, basic) { + Layers layers; + auto* x = layers.data("x", {1, 92, 28, 28}); + AttributeMap attrs; + attrs["adaptive"] = true; + attrs["ksize"] = std::vector{1, 1}; + layers.pool2d(x, false, &attrs); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = + PassRegistry::Instance().Get("adaptive_pool2d_convert_global_pass"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + VLOG(3) << DebugString(graph); + + bool global_pooling = false; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "pool2d") { + if (node->Op()->HasAttr("global_pooling")) { + global_pooling = + BOOST_GET_CONST(bool, node->Op()->GetAttr("global_pooling")); + } + } + } + PADDLE_ENFORCE_EQ( + global_pooling, true, + platform::errors::PreconditionNotMet( + "The attribute of pool2d global_pooling should be true after fuse")); +} + +TEST(AdaptivePool2dConvertGlobalPass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("adaptive_pool2d_convert_global_pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(adaptive_pool2d_convert_global_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 9001402233b..6b187e538d1 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -81,18 +81,34 @@ struct Layers { return out; } - VarDesc* pool2d(VarDesc* x, bool use_cudnn) { + VarDesc* pool2d(VarDesc* x, bool use_cudnn, + const AttributeMap* attrs = nullptr) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("pool2d"); op->SetInput("X", {x->Name()}); op->SetOutput("Out", {out->Name()}); op->SetAttr("use_cudnn", use_cudnn); + if (attrs) { + for (auto& iter : *attrs) { + op->SetAttr(iter.first, iter.second); + } + } op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); return out; } + VarDesc* unsqueeze2(VarDesc* x, const std::vector axes) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("unsqueeze2"); + op->SetInput("X", {x->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("axes", axes); + return out; + } + VarDesc* relu(VarDesc* x, VarDesc* out = nullptr) { return unary_op("relu", x, out); } @@ -188,8 +204,9 @@ struct Layers { return binary_op("elementwise_add", x, y, out); } - VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { - return binary_op("elementwise_mul", x, y, out); + VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, + const AttributeMap* attrs = nullptr) { + return binary_op("elementwise_mul", x, y, out, attrs); } VarDesc* dropout(VarDesc* x, float dropout_prob, diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc new file mode 100644 index 00000000000..f984744532f --- /dev/null +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.cc @@ -0,0 +1,134 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct UnsqueezeEltwise : public PatternBase { + UnsqueezeEltwise(PDPattern *pattern, const std::string &name_scope) + : PatternBase(pattern, name_scope, "unsqueeze2_eltwise_fuse_pass") {} + + PDNode *operator()(PDNode *x, PDNode *y); + + // declare operator node's name + PATTERN_DECL_NODE(unsqz); + PATTERN_DECL_NODE(elementwise); + // declare variable node's name + PATTERN_DECL_NODE(eltwise_in_x); + PATTERN_DECL_NODE(unsqz_in); + PATTERN_DECL_NODE(unsqz_out); + PATTERN_DECL_NODE(eltwise_out); +}; + +PDNode *UnsqueezeEltwise::operator()(PDNode *x, PDNode *y) { + x->assert_is_op_input("elementwise_mul", "X"); + y->assert_is_op_input("unsqueeze2", "X"); + + auto *unsqz = pattern->NewNode(unsqz_repr())->assert_is_op("unsqueeze2"); + auto *unsqz_out = pattern->NewNode(unsqz_out_repr()) + ->assert_is_op_output("unsqueeze2", "Out") + ->assert_is_op_input("elementwise_mul", "Y"); + unsqz->LinksFrom({y}).LinksTo({unsqz_out}); + + auto *elementwise = + pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_mul"); + auto *eltwise_out = pattern->NewNode(eltwise_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_mul"); + + elementwise->LinksFrom({x, unsqz_out}).LinksTo({eltwise_out}); + return eltwise_out; +} + +} // namespace patterns + +void UnsqueezeEltwiseFusePass::ApplyImpl(ir::Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + FusePassBase::Init("unsqueeze2_eltwise_fuse_pass", graph); + int found_subgraph_count = 0; + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode("unsqueeze2_eltwise_fuse_pass/x") + ->AsInput() + ->assert_is_op_input("elementwise_mul", "X") + ->assert_var_not_persistable(); + auto *y = gpd.mutable_pattern() + ->NewNode("unsqueeze2_eltwise_fuse_pass/y") + ->AsInput() + ->assert_is_op_input("unsqueeze2", "X") + ->assert_var_not_persistable(); + patterns::UnsqueezeEltwise fused_pattern(gpd.mutable_pattern(), + "unsqueeze2_eltwise_fuse_pass"); + fused_pattern(x, y); + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *graph) { + if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + + VLOG(4) << "handle UnsqueezeEltwise fuse"; + GET_IR_NODE_FROM_SUBGRAPH(eltwise_op, elementwise, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(unsqz_op, unsqz, fused_pattern); + GET_IR_NODE_FROM_SUBGRAPH(unsqz_out, unsqz_out, fused_pattern); + + size_t eltwise_in_x_rank = (subgraph.at(x)->Var()->GetShape()).size(); + size_t unsqz_in_rank = (subgraph.at(y)->Var()->GetShape()).size(); + std::vector unsqz_op_axes = + BOOST_GET_CONST(std::vector, unsqz_op->Op()->GetAttr("axes")); + int eltwise_op_axis = + BOOST_GET_CONST(int, eltwise_op->Op()->GetAttr("axis")); + + if (eltwise_in_x_rank == 4 && unsqz_in_rank == 2 && + unsqz_op_axes == std::vector{2, 3} && eltwise_op_axis == -1) { + eltwise_op->Op()->SetAttr("axis", 0); + eltwise_op->Op()->SetInput("Y", {subgraph.at(y)->Name()}); + IR_NODE_LINK_TO(subgraph.at(x), eltwise_op); + IR_NODE_LINK_TO(subgraph.at(y), eltwise_op); + IR_NODE_LINK_TO(eltwise_op, eltwise_out); + GraphSafeRemoveNodes(graph, {unsqz_op, unsqz_out}); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(unsqueeze2_eltwise_fuse_pass, + paddle::framework::ir::UnsqueezeEltwiseFusePass); +REGISTER_PASS_CAPABILITY(unsqueeze2_eltwise_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("unsqueeze2", 0) + .EQ("elementwise_mul", 0)); diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h new file mode 100644 index 00000000000..3be29f0e028 --- /dev/null +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +// |(rank 4) |(rank 2) |(rank 4) |(rank 2) +// | unsqueeze2(axes=[2,3]) | | +// | | fuse \ / +// |------elementwise_mul(axis=-1) -> elementwise_mul(axis=0) +// | | +// | | +// +// Notice: +// the rank of input is obtained from var_desc, +// it maybe change in runtime. +class UnsqueezeEltwiseFusePass : public FusePassBase { + public: + virtual ~UnsqueezeEltwiseFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass_tester.cc b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass_tester.cc new file mode 100644 index 00000000000..067a37c611a --- /dev/null +++ b/paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass_tester.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/unsqueeze2_eltwise_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(UnsqueezeEltwiseFusePass, basic) { + Layers layers; + auto* x = layers.data("x", {1, 92, 28, 28}); + auto* y = layers.data("y", {1, 92}); + std::vector axes{2, 3}; + auto* unsqz_out = layers.unsqueeze2(y, axes); + AttributeMap attrs; + attrs["axis"] = -1; + layers.elementwise_mul(x, unsqz_out, nullptr, &attrs); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("unsqueeze2_eltwise_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = GetNumOpNodes(graph, "elementwise_mul"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 2, + platform::errors::PreconditionNotMet( + "The number of nodes before and after the fuse does " + "not meet expectations")); + PADDLE_ENFORCE_EQ( + num_fused_nodes_after, 1, + platform::errors::PreconditionNotMet( + "The number of fusion nodes does not meet expectations after fuse")); +} + +TEST(UnsqueezeEltwiseFusePass, pass_op_version_check) { + ASSERT_TRUE( + paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance() + .IsPassCompatible("unsqueeze2_eltwise_fuse_pass")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(unsqueeze2_eltwise_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1448d565661..6c255b67199 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -71,7 +71,8 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ - "conv_affine_channel_fuse_pass", // + "conv_affine_channel_fuse_pass", // + "adaptive_pool2d_convert_global_pass", "conv_eltwiseadd_affine_channel_fuse_pass", // "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // @@ -81,10 +82,11 @@ const std::vector kTRTSubgraphPasses({ "embedding_eltwise_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v2", // "skip_layernorm_fuse_pass", // - "conv_bn_fuse_pass", // - "fc_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", + "conv_bn_fuse_pass", // + "fc_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // @@ -207,6 +209,7 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_transpose_reshape_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass", + // "fc_act_mkldnn_fuse_pass", "batch_norm_act_fuse_pass", "mkldnn_inplace_pass", // This pass should be activated after // fuses -- GitLab