未验证 提交 b110085f 编写于 作者: Z zhangyikun02 提交者: GitHub

remove copy of index for gather_nd_grad and scatter_nd_add op in xpu (#51871)

上级 298a1a0b
......@@ -65,14 +65,10 @@ void GatherNdGradKernel(const Context &ctx,
xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) {
auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{
index_cpu.data<int>(), index_size, index_data};
xpu::VectorParam<int> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr,
out_grad.data<T>(),
......@@ -83,8 +79,7 @@ void GatherNdGradKernel(const Context &ctx,
false);
} else {
auto index_data = const_cast<int64_t *>(index.data<int64_t>());
xpu::VectorParam<int64_t> index_vec{
index_cpu.data<int64_t>(), index_size, index_data};
xpu::VectorParam<int64_t> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr,
out_grad.data<T>(),
......
......@@ -70,14 +70,11 @@ void ScatterNdAddKernel(const Context &ctx,
xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) {
xpu::VectorParam<int> index_vec{index_cpu.data<int>(), index_size, nullptr};
auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr,
updates_ptr,
......@@ -87,9 +84,8 @@ void ScatterNdAddKernel(const Context &ctx,
index_shape,
false);
} else {
xpu::VectorParam<int64_t> index_vec{
index_cpu.data<int64_t>(), index_size, nullptr};
auto index_data = const_cast<int64_t *>(index.data<int64_t>());
xpu::VectorParam<int64_t> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr,
updates_ptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册