From a32c1391e0015ea636420189eee7fcaa3d1d77b8 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Sun, 23 Apr 2023 13:05:59 +0800 Subject: [PATCH] delete overwrite from gather_grad (#52707) * delete overwrite from gather_grad * fix * Update gather_grad_kernel.cc --- paddle/fluid/operators/gather_op.cc | 2 +- .../composite_backward/composite_backward_api.h | 1 - paddle/phi/api/yaml/legacy_backward.yaml | 4 ++-- paddle/phi/api/yaml/op_compat.yaml | 2 -- paddle/phi/kernels/cpu/gather_grad_kernel.cc | 15 ++------------- paddle/phi/kernels/gather_grad_kernel.h | 1 - paddle/phi/kernels/gpu/gather_grad_kernel.cu | 5 ++--- paddle/phi/kernels/xpu/gather_grad_kernel.cc | 5 ++--- paddle/phi/ops/compat/gather_sig.cc | 12 ++++-------- 9 files changed, 13 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index d876d078541..c01f2d2d528 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -155,7 +155,7 @@ class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { "We don't support dynamic index from tensor for gather composite " "grad for now. ")); } else { - prim::gather_grad(x, index, dout, axis, false, dx_ptr); + prim::gather_grad(x, index, dout, axis, dx_ptr); } this->RecoverOutputName(dx, dx_name); } diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index e5725b189fd..c722aa48583 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -121,7 +121,6 @@ void gather_grad(const Tensor& x, const Tensor& index, const Tensor& out_grad, const Scalar& axis, - bool overwrite, Tensor* grad_x) { auto zero_tensor = full(phi::vectorize(x.dims()), 0.0, x.dtype()); std::vector tmp_perm; diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 0674270a272..3f4acc31e6a 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -418,7 +418,7 @@ - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) - args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0, bool overwrite=false) + args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -426,7 +426,7 @@ kernel : data_type: x func : gather_grad - composite : gather_grad(x, index, out_grad, axis, overwrite, x_grad) + composite : gather_grad(x, index, out_grad, axis, x_grad) no_need_buffer : x - backward_op : group_norm_grad diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 6e600784089..490aff92c19 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -955,8 +955,6 @@ - op : gather backward : gather_grad - extra : - attrs : [bool overwrite = true] - op : gather_nd backward : gather_nd_grad diff --git a/paddle/phi/kernels/cpu/gather_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_grad_kernel.cc index f7f0ac6b2e0..0d0341234e6 100644 --- a/paddle/phi/kernels/cpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_grad_kernel.cc @@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { const auto& index_type = index.dtype(); auto axis_v = axis.to(); @@ -52,19 +51,9 @@ void GatherGradKernel(const Context& dev_ctx, if (x_grad->numel() == 0) return; if (index_type == phi::DataType::INT32) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); - } else { - phi::funcs::ScatterAssignAdd( - dev_ctx, out_grad, index, x_grad); - } + phi::funcs::ScatterAssignAdd(dev_ctx, out_grad, index, x_grad); } else if (index_type == phi::DataType::INT64) { - if (overwrite) { - phi::funcs::ScatterAssign(dev_ctx, out_grad, index, x_grad); - } else { - phi::funcs::ScatterAssignAdd( - dev_ctx, out_grad, index, x_grad); - } + phi::funcs::ScatterAssignAdd(dev_ctx, out_grad, index, x_grad); } else { PADDLE_THROW(phi::errors::InvalidArgument( "The data type of Input(Index) of gather_grad must be int32 or int64 " diff --git a/paddle/phi/kernels/gather_grad_kernel.h b/paddle/phi/kernels/gather_grad_kernel.h index e53da7b471c..7c978e139e6 100644 --- a/paddle/phi/kernels/gather_grad_kernel.h +++ b/paddle/phi/kernels/gather_grad_kernel.h @@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/gpu/gather_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_grad_kernel.cu index 56b6f136723..23c3eb39972 100644 --- a/paddle/phi/kernels/gpu/gather_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_grad_kernel.cu @@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { const auto& index_type = index.dtype(); auto axis_v = axis.to(); @@ -51,10 +50,10 @@ void GatherGradKernel(const Context& dev_ctx, if (out_grad.numel() == 0) return; if (index_type == DataType::INT32) { phi::funcs::GPUScatterAssign( - dev_ctx, out_grad, index, x_grad, overwrite); + dev_ctx, out_grad, index, x_grad, false); } else if (index_type == DataType::INT64) { phi::funcs::GPUScatterAssign( - dev_ctx, out_grad, index, x_grad, overwrite); + dev_ctx, out_grad, index, x_grad, false); } else { PADDLE_THROW(phi::errors::InvalidArgument( "The data type of Input(Index) of gather_grad must be int32 or int64 " diff --git a/paddle/phi/kernels/xpu/gather_grad_kernel.cc b/paddle/phi/kernels/xpu/gather_grad_kernel.cc index 86a6a39f87c..2bf33bfd485 100644 --- a/paddle/phi/kernels/xpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_grad_kernel.cc @@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, const DenseTensor& index, const DenseTensor& out_grad, const Scalar& axis, - bool overwrite, DenseTensor* x_grad) { auto axis_v = axis.to(); const auto& index_type = index.dtype(); @@ -68,7 +67,7 @@ void GatherGradKernel(const Context& dev_ctx, xshape, index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, - overwrite); + false); } else { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm(index.numel()); @@ -86,7 +85,7 @@ void GatherGradKernel(const Context& dev_ctx, xshape, index.dims().size() == 0 ? 1 : index.dims()[0], axis_v, - overwrite); + false); } PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_grad"); } diff --git a/paddle/phi/ops/compat/gather_sig.cc b/paddle/phi/ops/compat/gather_sig.cc index af9e50638ce..8618b5a226f 100644 --- a/paddle/phi/ops/compat/gather_sig.cc +++ b/paddle/phi/ops/compat/gather_sig.cc @@ -26,15 +26,11 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("Axis")) { - return KernelSignature("gather_grad", - {"X", "Index", "Out@GRAD"}, - {"Axis", "overwrite"}, - {"X@GRAD"}); + return KernelSignature( + "gather_grad", {"X", "Index", "Out@GRAD"}, {"Axis"}, {"X@GRAD"}); } else { - return KernelSignature("gather_grad", - {"X", "Index", "Out@GRAD"}, - {"axis", "overwrite"}, - {"X@GRAD"}); + return KernelSignature( + "gather_grad", {"X", "Index", "Out@GRAD"}, {"axis"}, {"X@GRAD"}); } } -- GitLab