未验证 提交 b110085f 编写于 作者: Z zhangyikun02 提交者: GitHub

remove copy of index for gather_nd_grad and scatter_nd_add op in xpu (#51871)

上级 298a1a0b
...@@ -65,14 +65,10 @@ void GatherNdGradKernel(const Context &ctx, ...@@ -65,14 +65,10 @@ void GatherNdGradKernel(const Context &ctx,
xpu::VectorParam<int64_t> x_vec = { xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr}; x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel()); int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
auto index_data = const_cast<int *>(index.data<int>()); auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{ xpu::VectorParam<int> index_vec{nullptr, index_size, index_data};
index_cpu.data<int>(), index_size, index_data};
r = xpu::scatter_nd<T, int>(ctx.x_context(), r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr, nullptr,
out_grad.data<T>(), out_grad.data<T>(),
...@@ -83,8 +79,7 @@ void GatherNdGradKernel(const Context &ctx, ...@@ -83,8 +79,7 @@ void GatherNdGradKernel(const Context &ctx,
false); false);
} else { } else {
auto index_data = const_cast<int64_t *>(index.data<int64_t>()); auto index_data = const_cast<int64_t *>(index.data<int64_t>());
xpu::VectorParam<int64_t> index_vec{ xpu::VectorParam<int64_t> index_vec{nullptr, index_size, index_data};
index_cpu.data<int64_t>(), index_size, index_data};
r = xpu::scatter_nd<T, int64_t>(ctx.x_context(), r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr, nullptr,
out_grad.data<T>(), out_grad.data<T>(),
......
...@@ -70,14 +70,11 @@ void ScatterNdAddKernel(const Context &ctx, ...@@ -70,14 +70,11 @@ void ScatterNdAddKernel(const Context &ctx,
xpu::VectorParam<int64_t> x_vec = { xpu::VectorParam<int64_t> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr}; x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
DenseTensor index_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &index_cpu);
int index_size = static_cast<int>(index.numel()); int index_size = static_cast<int>(index.numel());
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
xpu::VectorParam<int> index_vec{index_cpu.data<int>(), index_size, nullptr}; auto index_data = const_cast<int *>(index.data<int>());
xpu::VectorParam<int> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int>(ctx.x_context(), r = xpu::scatter_nd<T, int>(ctx.x_context(),
nullptr, nullptr,
updates_ptr, updates_ptr,
...@@ -87,9 +84,8 @@ void ScatterNdAddKernel(const Context &ctx, ...@@ -87,9 +84,8 @@ void ScatterNdAddKernel(const Context &ctx,
index_shape, index_shape,
false); false);
} else { } else {
xpu::VectorParam<int64_t> index_vec{ auto index_data = const_cast<int64_t *>(index.data<int64_t>());
index_cpu.data<int64_t>(), index_size, nullptr}; xpu::VectorParam<int64_t> index_vec{nullptr, index_size, index_data};
r = xpu::scatter_nd<T, int64_t>(ctx.x_context(), r = xpu::scatter_nd<T, int64_t>(ctx.x_context(),
nullptr, nullptr,
updates_ptr, updates_ptr,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册