未验证 提交 db47dec5 编写于 作者: J JingZhuangzhuang 提交者: GitHub

add function to disable trt op by output name (#49497)

上级 4779c2c1
...@@ -135,6 +135,16 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( ...@@ -135,6 +135,16 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
<< " is diabled by config in TensorRT"; << " is diabled by config in TensorRT";
return false; return false;
} }
for (const auto &out_var : node->Op()->OutputNames()) {
for (const auto &var_name : node->Op()->Output(out_var)) {
if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), var_name) !=
trt_disabled_ops.end()) {
VLOG(3) << node->Op()->Type().c_str()
<< " is diabled by config in TensorRT";
return false;
}
}
}
bool is_ok = tensorrt::OpTeller::Global().Tell( bool is_ok = tensorrt::OpTeller::Global().Tell(
node, no_calib_int8, with_dynamic_shape); node, no_calib_int8, with_dynamic_shape);
if (!is_ok) if (!is_ok)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册