未验证 提交 15251291 编写于 作者: Z Zhan Rongrui 提交者: GitHub

[Sparse]fix bug in paddle.sparse.transpose and paddle.sparse.reshape (#53038)

上级 5cfaa7dd
......@@ -33,7 +33,7 @@ void TransposeCooKernel(const Context& dev_ctx,
int64_t x_nnz = x.nnz();
DDim out_dims = x.dims().transpose(perm);
DenseTensor out_indices = EmptyLike<int64_t, Context>(dev_ctx, x.indices());
DenseTensor out_values(x.values());
const DenseTensor& out_values(x.values());
out->SetMember(out_indices, out_values, out_dims, x.coalesced());
// compute values of indices
......@@ -108,7 +108,7 @@ void TransposeCsrKernel(const Context& dev_ctx,
out_crows_data[i] = 0;
}
for (int i = 0; i < x_nnz; ++i) {
int j = x_cols_data[i];
int64_t j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_dims[0]] = x_nnz;
......@@ -134,8 +134,8 @@ void TransposeCsrKernel(const Context& dev_ctx,
}
}
} else { // n_dim == 3
int out_n_rows = out_dims[1];
int x_n_rows = x.dims()[1];
int64_t out_n_rows = out_dims[1];
int64_t x_n_rows = x.dims()[1];
for (int k = 0; k < out_dims[0]; ++k) {
if (perm[0] == 0) { // perm == {0, 2, 1}
// compute out_crows_data by x_cols_data
......@@ -143,7 +143,7 @@ void TransposeCsrKernel(const Context& dev_ctx,
out_crows_data[i] = 0;
}
for (int i = 0; i < x_crows_data[x_n_rows]; ++i) {
int j = x_cols_data[i];
int64_t j = x_cols_data[i];
out_crows_data[j + 1]++;
}
out_crows_data[out_n_rows] = x_crows_data[x_n_rows];
......@@ -176,14 +176,14 @@ void TransposeCsrKernel(const Context& dev_ctx,
for (int i = 0; i < out_n_rows; ++i) {
out_crows_data[i] = 0;
}
int x_cols_offset = 0;
int64_t x_cols_offset = 0;
int out_cols_index = 0;
for (int i = 0; i < x.dims()[0]; ++i) {
int x_crows_index = i * (x_n_rows + 1);
int start = x_crows_data[x_crows_index + k];
int end = x_crows_data[x_crows_index + 1 + k];
int64_t start = x_crows_data[x_crows_index + k];
int64_t end = x_crows_data[x_crows_index + 1 + k];
out_crows_data[i + 1] = end - start;
for (int j = start; j < end; ++j) {
for (int64_t j = start; j < end; ++j) {
out_cols_data[out_cols_index] = x_cols_data[x_cols_offset + j];
out_values_data[out_cols_index] = x_values_data[x_cols_offset + j];
out_cols_index++;
......
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
......@@ -78,33 +79,31 @@ void ReshapeCooKernel(const Context& dev_ctx,
int64_t *destination_x_sparse_part_strides,
*destination_out_sparse_part_strides;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
hipMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
hipMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void**>(&destination_x_sparse_part_strides),
sizeof(int64_t) * x_sparse_part_strides.size());
cudaMemcpy(destination_x_sparse_part_strides,
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void**>(&destination_out_sparse_part_strides),
sizeof(int64_t) * out_sparse_part_strides.size());
cudaMemcpy(destination_out_sparse_part_strides,
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
cudaMemcpyHostToDevice);
#endif
auto destination_x_sparse_part_strides_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * x_sparse_part_strides.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
destination_x_sparse_part_strides = reinterpret_cast<int64_t*>(
destination_x_sparse_part_strides_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
reinterpret_cast<void*>(destination_x_sparse_part_strides),
phi::CPUPlace(),
x_sparse_part_strides.Get(),
sizeof(int64_t) * x_sparse_part_strides.size(),
dev_ctx.stream());
auto destination_out_sparse_part_strides_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * out_sparse_part_strides.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
destination_out_sparse_part_strides = reinterpret_cast<int64_t*>(
destination_out_sparse_part_strides_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
destination_out_sparse_part_strides,
phi::CPUPlace(),
out_sparse_part_strides.Get(),
sizeof(int64_t) * out_sparse_part_strides.size(),
dev_ctx.stream());
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz, 1);
ReshapeCooCudaKernel<<<config.block_per_grid.x,
......
......@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/sparse/unary_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
......@@ -170,15 +171,18 @@ void TransposeCooKernel(const Context &dev_ctx,
const auto *x_indices_data = x_indices.data<int64_t>();
auto *out_indices_data = out_indices.data<int64_t>();
int *d_perm;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
hipMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
cudaMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice);
#endif
auto d_perm_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int) * perm.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
d_perm = reinterpret_cast<int *>(d_perm_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_perm,
phi::CPUPlace(),
perm.data(),
sizeof(int) * perm.size(),
dev_ctx.stream());
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_nnz * n_dim, 1);
TransposeCooCudaKernel<<<config.block_per_grid.x,
......@@ -242,39 +246,41 @@ void TransposeCsrKernel(const Context &dev_ctx,
const T *x_values_data = x_values.data<T>();
int *d_perm;
int64_t *d_x_dims, *d_out_dims;
#ifdef PADDLE_WITH_HIP
hipMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
hipMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void **>(&d_x_dims),
sizeof(int64_t) * x.dims().size());
hipMemcpy(d_x_dims,
x.dims().Get(),
sizeof(int64_t) * x.dims().size(),
hipMemcpyHostToDevice);
hipMalloc(reinterpret_cast<void **>(&d_out_dims),
sizeof(int64_t) * out_dims.size());
hipMemcpy(d_out_dims,
out_dims.Get(),
sizeof(int64_t) * out_dims.size(),
hipMemcpyHostToDevice);
#else
cudaMalloc(reinterpret_cast<void **>(&d_perm), sizeof(int) * perm.size());
cudaMemcpy(
d_perm, perm.data(), sizeof(int) * perm.size(), cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void **>(&d_x_dims),
sizeof(int64_t) * x.dims().size());
cudaMemcpy(d_x_dims,
x.dims().Get(),
sizeof(int64_t) * x.dims().size(),
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void **>(&d_out_dims),
sizeof(int64_t) * out_dims.size());
cudaMemcpy(d_out_dims,
out_dims.Get(),
sizeof(int64_t) * out_dims.size(),
cudaMemcpyHostToDevice);
#endif
auto d_perm_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int) * perm.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
d_perm = reinterpret_cast<int *>(d_perm_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_perm,
phi::CPUPlace(),
perm.data(),
sizeof(int) * perm.size(),
dev_ctx.stream());
auto d_x_dims_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * x.dims().size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
d_x_dims = reinterpret_cast<int64_t *>(d_x_dims_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_x_dims,
phi::CPUPlace(),
x.dims().Get(),
sizeof(int64_t) * x.dims().size(),
dev_ctx.stream());
auto d_out_dims_tensor = memory_utils::Alloc(
dev_ctx.GetPlace(),
sizeof(int64_t) * out_dims.size(),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
d_out_dims = reinterpret_cast<int64_t *>(d_out_dims_tensor->ptr());
memory_utils::Copy(dev_ctx.GetPlace(),
d_out_dims,
phi::CPUPlace(),
out_dims.Get(),
sizeof(int64_t) * out_dims.size(),
dev_ctx.stream());
int64_t x_nnz = x.nnz();
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_dims[0], 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册