未验证 提交 e4d20cdd 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] forbid int32 tensor as arg_max's input (#51926)

* forbid int32 tensor as arg_max's input
上级 a2cbc81a
......@@ -733,6 +733,22 @@ struct SimpleOpTypeSetTeller : public Teller {
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
auto x_dtype = x_var_desc->GetDataType();
if (!(x_dtype == framework::proto::VarType::FP32 ||
x_dtype == framework::proto::VarType::FP16)) {
return false;
}
int axis = desc.HasAttr("axis")
? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis"))
: -1;
......@@ -2483,7 +2499,21 @@ struct SimpleOpTypeSetTeller : public Teller {
if (op_type == "top_k_v2" || op_type == "top_k") {
auto* block = desc.Block();
auto x_var_name = desc.Input("X")[0];
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* x_var_desc = block->FindVar(x_var_name);
auto x_dtype = x_var_desc->GetDataType();
if (!(x_dtype == framework::proto::VarType::FP32 ||
x_dtype == framework::proto::VarType::FP16)) {
return false;
}
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "top_k/top_k_v2 does not support 1-dimensional input in "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册