未验证 提交 c53e92fc 编写于 作者: W wanghuancoder 提交者: GitHub

CastPyArg2IntArray use int64_t (#45919)

上级 0b82fb32
......@@ -1387,14 +1387,12 @@ paddle::experimental::IntArray CastPyArg2IntArray(PyObject* obj,
auto type_name = std::string(type->tp_name);
if (type_name == "list" || type_name == "tuple" ||
type_name == "numpy.ndarray") {
std::vector<int> value = CastPyArg2Ints(obj, op_type, arg_pos);
std::vector<int64_t> value = CastPyArg2Longs(obj, op_type, arg_pos);
return paddle::experimental::IntArray(value);
} else if (type_name == "paddle.Tensor" || type_name == "Tensor") {
paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::IntArray(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册