From 1448520d45d18c7272332f1d10247ab1c287b234 Mon Sep 17 00:00:00 2001 From: shentanyue <34421038+shentanyue@users.noreply.github.com> Date: Mon, 30 May 2022 21:39:23 +0800 Subject: [PATCH] [TensorRT] Fix delete fill_constant pass (#43053) * update lite compile cmake * Update delete_fill_constant_op_pass.cc * Update analysis_config.cc --- .../ir/delete_fill_constant_op_pass.cc | 20 ++++++++++++------- .../inference/analysis/ir_pass_manager.cc | 5 +++++ paddle/fluid/inference/api/analysis_config.cc | 5 ----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc index e86bb2926b6..79a06572d14 100644 --- a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc @@ -30,13 +30,19 @@ void FillConstData(LoDTensor* out_t, T value) { void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init("delete_fill_constant_op_pass", graph); GraphPatternDetector detector; - auto fill_constant_op = detector.mutable_pattern() - ->NewNode("fill_constant") - ->assert_is_op("fill_constant") - ->assert_is_not_op_input("ValueTensor") - ->assert_is_not_op_input("str_value") - ->assert_is_not_op_input("ShapeTensor") - ->assert_is_not_op_input("ShapeTensorList"); + auto fill_constant_op = + detector.mutable_pattern() + ->NewNode("fill_constant") + ->assert_is_op("fill_constant") + ->assert_is_not_op_input("ValueTensor") + ->assert_is_not_op_input("str_value") + ->assert_is_not_op_input("ShapeTensor") + ->assert_is_not_op_input("ShapeTensorList") + ->assert_more([&](Node* node) { + return node->Op() + ->GetAttrIfExists>("shape") + .size() == 1; + }); auto fill_constant_out = detector.mutable_pattern() ->NewNode("fill_constant_out") diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index b2d8afaa7b4..aafbe57e05f 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -273,6 +273,11 @@ std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) { PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); } + // delete_fill_constant_op_pass is not apply under trt dynamic shape + if (pass->Type() == "delete_fill_constant_op_pass") { + bool use_dynamic = pass->Get("with_dynamic_shape"); + if (use_dynamic) continue; + } graph.reset(pass->Apply(graph.release())); } return graph; diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index adc3fc46f72..735e1b7be4c 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -633,11 +633,6 @@ void AnalysisConfig::Update() { (pass == "conv_bn_fuse_pass")) { continue; } - // delete_fill_constant_op_pass is not used under trt dynamic shape - if ((!min_input_shape_.empty() || trt_tuned_dynamic_shape_) && - pass == "delete_fill_constant_op_pass") { - continue; - } pass_builder()->AppendPass(pass); } } -- GitLab