diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 9d86eef7f157f435f32ea593f7d25162ed1f0052..3691df73ef676a936430ac90b7108aed60ff1d66 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -170,12 +170,24 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #undef CP_MEMBER - // Update(); - // Update() will reset all the passes, when some tensorRT pass is deleted in - // other.pass_builder(), it will set again, so just copy the passes. - pass_builder_->ClearPasses(); - for (const std::string &pass : other.pass_builder()->AllPasses()) { - pass_builder_->AppendPass(pass); + Update(); + if (use_tensorrt_) { + // Update() will reset all the passes, when some tensorRT pass is deleted in + // other.pass_builder(), it will set again, so we just remove the + // deleted_pass. + auto all_passes = kTRTSubgraphPasses; + auto other_passes = other.pass_builder()->AllPasses(); + // We should sort them, because the user may call the SwitchIrDebug + // interface, which will change the pass. + std::sort(all_passes.begin(), all_passes.end()); + std::sort(other_passes.begin(), other_passes.end()); + std::vector deleted_passes; + std::set_difference(all_passes.begin(), all_passes.end(), + other_passes.begin(), other_passes.end(), + std::inserter(deleted_passes, deleted_passes.begin())); + for (auto ps : deleted_passes) { + pass_builder_->DeletePass(ps); + } } } diff --git a/paddle/fluid/inference/api/api_tester.cc b/paddle/fluid/inference/api/api_tester.cc index 2c450ef7cead4d5c3870d5e9186eb221e5dc19a0..80ec84c8d952246ed26a48e83e760d987864be24 100644 --- a/paddle/fluid/inference/api/api_tester.cc +++ b/paddle/fluid/inference/api/api_tester.cc @@ -67,4 +67,17 @@ TEST(paddle_inference_api, get_version) { ASSERT_FALSE(version.empty()); } +TEST(paddle_inference_api, AnalysisConfigCopyCtor) { + AnalysisConfig cfg1; + cfg1.EnableUseGpu(10); + cfg1.EnableTensorRtEngine(); + std::string delete_pass("skip_layernorm_fuse_pass"); + cfg1.pass_builder()->DeletePass(delete_pass); + AnalysisConfig cfg2(cfg1); + + auto passes = cfg2.pass_builder()->AllPasses(); + for (auto ps : passes) { + CHECK_NE(ps, delete_pass); + } +} } // namespace paddle