未验证 提交 799e4347 编写于 作者: 傅剑寒 提交者: GitHub

Add datatype for index_put in ops.yaml (#53715)

This PR add data_type for selecting which arg's datatype to instantiate template type T for index_put kernel
Related PR #53652
上级 2a696fb8
......@@ -877,6 +877,7 @@
func : IndexPutInferMeta
kernel :
func : index_put
data_type : x
inplace : (x -> out)
backward : index_put_grad
......
......@@ -183,7 +183,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"The data type of tensor value must be same to the data type "
"of tensor x."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
......
......@@ -102,7 +102,7 @@ void IndexPutKernel(const Context& dev_ctx,
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"The data type of tensor value must be same to the data type "
"of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(),
false,
......
......@@ -214,7 +214,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"The data type of tensor value must be same to the data type "
"of tensor x."));
std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v =
......
......@@ -110,7 +110,7 @@ void IndexPutKernel(const Context& dev_ctx,
x.dtype(),
value.dtype(),
phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type "
"The data type of tensor value must be same to the data type "
"of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(),
false,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册