未验证 提交 fd1730e4 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] argmax support int64 dtype (#49167)

上级 baa98d1d
...@@ -691,7 +691,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -691,7 +691,7 @@ struct SimpleOpTypeSetTeller : public Teller {
: -1; : -1;
bool flatten = PADDLE_GET_CONST(bool, desc.GetAttr("flatten")); bool flatten = PADDLE_GET_CONST(bool, desc.GetAttr("flatten"));
int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype"));
if (axis == 0 || flatten || dtype != 2) return false; if (axis == 0 || flatten || (dtype != 2 && dtype != 3)) return false;
} }
if (op_type == "arg_min") { if (op_type == "arg_min") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册