未验证 提交 a32c1391 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete overwrite from gather_grad (#52707)

* delete overwrite from gather_grad

* fix

* Update gather_grad_kernel.cc
上级 7634a18a
...@@ -155,7 +155,7 @@ class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { ...@@ -155,7 +155,7 @@ class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
"We don't support dynamic index from tensor for gather composite " "We don't support dynamic index from tensor for gather composite "
"grad for now. ")); "grad for now. "));
} else { } else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, false, dx_ptr); prim::gather_grad<prim::DescTensor>(x, index, dout, axis, dx_ptr);
} }
this->RecoverOutputName(dx, dx_name); this->RecoverOutputName(dx, dx_name);
} }
......
...@@ -121,7 +121,6 @@ void gather_grad(const Tensor& x, ...@@ -121,7 +121,6 @@ void gather_grad(const Tensor& x,
const Tensor& index, const Tensor& index,
const Tensor& out_grad, const Tensor& out_grad,
const Scalar& axis, const Scalar& axis,
bool overwrite,
Tensor* grad_x) { Tensor* grad_x) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype()); auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int> tmp_perm; std::vector<int> tmp_perm;
......
...@@ -418,7 +418,7 @@ ...@@ -418,7 +418,7 @@
- backward_op : gather_grad - backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0, bool overwrite=false) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
...@@ -426,7 +426,7 @@ ...@@ -426,7 +426,7 @@
kernel : kernel :
data_type: x data_type: x
func : gather_grad func : gather_grad
composite : gather_grad(x, index, out_grad, axis, overwrite, x_grad) composite : gather_grad(x, index, out_grad, axis, x_grad)
no_need_buffer : x no_need_buffer : x
- backward_op : group_norm_grad - backward_op : group_norm_grad
......
...@@ -955,8 +955,6 @@ ...@@ -955,8 +955,6 @@
- op : gather - op : gather
backward : gather_grad backward : gather_grad
extra :
attrs : [bool overwrite = true]
- op : gather_nd - op : gather_nd
backward : gather_nd_grad backward : gather_nd_grad
......
...@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const Scalar& axis, const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) { DenseTensor* x_grad) {
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
auto axis_v = axis.to<int>(); auto axis_v = axis.to<int>();
...@@ -52,19 +51,9 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -52,19 +51,9 @@ void GatherGradKernel(const Context& dev_ctx,
if (x_grad->numel() == 0) return; if (x_grad->numel() == 0) return;
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
if (overwrite) { phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, out_grad, index, x_grad);
phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int32_t>(
dev_ctx, out_grad, index, x_grad);
}
} else if (index_type == phi::DataType::INT64) { } else if (index_type == phi::DataType::INT64) {
if (overwrite) { phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, out_grad, index, x_grad);
phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int64_t>(
dev_ctx, out_grad, index, x_grad);
}
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 " "The data type of Input(Index) of gather_grad must be int32 or int64 "
......
...@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const Scalar& axis, const Scalar& axis,
bool overwrite,
DenseTensor* x_grad); DenseTensor* x_grad);
} // namespace phi } // namespace phi
...@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const Scalar& axis, const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) { DenseTensor* x_grad) {
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
auto axis_v = axis.to<int>(); auto axis_v = axis.to<int>();
...@@ -51,10 +50,10 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -51,10 +50,10 @@ void GatherGradKernel(const Context& dev_ctx,
if (out_grad.numel() == 0) return; if (out_grad.numel() == 0) return;
if (index_type == DataType::INT32) { if (index_type == DataType::INT32) {
phi::funcs::GPUScatterAssign<T, int>( phi::funcs::GPUScatterAssign<T, int>(
dev_ctx, out_grad, index, x_grad, overwrite); dev_ctx, out_grad, index, x_grad, false);
} else if (index_type == DataType::INT64) { } else if (index_type == DataType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>( phi::funcs::GPUScatterAssign<T, int64_t>(
dev_ctx, out_grad, index, x_grad, overwrite); dev_ctx, out_grad, index, x_grad, false);
} else { } else {
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 " "The data type of Input(Index) of gather_grad must be int32 or int64 "
......
...@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index, const DenseTensor& index,
const DenseTensor& out_grad, const DenseTensor& out_grad,
const Scalar& axis, const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto axis_v = axis.to<int>(); auto axis_v = axis.to<int>();
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
...@@ -68,7 +67,7 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -68,7 +67,7 @@ void GatherGradKernel(const Context& dev_ctx,
xshape, xshape,
index.dims().size() == 0 ? 1 : index.dims()[0], index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v, axis_v,
overwrite); false);
} else { } else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(index.numel()); int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(index.numel());
...@@ -86,7 +85,7 @@ void GatherGradKernel(const Context& dev_ctx, ...@@ -86,7 +85,7 @@ void GatherGradKernel(const Context& dev_ctx,
xshape, xshape,
index.dims().size() == 0 ? 1 : index.dims()[0], index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v, axis_v,
overwrite); false);
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_grad");
} }
......
...@@ -26,15 +26,11 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -26,15 +26,11 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) { if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad", return KernelSignature(
{"X", "Index", "Out@GRAD"}, "gather_grad", {"X", "Index", "Out@GRAD"}, {"Axis"}, {"X@GRAD"});
{"Axis", "overwrite"},
{"X@GRAD"});
} else { } else {
return KernelSignature("gather_grad", return KernelSignature(
{"X", "Index", "Out@GRAD"}, "gather_grad", {"X", "Index", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
{"axis", "overwrite"},
{"X@GRAD"});
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册