未验证 提交 021085e3 编写于 作者: Z zhoutianzi666 提交者: GitHub

forbid ops who have 1D intermediate tensor entering Paddle-TRT (#49378)

上级 121eaea7
...@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif #endif
} }
// In static shape mode in TRT, we can't allow that op's input is a // In static shape in Paddle-TRT, we can't allow that one op has a
// 1D-tensor So we filter it here. Some op like elementwise having "Y" too, // 1D intermediate tensor as input.
// but that is dealt with in the specified op, here just the common case
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
std::string X_name;
auto inputs = desc.Inputs(); auto inputs = desc.Inputs();
if (inputs.count("X") && !desc.Input("X").empty()) { for (auto iter : inputs) {
X_name = desc.Input("X")[0]; for (auto var_name : iter.second) {
} else if (inputs.count("Input") && !desc.Input("Input").empty()) {
X_name = desc.Input("Input")[0];
}
auto* block = desc.Block(); auto* block = desc.Block();
if (block) { if (block) {
auto* x_var_desc = block->FindVar(X_name); auto* var_desc = block->FindVar(var_name);
// Can't get feed op's TensorDesc // Can't get feed op's TensorDesc
if (op_type != "feed" && x_var_desc && !x_var_desc->Persistable()) { if (op_type != "feed" && var_desc && !var_desc->Persistable()) {
const auto x_shape = x_var_desc->GetShape(); const auto shape = var_desc->GetShape();
if (x_shape.size() == 1) return false; if (shape.size() == 1) return false;
}
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册