From 799e43471c5d10ec5d8fcd211e1d9003de585c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Fri, 12 May 2023 11:06:48 +0800 Subject: [PATCH] Add datatype for index_put in ops.yaml (#53715) This PR add data_type for selecting which arg's datatype to instantiate template type T for index_put kernel Related PR #53652 --- paddle/phi/api/yaml/ops.yaml | 1 + paddle/phi/kernels/cpu/index_put_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/index_put_kernel.cc | 2 +- paddle/phi/kernels/gpu/index_put_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/index_put_kernel.cu | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 8e3a6c2f204..38aca867e27 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -877,6 +877,7 @@ func : IndexPutInferMeta kernel : func : index_put + data_type : x inplace : (x -> out) backward : index_put_grad diff --git a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc index 7374bcd403d..7c8ac0624e2 100644 --- a/paddle/phi/kernels/cpu/index_put_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_grad_kernel.cc @@ -183,7 +183,7 @@ void IndexPutGradKernel(const Context& dev_ctx, x.dtype(), value.dtype(), phi::errors::InvalidArgument( - "The data type of tensor in indices must be same to the data type " + "The data type of tensor value must be same to the data type " "of tensor x.")); std::vector tmp_args; std::vector int_indices_v = diff --git a/paddle/phi/kernels/cpu/index_put_kernel.cc b/paddle/phi/kernels/cpu/index_put_kernel.cc index da3e37ac242..b35f2e1982e 100644 --- a/paddle/phi/kernels/cpu/index_put_kernel.cc +++ b/paddle/phi/kernels/cpu/index_put_kernel.cc @@ -102,7 +102,7 @@ void IndexPutKernel(const Context& dev_ctx, x.dtype(), value.dtype(), phi::errors::InvalidArgument( - "The data type of tensor in indices must be same to the data type " + "The data type of tensor value must be same to the data type " "of tensor x.")); PADDLE_ENFORCE_EQ(indices.empty(), false, diff --git a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu index 7ae1e42c067..9dca49ee7ff 100644 --- a/paddle/phi/kernels/gpu/index_put_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_grad_kernel.cu @@ -214,7 +214,7 @@ void IndexPutGradKernel(const Context& dev_ctx, x.dtype(), value.dtype(), phi::errors::InvalidArgument( - "The data type of tensor in indices must be same to the data type " + "The data type of tensor value must be same to the data type " "of tensor x.")); std::vector tmp_args; std::vector int_indices_v = diff --git a/paddle/phi/kernels/gpu/index_put_kernel.cu b/paddle/phi/kernels/gpu/index_put_kernel.cu index ad27993c352..fd4476fe11d 100644 --- a/paddle/phi/kernels/gpu/index_put_kernel.cu +++ b/paddle/phi/kernels/gpu/index_put_kernel.cu @@ -110,7 +110,7 @@ void IndexPutKernel(const Context& dev_ctx, x.dtype(), value.dtype(), phi::errors::InvalidArgument( - "The data type of tensor in indices must be same to the data type " + "The data type of tensor value must be same to the data type " "of tensor x.")); PADDLE_ENFORCE_EQ(indices.empty(), false, -- GitLab