未验证 提交 021085e3 编写于 作者: Z zhoutianzi666 提交者: GitHub

forbid ops who have 1D intermediate tensor entering Paddle-TRT (#49378)

上级 121eaea7
...@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif #endif
} }
// In static shape mode in TRT, we can't allow that op's input is a // In static shape in Paddle-TRT, we can't allow that one op has a
// 1D-tensor So we filter it here. Some op like elementwise having "Y" too, // 1D intermediate tensor as input.
// but that is dealt with in the specified op, here just the common case
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
std::string X_name;
auto inputs = desc.Inputs(); auto inputs = desc.Inputs();
if (inputs.count("X") && !desc.Input("X").empty()) { for (auto iter : inputs) {
X_name = desc.Input("X")[0]; for (auto var_name : iter.second) {
} else if (inputs.count("Input") && !desc.Input("Input").empty()) { auto* block = desc.Block();
X_name = desc.Input("Input")[0]; if (block) {
} auto* var_desc = block->FindVar(var_name);
auto* block = desc.Block(); // Can't get feed op's TensorDesc
if (block) { if (op_type != "feed" && var_desc && !var_desc->Persistable()) {
auto* x_var_desc = block->FindVar(X_name); const auto shape = var_desc->GetShape();
// Can't get feed op's TensorDesc if (shape.size() == 1) return false;
if (op_type != "feed" && x_var_desc && !x_var_desc->Persistable()) { }
const auto x_shape = x_var_desc->GetShape(); }
if (x_shape.size() == 1) return false;
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册