未验证 提交 5752643b 编写于 作者: Z zhangkaihuo 提交者: GitHub

sparse convertion kernel support secondary dispatch (#43345)

* use GpuMemcpy and GpuMemset

* sparse convert kernel support double dispatch by indices dtype

* cudaMemcpyKind->gpuMemcpyKind
上级 c4c30e6f
......@@ -67,6 +67,16 @@ DECLARE_CONSTANT_FOR_GPU(gpuErrorOutOfMemory,
DECLARE_CONSTANT_FOR_GPU(gpuErrorNotReady, cudaErrorNotReady, hipErrorNotReady);
DECLARE_CONSTANT_FOR_GPU(gpuSuccess, cudaSuccess, hipSuccess);
DECLARE_CONSTANT_FOR_GPU(gpuMemcpyHostToDevice,
cudaMemcpyKind::cudaMemcpyHostToDevice,
hipMemcpyKind::hipMemcpyHostToDevice);
DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToHost,
cudaMemcpyKind::cudaMemcpyDeviceToHost,
hipMemcpyKind::hipMemcpyDeviceToHost);
DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToDevice,
cudaMemcpyKind::cudaMemcpyDeviceToDevice,
hipMemcpyKind::hipMemcpyDeviceToDevice);
#undef DECLARE_CONSTANT_FOR_GPU
} // namespace phi
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
namespace phi {
......@@ -68,20 +69,23 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
SparseCooTensor* out) {
const T* x_data = x.data<T>();
const auto& x_dims = x.dims();
PADDLE_ENFORCE_LE(sparse_dim,
x_dims.size(),
phi::errors::InvalidArgument(
"sparse_dim must be less than the size of x.dims()"));
PADDLE_ENFORCE_GT(
sparse_dim, 0, phi::errors::InvalidArgument("sparse_dim must be >0"));
int64_t non_zero_num = GetNonZeroNum<T>(x, sparse_dim);
const auto place = dev_ctx.GetPlace();
const auto values_dims =
phi::funcs::sparse::InferDenseDims(x_dims, sparse_dim, non_zero_num);
DenseTensorMeta indices_meta(DataType::INT64,
{sparse_dim, static_cast<int64_t>(non_zero_num)},
DataLayout::NCHW);
DenseTensorMeta values_meta(x.meta().dtype, values_dims, x.meta().layout);
phi::DenseTensor indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor indices =
phi::Empty<int64_t>(dev_ctx, {sparse_dim, non_zero_num});
phi::DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta));
int64_t* indices_data = indices.mutable_data<int64_t>(place);
T* values_data = values.mutable_data<T>(place);
int64_t* indices_data = indices.data<int64_t>();
T* values_data = values.data<T>();
auto dims_2d = flatten_to_2d(x_dims, sparse_dim);
const int rows = dims_2d[0];
......@@ -102,36 +106,32 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
out->SetMember(indices, values, x_dims, true);
}
template <typename T, typename Context>
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
template <typename T, typename IntT>
void SparseCsrToCooCPUKernel(const CPUContext& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
const DDim& x_dims = x.dims();
const int64_t non_zero_num = x.non_zero_cols().numel();
const auto& csr_crows = x.non_zero_crows();
const auto& csr_cols = x.non_zero_cols();
const auto& csr_values = x.non_zero_elements();
const int64_t* csr_crows_data = csr_crows.data<int64_t>();
const int64_t* csr_cols_data = csr_cols.data<int64_t>();
const IntT* csr_crows_data = csr_crows.data<IntT>();
const IntT* csr_cols_data = csr_cols.data<IntT>();
const T* csr_values_data = csr_values.data<T>();
int64_t sparse_dim = 2;
if (x_dims.size() == 3) {
sparse_dim = 3;
}
const auto place = dev_ctx.GetPlace();
DenseTensorMeta indices_meta(
DataType::INT64, {sparse_dim, non_zero_num}, DataLayout::NCHW);
DenseTensorMeta values_meta(
x.dtype(), {non_zero_num}, x.non_zero_elements().layout());
phi::DenseTensor indices = phi::Empty(dev_ctx, std::move(indices_meta));
phi::DenseTensor values = phi::Empty(dev_ctx, std::move(values_meta));
int64_t* coo_indices = indices.mutable_data<int64_t>(place);
int64_t* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
int64_t* coo_rows_data =
phi::DenseTensor indices =
phi::Empty<IntT>(dev_ctx, {sparse_dim, non_zero_num});
phi::DenseTensor values = phi::Empty<T>(dev_ctx, {non_zero_num});
IntT* coo_indices = indices.data<IntT>();
IntT* batch_ptr = x_dims.size() == 2 ? nullptr : coo_indices;
IntT* coo_rows_data =
x_dims.size() == 2 ? coo_indices : batch_ptr + non_zero_num;
int64_t* coo_cols_data = coo_rows_data + non_zero_num;
T* coo_values_data = values.mutable_data<T>(place);
IntT* coo_cols_data = coo_rows_data + non_zero_num;
T* coo_values_data = values.data<T>();
int batch = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
......@@ -139,7 +139,7 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
int index = 0;
for (int b = 0; b < batch; b++) {
for (int i = 0; i < rows; i++) {
for (int j = csr_crows_data[b * (rows + 1) + i];
for (IntT j = csr_crows_data[b * (rows + 1) + i];
j < csr_crows_data[b * (rows + 1) + i + 1];
j++) {
coo_rows_data[index] = i;
......@@ -151,15 +151,25 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
}
}
memcpy(coo_cols_data, csr_cols_data, sizeof(int64_t) * non_zero_num);
memcpy(coo_cols_data, csr_cols_data, sizeof(IntT) * non_zero_num);
memcpy(coo_values_data, csr_values_data, sizeof(T) * non_zero_num);
out->SetMember(indices, values, x_dims, true);
}
template <typename T, typename Context>
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
void SparseCsrToCooKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
SparseCooTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_crows().dtype(), "SparseCsrToCooCPUKernel", ([&] {
SparseCsrToCooCPUKernel<T, data_t>(dev_ctx, x, out);
}));
}
template <typename T, typename IntT>
void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
const auto& x_dims = x.dims();
bool valid = x_dims.size() == 2 || x_dims.size() == 3;
PADDLE_ENFORCE_EQ(valid,
......@@ -174,11 +184,11 @@ void SparseCooToCsrKernel(const Context& dev_ctx,
phi::DenseTensor non_zero_crows;
non_zero_crows.Resize({batchs * (rows + 1)});
int64_t* csr_crows_data = dev_ctx.template Alloc<int64_t>(&non_zero_crows);
IntT* csr_crows_data = dev_ctx.template Alloc<IntT>(&non_zero_crows);
phi::DenseTensor non_zero_cols;
non_zero_cols.Resize({non_zero_num});
int64_t* csr_cols_data = dev_ctx.template Alloc<int64_t>(&non_zero_cols);
IntT* csr_cols_data = dev_ctx.template Alloc<IntT>(&non_zero_cols);
phi::DenseTensor non_zero_elements;
non_zero_elements.Resize({non_zero_num});
......@@ -186,16 +196,12 @@ void SparseCooToCsrKernel(const Context& dev_ctx,
const auto& coo_indices = x.non_zero_indices();
const auto& coo_values = x.non_zero_elements();
const int64_t* batchs_ptr = coo_indices.data<int64_t>();
const int64_t* coo_rows_data =
const IntT* batchs_ptr = coo_indices.data<IntT>();
const IntT* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num;
const int64_t* coo_cols_data = coo_rows_data + non_zero_num;
const IntT* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>();
if (!x.coalesced()) {
// TODO(zhangkahuo): call coalesced() to distinct and sort the indices
}
std::vector<int64_t> offsets(batchs, 0);
if (batchs > 1) {
for (int i = 0; i < non_zero_num; i++) {
......@@ -220,25 +226,34 @@ void SparseCooToCsrKernel(const Context& dev_ctx,
csr_crows_data[b * (rows + 1) + i] = 0;
}
for (int64_t i = 1; i < batch_non_zero_num; i++) {
for (int j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
for (IntT j = coo_rows_ptr[i - 1]; j < coo_rows_ptr[i]; j++) {
csr_crows_data[b * (rows + 1) + j + 1] = i;
}
}
for (int64_t i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1;
i++) {
for (IntT i = coo_rows_ptr[batch_non_zero_num - 1] + 1; i < rows + 1; i++) {
csr_crows_data[b * (rows + 1) + i] = batch_non_zero_num;
}
}
memcpy(csr_cols_data, coo_cols_data, sizeof(int64_t) * non_zero_num);
memcpy(csr_cols_data, coo_cols_data, sizeof(IntT) * non_zero_num);
memcpy(csr_values_data, coo_values_data, sizeof(T) * non_zero_num);
out->SetMember(non_zero_crows, non_zero_cols, non_zero_elements, x_dims);
}
template <typename T, typename Context>
void SparseCooToDenseKernel(const Context& dev_ctx,
const SparseCooTensor& x,
DenseTensor* out) {
void SparseCooToCsrKernel(const Context& dev_ctx,
const SparseCooTensor& x,
SparseCsrTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToCsrCPUKernel", ([&] {
SparseCooToCsrCPUKernel<T, data_t>(dev_ctx, x, out);
}));
}
template <typename T, typename IntT>
void SparseCooToDenseCPUKernel(const CPUContext& dev_ctx,
const SparseCooTensor& x,
DenseTensor* out) {
const auto non_zero_num = x.nnz();
const auto dense_dims = x.dims();
const auto indices = x.non_zero_indices();
......@@ -270,8 +285,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
for (auto i = 0; i < non_zero_num; i++) {
int64_t index = 0;
for (int j = 0; j < sparse_dim; j++) {
index +=
indices.data<int64_t>()[j * non_zero_num + i] * sparse_offsets[j];
index += indices.data<IntT>()[j * non_zero_num + i] * sparse_offsets[j];
}
for (int j = 0; j < base_offset; j++) {
......@@ -280,6 +294,16 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
}
}
template <typename T, typename Context>
void SparseCooToDenseKernel(const Context& dev_ctx,
const SparseCooTensor& x,
DenseTensor* out) {
PD_VISIT_INTEGRAL_TYPES(
x.non_zero_indices().dtype(), "SparseCooToDenseCPUKernel", ([&] {
SparseCooToDenseCPUKernel<T, data_t>(dev_ctx, x, out);
}));
}
} // namespace sparse
} // namespace phi
......
......@@ -168,31 +168,33 @@ class TestSparseConvert(unittest.TestCase):
with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor = sparse_x.to_dense()
#test to_dense_grad backward
out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]]
dense_tensor.backward(paddle.to_tensor(out_grad))
#mask the out_grad by sparse_x.indices()
correct_x_grad = [2.0, 4.0, 7.0, 9.0, 10.0]
assert np.array_equal(correct_x_grad,
sparse_x.grad.values().numpy())
paddle.device.set_device("cpu")
sparse_x_cpu = paddle.incubate.sparse.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor_cpu = sparse_x_cpu.to_dense()
dense_tensor_cpu.backward(paddle.to_tensor(out_grad))
assert np.array_equal(correct_x_grad,
sparse_x_cpu.grad.values().numpy())
indices_dtypes = ['int32', 'int64']
for indices_dtype in indices_dtypes:
sparse_x = paddle.incubate.sparse.sparse_coo_tensor(
paddle.to_tensor(indices, dtype=indices_dtype),
paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor = sparse_x.to_dense()
#test to_dense_grad backward
out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]]
dense_tensor.backward(paddle.to_tensor(out_grad))
#mask the out_grad by sparse_x.indices()
correct_x_grad = [2.0, 4.0, 7.0, 9.0, 10.0]
assert np.array_equal(correct_x_grad,
sparse_x.grad.values().numpy())
paddle.device.set_device("cpu")
sparse_x_cpu = paddle.incubate.sparse.sparse_coo_tensor(
paddle.to_tensor(indices, dtype=indices_dtype),
paddle.to_tensor(values),
shape=[3, 4],
stop_gradient=False)
dense_tensor_cpu = sparse_x_cpu.to_dense()
dense_tensor_cpu.backward(paddle.to_tensor(out_grad))
assert np.array_equal(correct_x_grad,
sparse_x_cpu.grad.values().numpy())
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def test_to_sparse_csr(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册