未验证 提交 799e4347 编写于 作者: 傅剑寒 提交者: GitHub

Add datatype for index_put in ops.yaml (#53715)

This PR add data_type for selecting which arg's datatype to instantiate template type T for index_put kernel
Related PR #53652
上级 2a696fb8
...@@ -877,6 +877,7 @@ ...@@ -877,6 +877,7 @@
func : IndexPutInferMeta func : IndexPutInferMeta
kernel : kernel :
func : index_put func : index_put
data_type : x
inplace : (x -> out) inplace : (x -> out)
backward : index_put_grad backward : index_put_grad
......
...@@ -183,7 +183,7 @@ void IndexPutGradKernel(const Context& dev_ctx, ...@@ -183,7 +183,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
x.dtype(), x.dtype(),
value.dtype(), value.dtype(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type " "The data type of tensor value must be same to the data type "
"of tensor x.")); "of tensor x."));
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
......
...@@ -102,7 +102,7 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -102,7 +102,7 @@ void IndexPutKernel(const Context& dev_ctx,
x.dtype(), x.dtype(),
value.dtype(), value.dtype(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type " "The data type of tensor value must be same to the data type "
"of tensor x.")); "of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(), PADDLE_ENFORCE_EQ(indices.empty(),
false, false,
......
...@@ -214,7 +214,7 @@ void IndexPutGradKernel(const Context& dev_ctx, ...@@ -214,7 +214,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
x.dtype(), x.dtype(),
value.dtype(), value.dtype(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type " "The data type of tensor value must be same to the data type "
"of tensor x.")); "of tensor x."));
std::vector<DenseTensor> tmp_args; std::vector<DenseTensor> tmp_args;
std::vector<const phi::DenseTensor*> int_indices_v = std::vector<const phi::DenseTensor*> int_indices_v =
......
...@@ -110,7 +110,7 @@ void IndexPutKernel(const Context& dev_ctx, ...@@ -110,7 +110,7 @@ void IndexPutKernel(const Context& dev_ctx,
x.dtype(), x.dtype(),
value.dtype(), value.dtype(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The data type of tensor in indices must be same to the data type " "The data type of tensor value must be same to the data type "
"of tensor x.")); "of tensor x."));
PADDLE_ENFORCE_EQ(indices.empty(), PADDLE_ENFORCE_EQ(indices.empty(),
false, false,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册