未验证 提交 a22ea652 编写于 作者: W Wilber 提交者: GitHub

fix trt delete_pass bug. (#28763)

上级 1dad8cea
......@@ -175,12 +175,20 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER
// Update();
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so just copy the passes.
pass_builder_->ClearPasses();
for (const std::string &pass : other.pass_builder()->AllPasses()) {
pass_builder_->AppendPass(pass);
Update();
if (use_tensorrt_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the
// deleted_pass.
auto all_passes = kTRTSubgraphPasses;
auto other_passes = other.pass_builder()->AllPasses();
std::vector<std::string> deleted_passes;
std::set_difference(all_passes.begin(), all_passes.end(),
other_passes.begin(), other_passes.end(),
std::inserter(deleted_passes, deleted_passes.begin()));
for (auto ps : deleted_passes) {
pass_builder_->DeletePass(ps);
}
}
}
......
......@@ -77,4 +77,18 @@ TEST(paddle_inference_api, UpdateDllFlag) {
LOG(INFO) << e.what();
}
}
TEST(paddle_inference_api, AnalysisConfigCopyCtor) {
AnalysisConfig cfg1;
cfg1.EnableUseGpu(10);
cfg1.EnableTensorRtEngine();
std::string delete_pass("skip_layernorm_fuse_pass");
cfg1.pass_builder()->DeletePass(delete_pass);
AnalysisConfig cfg2(cfg1);
auto passes = cfg2.pass_builder()->AllPasses();
for (auto ps : passes) {
CHECK_NE(ps, delete_pass);
}
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册