diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index d876d078541468385f0ec7fc8246025f5e590f06..c01f2d2d528cfc22ed39893326af8af51a226c97 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -155,7 +155,7 @@ class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { "We don't support dynamic index from tensor for gather composite " "grad for now. ")); } else { - prim::gather_grad(x, index, dout, axis, false, dx_ptr); + prim::gather_grad(x, index, dout, axis, dx_ptr); } this->RecoverOutputName(dx, dx_name); } diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index e5725b189fdb61cda442a057e9fe1d0175ba90ba..c722aa4858388cfbf16268e8abf5981554c74cc2 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -121,7 +121,6 @@ void gather_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, const Scalar& axis, - bool overwrite, Tensor* grad_x) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector tmp_perm; diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 0674270a272515d1cdfed70a99a3774ce909c834..3f4acc31e6a991e448c76ca8ecd6af9d6777ca3c 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -418,7 +418,7 @@ - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) - args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0, bool overwrite=false) + args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -426,7 +426,7 @@ kernel : data_type: x func : gather_grad - composite : gather_grad(x, index, out_grad, axis, overwrite, x_grad) + composite : gather_grad(x, index, out_grad, axis, x_grad) no_need_buffer : x - backward_op : group_norm_grad diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 6e6007840892bd431bb87e583c50268a5f13a855..490aff92c192aa8186bb5b70037b7901b4e5e9bb 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -955,8 +955,6 @@ - op : gather backward : gather_grad - extra : - attrs : [bool overwrite = true] - op : gather_nd backward : gather_nd_grad diff --git a/paddle/phi/kernels/cpu/gather_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_grad_kernel.cc index f7f0ac6b2e0feead936f29047d4867580b6d97d8..0d0341234e60b16c28c12167871bb8b12d683cd8 100644 --- a/paddle/phi/kernels/cpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_grad_kernel.cc @@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { const auto& index_type = index.dtype(); auto axis_v = axis.to(); @@ -52,19 +51,9 @@ void GatherGradKernel(const Context& dev_ctx, if (x_grad->numel() == 0) return; if (index_type == phi::DataType::INT32) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); - } else { - phi::funcs::ScatterAssignAdd( - dev_ctx, out_grad, index, x_grad); - } + phi::funcs::ScatterAssignAdd(dev_ctx, out_grad, index, x_grad); } else if (index_type == phi::DataType::INT64) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); - } else { - phi::funcs::ScatterAssignAdd( - dev_ctx, out_grad, index, x_grad); - } + phi::funcs::ScatterAssignAdd(dev_ctx, out_grad, index, x_grad); } else { PADDLE_THROW(phi::errors::InvalidArgument( "The data type of Input(Index) of gather_grad must be int32 or int64 " diff --git a/paddle/phi/kernels/gather_grad_kernel.h b/paddle/phi/kernels/gather_grad_kernel.h index e53da7b471c7b82efef2319915cc57537ee824b5..7c978e139e62aaa95e79592a74e49e7c4f2b5b0a 100644 --- a/paddle/phi/kernels/gather_grad_kernel.h +++ b/paddle/phi/kernels/gather_grad_kernel.h @@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/gpu/gather_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_grad_kernel.cu index 56b6f136723e6f5fa7f65d76704bb81041aa1deb..23c3eb399725755ad7ddca160423260d4444ebe4 100644 --- a/paddle/phi/kernels/gpu/gather_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_grad_kernel.cu @@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { const auto& index_type = index.dtype(); auto axis_v = axis.to(); @@ -51,10 +50,10 @@ void GatherGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; if (index_type == DataType::INT32) { phi::funcs::GPUScatterAssign( - dev_ctx, out_grad, index, x_grad, overwrite); + dev_ctx, out_grad, index, x_grad, false); } else if (index_type == DataType::INT64) { phi::funcs::GPUScatterAssign( - dev_ctx, out_grad, index, x_grad, overwrite); + dev_ctx, out_grad, index, x_grad, false); } else { PADDLE_THROW(phi::errors::InvalidArgument( "The data type of Input(Index) of gather_grad must be int32 or int64 " diff --git a/paddle/phi/kernels/xpu/gather_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_grad_kernel.cc index 86a6a39f87cf5df8b70ad4eaed67aba3b668e79b..2bf33bfd485caa3e4e5a689d39ba01270b2afbc8 100644 --- a/paddle/phi/kernels/xpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_grad_kernel.cc @@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { auto axis_v = axis.to(); const auto& index_type = index.dtype(); @@ -68,7 +67,7 @@ void GatherGradKernel(const Context& dev_ctx, xshape, index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, - overwrite); + false); } else { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm(index.numel()); @@ -86,7 +85,7 @@ void GatherGradKernel(const Context& dev_ctx, xshape, index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, - overwrite); + false); } PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_grad"); } diff --git a/paddle/phi/ops/compat/gather_sig.cc b/paddle/phi/ops/compat/gather_sig.cc index af9e50638ce7026000e53e73c68794c2c9b01cda..8618b5a226f4921fd8a7a12d8d5675834c0903db 100644 --- a/paddle/phi/ops/compat/gather_sig.cc +++ b/paddle/phi/ops/compat/gather_sig.cc @@ -26,15 +26,11 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("Axis")) { - return KernelSignature("gather_grad", - {"X", "Index", "Out@GRAD"}, - {"Axis", "overwrite"}, - {"X@GRAD"}); + return KernelSignature( + "gather_grad", {"X", "Index", "Out@GRAD"}, {"Axis"}, {"X@GRAD"}); } else { - return KernelSignature("gather_grad", - {"X", "Index", "Out@GRAD"}, - {"axis", "overwrite"}, - {"X@GRAD"}); + return KernelSignature( + "gather_grad", {"X", "Index", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); } }