未验证 提交 a32c1391 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete overwrite from gather_grad (#52707)

* delete overwrite from gather_grad

* fix

* Update gather_grad_kernel.cc
上级 7634a18a
......@@ -155,7 +155,7 @@ class GatherCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
"We don't support dynamic index from tensor for gather composite "
"grad for now. "));
} else {
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, false, dx_ptr);
prim::gather_grad<prim::DescTensor>(x, index, dout, axis, dx_ptr);
}
this->RecoverOutputName(dx, dx_name);
}
......
......@@ -121,7 +121,6 @@ void gather_grad(const Tensor& x,
const Tensor& index,
const Tensor& out_grad,
const Scalar& axis,
bool overwrite,
Tensor* grad_x) {
auto zero_tensor = full<T>(phi::vectorize(x.dims()), 0.0, x.dtype());
std::vector<int> tmp_perm;
......
......@@ -418,7 +418,7 @@
- backward_op : gather_grad
forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0, bool overwrite=false)
args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
......@@ -426,7 +426,7 @@
kernel :
data_type: x
func : gather_grad
composite : gather_grad(x, index, out_grad, axis, overwrite, x_grad)
composite : gather_grad(x, index, out_grad, axis, x_grad)
no_need_buffer : x
- backward_op : group_norm_grad
......
......@@ -955,8 +955,6 @@
- op : gather
backward : gather_grad
extra :
attrs : [bool overwrite = true]
- op : gather_nd
backward : gather_nd_grad
......
......@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
......@@ -52,19 +51,9 @@ void GatherGradKernel(const Context& dev_ctx,
if (x_grad->numel() == 0) return;
if (index_type == phi::DataType::INT32) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int32_t>(
dev_ctx, out_grad, index, x_grad);
}
phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, out_grad, index, x_grad);
} else if (index_type == phi::DataType::INT64) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int64_t>(
dev_ctx, out_grad, index, x_grad);
}
phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, out_grad, index, x_grad);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 "
......
......@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad);
} // namespace phi
......@@ -28,7 +28,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
......@@ -51,10 +50,10 @@ void GatherGradKernel(const Context& dev_ctx,
if (out_grad.numel() == 0) return;
if (index_type == DataType::INT32) {
phi::funcs::GPUScatterAssign<T, int>(
dev_ctx, out_grad, index, x_grad, overwrite);
dev_ctx, out_grad, index, x_grad, false);
} else if (index_type == DataType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>(
dev_ctx, out_grad, index, x_grad, overwrite);
dev_ctx, out_grad, index, x_grad, false);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 "
......
......@@ -25,7 +25,6 @@ void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
auto axis_v = axis.to<int>();
const auto& index_type = index.dtype();
......@@ -68,7 +67,7 @@ void GatherGradKernel(const Context& dev_ctx,
xshape,
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v,
overwrite);
false);
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(index.numel());
......@@ -86,7 +85,7 @@ void GatherGradKernel(const Context& dev_ctx,
xshape,
index.dims().size() == 0 ? 1 : index.dims()[0],
axis_v,
overwrite);
false);
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_grad");
}
......
......@@ -26,15 +26,11 @@ KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad",
{"X", "Index", "Out@GRAD"},
{"Axis", "overwrite"},
{"X@GRAD"});
return KernelSignature(
"gather_grad", {"X", "Index", "Out@GRAD"}, {"Axis"}, {"X@GRAD"});
} else {
return KernelSignature("gather_grad",
{"X", "Index", "Out@GRAD"},
{"axis", "overwrite"},
{"X@GRAD"});
return KernelSignature(
"gather_grad", {"X", "Index", "Out@GRAD"}, {"axis"}, {"X@GRAD"});
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册