未验证 提交 f3ba7253 编写于 作者: S Shang Zhizhou 提交者: GitHub

Trt del pass (#29316)

* cherry-pick a22ea652

* fix analysis_config bug. (#29304)

* fix code format
Co-authored-by: NWilber <jiweibo@baidu.com>
上级 b94edba7
......@@ -170,12 +170,24 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER
// Update();
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so just copy the passes.
pass_builder_->ClearPasses();
for (const std::string &pass : other.pass_builder()->AllPasses()) {
pass_builder_->AppendPass(pass);
Update();
if (use_tensorrt_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the
// deleted_pass.
auto all_passes = kTRTSubgraphPasses;
auto other_passes = other.pass_builder()->AllPasses();
// We should sort them, because the user may call the SwitchIrDebug
// interface, which will change the pass.
std::sort(all_passes.begin(), all_passes.end());
std::sort(other_passes.begin(), other_passes.end());
std::vector<std::string> deleted_passes;
std::set_difference(all_passes.begin(), all_passes.end(),
other_passes.begin(), other_passes.end(),
std::inserter(deleted_passes, deleted_passes.begin()));
for (auto ps : deleted_passes) {
pass_builder_->DeletePass(ps);
}
}
}
......
......@@ -67,4 +67,17 @@ TEST(paddle_inference_api, get_version) {
ASSERT_FALSE(version.empty());
}
TEST(paddle_inference_api, AnalysisConfigCopyCtor) {
AnalysisConfig cfg1;
cfg1.EnableUseGpu(10);
cfg1.EnableTensorRtEngine();
std::string delete_pass("skip_layernorm_fuse_pass");
cfg1.pass_builder()->DeletePass(delete_pass);
AnalysisConfig cfg2(cfg1);
auto passes = cfg2.pass_builder()->AllPasses();
for (auto ps : passes) {
CHECK_NE(ps, delete_pass);
}
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册