未验证 提交 021085e3 编写于 作者: Z zhoutianzi666 提交者: GitHub

forbid ops who have 1D intermediate tensor entering Paddle-TRT (#49378)

上级 121eaea7
......@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
}
// In static shape mode in TRT, we can't allow that op's input is a
// 1D-tensor So we filter it here. Some op like elementwise having "Y" too,
// but that is dealt with in the specified op, here just the common case
// In static shape in Paddle-TRT, we can't allow that one op has a
// 1D intermediate tensor as input.
if (!with_dynamic_shape) {
std::string X_name;
auto inputs = desc.Inputs();
if (inputs.count("X") && !desc.Input("X").empty()) {
X_name = desc.Input("X")[0];
} else if (inputs.count("Input") && !desc.Input("Input").empty()) {
X_name = desc.Input("Input")[0];
}
auto* block = desc.Block();
if (block) {
auto* x_var_desc = block->FindVar(X_name);
// Can't get feed op's TensorDesc
if (op_type != "feed" && x_var_desc && !x_var_desc->Persistable()) {
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) return false;
for (auto iter : inputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
// Can't get feed op's TensorDesc
if (op_type != "feed" && var_desc && !var_desc->Persistable()) {
const auto shape = var_desc->GetShape();
if (shape.size() == 1) return false;
}
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册