未验证 提交 8300e618 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] Add Sparse API to_dense, to_sparse_coo and values (#41394) (#41834)

Add paddle.sparse and three Sparse API (#41276)
Add Sparse API to_dense, to_sparse_coo and values (#41394)
上级 86bbb0f2
......@@ -1271,21 +1271,6 @@ static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_to_sparse_coo(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
int64_t sparse_dim = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
auto coo_tensor = self->tensor.to_sparse_coo(sparse_dim);
egr::EagerUtils::autograd_meta(&coo_tensor)
->SetStopGradient(
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
egr::EagerUtils::autograd_meta(&coo_tensor)
->SetPersistable(
egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
return ToPyObject(coo_tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_to_sparse_csr(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
......@@ -1300,20 +1285,6 @@ static PyObject* tensor_method_to_sparse_csr(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_to_dense(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
auto dense_tensor = self->tensor.to_dense();
egr::EagerUtils::autograd_meta(&dense_tensor)
->SetStopGradient(
egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
egr::EagerUtils::autograd_meta(&dense_tensor)
->SetPersistable(
egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
return ToPyObject(dense_tensor);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
......@@ -1530,17 +1501,13 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor__copy_gradient_from,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/
{"non_zero_indices",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
{"indices", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_elements",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
{"values", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_crows",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
{"crows", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"non_zero_cols",
(PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
{"cols", (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse", (PyCFunction)(void (*)(void))tensor_method_is_sparse,
METH_VARARGS | METH_KEYWORDS, NULL},
......@@ -1548,12 +1515,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"is_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_sparse_coo", (PyCFunction)(void (*)(void))tensor_method_to_sparse_coo,
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_sparse_csr", (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_dense", (PyCFunction)(void (*)(void))tensor_method_to_dense,
METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_method_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/api/ext/dispatch.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
void SparseMaskCPUKernel(const CPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
const DDim& dims = x.dims();
PADDLE_ENFORCE_EQ(
x.dims(),
mask.dims(),
phi::errors::InvalidArgument("the input x and mask must have the shape"));
const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size();
std::vector<int64_t> sparse_offsets(sparse_dim);
int64_t offset = 1;
for (int i = sparse_dim - 1; i >= 0; i--) {
sparse_offsets[i] = offset;
offset *= dims[i];
}
DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);
// the out_indices is same as indices of mask
phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices);
const IntT* indices_ptr = indices.data<IntT>();
T* out_values_ptr = out_values.data<T>();
const T* x_ptr = x.data<T>();
const int64_t non_zero_num = mask.nnz();
auto dims_2d = flatten_to_2d(dims, sparse_dim);
const int cols = dims_2d[1];
for (int64_t i = 0; i < non_zero_num; i++) {
int64_t index = 0;
for (int j = 0; j < sparse_dim; j++) {
index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j];
}
memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T));
}
out->SetMember(out_indices, out_values, dims, true);
}
/**
* @brief Filter the DenseTensor x by the
* mask.non_zero_indices() and output a SparseCooTensor
* x and mask must have the same shape.
**/
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
PD_DISPATCH_INTEGRAL_TYPES(
mask.non_zero_indices().dtype(), "SparseMaskCPUKernel", ([&] {
SparseMaskCPUKernel<T, data_t>(dev_ctx, x, mask, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sparse_mask,
CPU,
ALL_LAYOUT,
phi::sparse::SparseMaskKernel,
float,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -364,3 +364,33 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(coo_values,
CPU,
ALL_LAYOUT,
phi::sparse::CooValuesKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(csr_values,
CPU,
ALL_LAYOUT,
phi::sparse::CsrValuesKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
#include "paddle/phi/api/ext/dispatch.h"
namespace phi {
namespace sparse {
template <typename T, typename IntT>
__global__ void MaskKernel(const T* x_ptr,
const IntT* indices_ptr,
const int64_t* sparse_offsets,
const int64_t non_zero_num,
const int cols,
const int sparse_dim,
T* out_values_ptr) {
CUDA_KERNEL_LOOP_TYPE(i, non_zero_num * cols, int64_t) {
int64_t out_i = i / cols;
int64_t col_i = i - out_i * cols;
int64_t index = 0;
for (int j = 0; j < sparse_dim; j++) {
index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j];
}
out_values_ptr[out_i * cols + col_i] = x_ptr[index * cols + col_i];
}
}
template <typename T, typename IntT>
void SparseMaskGPUKernel(const GPUContext& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
const DDim& dims = x.dims();
PADDLE_ENFORCE_EQ(
x.dims(),
mask.dims(),
phi::errors::InvalidArgument("the input x and mask must have the shape"));
const DenseTensor& indices = mask.non_zero_indices();
const DenseTensor& values = mask.non_zero_elements();
int sparse_dim = indices.dims().size();
DenseTensor sparse_offsets = phi::Empty(
dev_ctx,
DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW));
std::vector<int64_t> h_sparse_offsets(sparse_dim);
int64_t offset = 1;
for (int i = sparse_dim - 1; i >= 0; i--) {
h_sparse_offsets[i] = offset;
offset *= dims[i];
}
phi::backends::gpu::GpuMemcpyAsync(sparse_offsets.data<int64_t>(),
&h_sparse_offsets[0],
sizeof(int64_t) * sparse_dim,
#ifdef PADDLE_WITH_HIP
hipMemcpyHostToDevice,
#else
cudaMemcpyHostToDevice,
#endif
dev_ctx.stream());
DenseTensor out_indices = phi::EmptyLike<T>(dev_ctx, indices);
DenseTensor out_values = phi::EmptyLike<T>(dev_ctx, values);
phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices);
const IntT* indices_ptr = indices.data<IntT>();
T* out_values_ptr = out_values.data<T>();
const T* x_ptr = x.data<T>();
const int64_t non_zero_num = mask.nnz();
auto dims_2d = flatten_to_2d(dims, sparse_dim);
const int cols = dims_2d[1];
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num * cols, 1);
MaskKernel<T, IntT><<<config.block_per_grid, config.thread_per_block>>>(
x_ptr,
indices_ptr,
sparse_offsets.data<int64_t>(),
non_zero_num,
cols,
sparse_dim,
out_values_ptr);
out->SetMember(out_indices, out_values, dims, true);
}
/**
* @brief Filter the DenseTensor x by the
* mask.non_zero_indices() and output a SparseCooTensor
* x and mask must have the same shape.
**/
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out) {
PD_DISPATCH_INTEGRAL_TYPES(
mask.non_zero_indices().dtype(), "SparseMaskGPUKernel", ([&] {
SparseMaskGPUKernel<T, data_t>(dev_ctx, x, mask, out);
}));
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(sparse_mask,
GPU,
ALL_LAYOUT,
phi::sparse::SparseMaskKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
......@@ -635,3 +635,33 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(coo_values,
GPU,
ALL_LAYOUT,
phi::sparse::CooValuesKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(csr_values,
GPU,
ALL_LAYOUT,
phi::sparse::CsrValuesKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void SparseMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const SparseCooTensor& mask,
SparseCooTensor* out);
} // namespace sparse
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void CooValuesGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& out_grad,
SparseCooTensor* x_grad) {
x_grad->SetMember(x.non_zero_indices(), out_grad, x.dims(), true);
}
template <typename T, typename Context>
void SparseCooToDenseGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& out_grad,
SparseCooTensor* x_grad) {
SparseMaskKernel<T, Context>(dev_ctx, out_grad, x, x_grad);
}
} // namespace sparse
} // namespace phi
PD_REGISTER_KERNEL(coo_values_grad,
CPU,
ALL_LAYOUT,
phi::sparse::CooValuesGradKernel,
float,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(sparse_coo_to_dense_grad,
CPU,
ALL_LAYOUT,
phi::sparse::SparseCooToDenseGradKernel,
float,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(coo_values_grad,
GPU,
ALL_LAYOUT,
phi::sparse::CooValuesGradKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
PD_REGISTER_KERNEL(sparse_coo_to_dense_grad,
GPU,
ALL_LAYOUT,
phi::sparse::SparseCooToDenseGradKernel,
float,
double,
phi::dtype::float16,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {
kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
namespace phi {
namespace sparse {
template <typename T, typename Context>
void CooValuesGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& out_grad,
SparseCooTensor* x_grad);
template <typename T, typename Context>
void SparseCooToDenseGradKernel(const Context& dev_ctx,
const SparseCooTensor& x,
const DenseTensor& out_grad,
SparseCooTensor* x_grad);
} // namespace sparse
} // namespace phi
......@@ -133,5 +133,19 @@ DenseTensor SparseCsrToDense(const Context& dev_ctx, const SparseCsrTensor& x) {
return dense;
}
template <typename T, typename Context>
void CooValuesKernel(const Context& dev_ctx,
const SparseCooTensor& x,
DenseTensor* out) {
*out = x.non_zero_elements();
}
template <typename T, typename Context>
void CsrValuesKernel(const Context& dev_ctx,
const SparseCsrTensor& x,
DenseTensor* out) {
*out = x.non_zero_elements();
}
} // namespace sparse
} // namespace phi
......@@ -872,6 +872,38 @@ def monkey_patch_varbase():
res.persistable = self.persistable
return res
@framework.dygraph_only
def values(self):
if self.is_sparse_coo():
return _C_ops.final_state_sparse_coo_values(self)
elif self.is_sparse_csr():
return _C_ops.final_state_sparse_csr_values(self)
else:
raise ValueError(
"only SparseCooTensor and SparseCsrTensor have method values")
@framework.dygraph_only
def to_dense(self):
if self.is_sparse_coo():
return _C_ops.final_state_sparse_coo_to_dense(self)
elif self.is_sparse_csr():
return _C_ops.final_state_sparse_to_dense(self)
else:
return self
@framework.dygraph_only
def to_sparse_coo(self, sparse_dim):
if self.is_sparse_csr():
return _C_ops.final_state_sparse_to_sparse_coo(self, sparse_dim)
elif self.is_sparse_coo():
return self
elif self.is_selected_rows():
raise ValueError(
"SelectedRows does not support to_sparse_coo method")
else:
#is dense tensor
return _C_ops.final_state_sparse_dense_to_coo(self, sparse_dim)
if framework._in_eager_mode_ and not hasattr(core, "eager"):
return
......@@ -884,7 +916,8 @@ def monkey_patch_varbase():
("__repr__", __str__), ("__deepcopy__", __deepcopy__),
("__module__", "paddle"), ("__array__", __array__),
("__getitem__", __getitem__), ("item", item),
("__setitem__", __setitem__), ("_to", _to)):
("__setitem__", __setitem__), ("_to", _to), ("values", values),
("to_dense", to_dense), ("to_sparse_coo", to_sparse_coo)):
if framework._in_eager_mode_:
setattr(core.eager.Tensor, method_name, method)
else:
......
......@@ -23,19 +23,28 @@ class TestSparseActivation(unittest.TestCase):
def test_sparse_relu(self):
with _test_eager_guard():
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
def dense_relu(x):
dense_x = paddle.to_tensor(
x, dtype='float32', stop_gradient=False)
dense_relu = paddle.nn.ReLU()
dense_out = dense_relu(dense_x)
dense_out.backward(dense_out)
return dense_out, dense_x.grad
dense_x = paddle.to_tensor(x, dtype='float32', stop_gradient=False)
sparse_dim = 2
sparse_x = dense_x.to_sparse_coo(sparse_dim)
sparse_relu = paddle.sparse.ReLU()
sparse_out = sparse_relu(sparse_x)
dense_relu = paddle.nn.ReLU()
#TODO: replace non_zero_elements() as values()
dense_out = dense_relu(sparse_x.non_zero_elements())
actual_result = sparse_out.non_zero_elements().numpy()
assert np.array_equal(dense_out.numpy(), actual_result)
dense_out.backward(dense_out)
sparse_out.backward(sparse_out)
dense_out, dense_x_grad = dense_relu(x)
assert np.array_equal(dense_out.numpy(),
sparse_out.to_dense().numpy())
assert np.array_equal(dense_x_grad.numpy(),
sparse_x.grad.to_dense().numpy())
if __name__ == "__main__":
unittest.main()
......@@ -46,9 +46,8 @@ class TestSparseConv(unittest.TestCase):
out.backward(out)
#At present, only backward can be verified to work normally
#TODO(zhangkaihuo): compare the result with dense conv
print(sparse_input.grad.non_zero_elements())
assert np.array_equal(correct_out_values,
out.non_zero_elements().numpy())
print(sparse_input.grad.values())
assert np.array_equal(correct_out_values, out.values().numpy())
#TODO: Add more test case
......@@ -33,8 +33,7 @@ class TestSparseCopy(unittest.TestCase):
dense_x_2 = paddle.to_tensor(np_x_2, dtype='float32')
coo_x_2 = dense_x_2.to_sparse_coo(2)
coo_x_2.copy_(coo_x, True)
assert np.array_equal(np_values,
coo_x_2.non_zero_elements().numpy())
assert np.array_equal(np_values, coo_x_2.values().numpy())
def test_copy_sparse_csr(self):
with _test_eager_guard():
......@@ -47,5 +46,4 @@ class TestSparseCopy(unittest.TestCase):
dense_x_2 = paddle.to_tensor(np_x_2, dtype='float32')
csr_x_2 = dense_x_2.to_sparse_csr()
csr_x_2.copy_(csr_x, True)
assert np.array_equal(np_values,
csr_x_2.non_zero_elements().numpy())
assert np.array_equal(np_values, csr_x_2.values().numpy())
......@@ -23,18 +23,15 @@ from paddle.fluid.framework import _test_eager_guard
class TestSparseCreate(unittest.TestCase):
def test_create_coo_by_tensor(self):
with _test_eager_guard():
non_zero_indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
non_zero_elements = [1, 2, 3, 4, 5]
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
dense_indices = paddle.to_tensor(non_zero_indices)
dense_elements = paddle.to_tensor(
non_zero_elements, dtype='float32')
dense_indices = paddle.to_tensor(indices)
dense_elements = paddle.to_tensor(values, dtype='float32')
coo = paddle.sparse.sparse_coo_tensor(
dense_indices, dense_elements, dense_shape, stop_gradient=False)
assert np.array_equal(non_zero_indices,
coo.non_zero_indices().numpy())
assert np.array_equal(non_zero_elements,
coo.non_zero_elements().numpy())
assert np.array_equal(indices, coo.indices().numpy())
assert np.array_equal(values, coo.values().numpy())
def test_create_coo_by_np(self):
with _test_eager_guard():
......@@ -42,20 +39,18 @@ class TestSparseCreate(unittest.TestCase):
values = [1.0, 2.0, 3.0]
dense_shape = [2, 3]
coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
print(coo)
assert np.array_equal(indices, coo.non_zero_indices().numpy())
assert np.array_equal(values, coo.non_zero_elements().numpy())
assert np.array_equal(indices, coo.indices().numpy())
assert np.array_equal(values, coo.values().numpy())
def test_create_csr_by_tensor(self):
with _test_eager_guard():
non_zero_crows = [0, 2, 3, 5]
non_zero_cols = [1, 3, 2, 0, 1]
non_zero_elements = [1, 2, 3, 4, 5]
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1, 2, 3, 4, 5]
dense_shape = [3, 4]
dense_crows = paddle.to_tensor(non_zero_crows)
dense_cols = paddle.to_tensor(non_zero_cols)
dense_elements = paddle.to_tensor(
non_zero_elements, dtype='float32')
dense_crows = paddle.to_tensor(crows)
dense_cols = paddle.to_tensor(cols)
dense_elements = paddle.to_tensor(values, dtype='float32')
stop_gradient = False
csr = paddle.sparse.sparse_csr_tensor(
dense_crows,
......@@ -63,7 +58,6 @@ class TestSparseCreate(unittest.TestCase):
dense_elements,
dense_shape,
stop_gradient=stop_gradient)
print(csr)
def test_create_csr_by_np(self):
with _test_eager_guard():
......@@ -73,9 +67,9 @@ class TestSparseCreate(unittest.TestCase):
dense_shape = [3, 4]
csr = paddle.sparse.sparse_csr_tensor(crows, cols, values,
dense_shape)
assert np.array_equal(crows, csr.non_zero_crows().numpy())
assert np.array_equal(cols, csr.non_zero_cols().numpy())
assert np.array_equal(values, csr.non_zero_elements().numpy())
assert np.array_equal(crows, csr.crows().numpy())
assert np.array_equal(cols, csr.cols().numpy())
assert np.array_equal(values, csr.values().numpy())
def test_place(self):
with _test_eager_guard():
......@@ -86,8 +80,8 @@ class TestSparseCreate(unittest.TestCase):
coo = paddle.sparse.sparse_coo_tensor(
indices, values, dense_shape, place=place)
assert coo.place.is_cpu_place()
assert coo.non_zero_elements().place.is_cpu_place()
assert coo.non_zero_indices().place.is_cpu_place()
assert coo.values().place.is_cpu_place()
assert coo.indices().place.is_cpu_place()
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
......@@ -95,9 +89,9 @@ class TestSparseCreate(unittest.TestCase):
csr = paddle.sparse.sparse_csr_tensor(
crows, cols, values, [3, 5], place=place)
assert csr.place.is_cpu_place()
assert csr.non_zero_crows().place.is_cpu_place()
assert csr.non_zero_cols().place.is_cpu_place()
assert csr.non_zero_elements().place.is_cpu_place()
assert csr.crows().place.is_cpu_place()
assert csr.cols().place.is_cpu_place()
assert csr.values().place.is_cpu_place()
def test_dtype(self):
with _test_eager_guard():
......@@ -131,37 +125,67 @@ class TestSparseConvert(unittest.TestCase):
def test_to_sparse_coo(self):
with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
non_zero_indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
non_zero_elements = [1, 2, 3, 4, 5]
dense_x = paddle.to_tensor(x)
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
dense_x = paddle.to_tensor(x, dtype='float32', stop_gradient=False)
out = dense_x.to_sparse_coo(2)
print(out)
assert np.array_equal(out.non_zero_indices().numpy(),
non_zero_indices)
assert np.array_equal(out.non_zero_elements().numpy(),
non_zero_elements)
dense_tensor = out.to_dense()
assert np.array_equal(dense_tensor.numpy(), x)
assert np.array_equal(out.indices().numpy(), indices)
assert np.array_equal(out.values().numpy(), values)
#test to_sparse_coo_grad backward
out_grad_indices = [[0, 1], [0, 1]]
out_grad_values = [2.0, 3.0]
out_grad = core.eager.sparse_coo_tensor(
paddle.to_tensor(out_grad_indices),
paddle.to_tensor(out_grad_values), out.shape, True)
out.backward(out_grad)
assert np.array_equal(dense_x.grad.numpy(),
out_grad.to_dense().numpy())
def test_coo_to_dense(self):
with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
sparse_x = core.eager.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values), [3, 4], False)
dense_tensor = sparse_x.to_dense()
#test to_dense_grad backward
out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]]
dense_tensor.backward(paddle.to_tensor(out_grad))
#mask the out_grad by sparse_x.indices()
correct_x_grad = [2.0, 4.0, 7.0, 9.0, 10.0]
assert np.array_equal(correct_x_grad,
sparse_x.grad.values().numpy())
def test_to_sparse_csr(self):
with _test_eager_guard():
x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]]
non_zero_crows = [0, 2, 3, 5]
non_zero_cols = [1, 3, 2, 0, 1]
non_zero_elements = [1, 2, 3, 4, 5]
crows = [0, 2, 3, 5]
cols = [1, 3, 2, 0, 1]
values = [1, 2, 3, 4, 5]
dense_x = paddle.to_tensor(x)
out = dense_x.to_sparse_csr()
print(out)
assert np.array_equal(out.non_zero_crows().numpy(), non_zero_crows)
assert np.array_equal(out.non_zero_cols().numpy(), non_zero_cols)
assert np.array_equal(out.non_zero_elements().numpy(),
non_zero_elements)
assert np.array_equal(out.crows().numpy(), crows)
assert np.array_equal(out.cols().numpy(), cols)
assert np.array_equal(out.values().numpy(), values)
dense_tensor = out.to_dense()
print(dense_tensor)
assert np.array_equal(dense_tensor.numpy(), x)
def test_coo_values_grad(self):
with _test_eager_guard():
indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]]
values = [1.0, 2.0, 3.0, 4.0, 5.0]
sparse_x = core.eager.sparse_coo_tensor(
paddle.to_tensor(indices),
paddle.to_tensor(values), [3, 4], False)
values_tensor = sparse_x.values()
out_grad = [2.0, 3.0, 5.0, 8.0, 9.0]
# test coo_values_grad
values_tensor.backward(paddle.to_tensor(out_grad))
assert np.array_equal(out_grad, sparse_x.grad.values().numpy())
if __name__ == "__main__":
unittest.main()
......@@ -291,11 +291,11 @@ def sparse_tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
if tensor.is_sparse_coo():
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient}, \n{indent}{indices}, \n{indent}{values})"
indices_tensor = tensor.non_zero_indices()
elements_tensor = tensor.non_zero_elements()
indices_tensor = tensor.indices()
values_tensor = tensor.values()
indices_data = 'indices=' + _format_dense_tensor(indices_tensor, indent
+ len('indices='))
values_data = 'values=' + _format_dense_tensor(elements_tensor, indent +
values_data = 'values=' + _format_dense_tensor(values_tensor, indent +
len('values='))
return _template.format(
prefix=prefix,
......@@ -308,9 +308,9 @@ def sparse_tensor_to_string(tensor, prefix='Tensor'):
values=values_data)
else:
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient}, \n{indent}{crows}, \n{indent}{cols}, \n{indent}{values})"
crows_tensor = tensor.non_zero_crows()
cols_tensor = tensor.non_zero_cols()
elements_tensor = tensor.non_zero_elements()
crows_tensor = tensor.crows()
cols_tensor = tensor.cols()
elements_tensor = tensor.values()
crows_data = 'crows=' + _format_dense_tensor(crows_tensor, indent +
len('crows='))
cols_data = 'cols=' + _format_dense_tensor(cols_tensor, indent +
......
......@@ -7,6 +7,33 @@
intermediate : rulebook
backward : conv3d_grad
- api : coo_to_dense
args : (Tensor x)
output : Tensor(out@DenseTensor)
invoke : to_dense_impl(x)
backward : coo_to_dense_grad
- api : coo_values
args : (Tensor x)
output : Tensor(out@DenseTensor)
kernel :
func : coo_values
layout : x
backward : coo_values_grad
- api : csr_values
args : (Tensor x)
output : Tensor(out@DenseTensor)
kernel :
func : csr_values
layout : x
- api : dense_to_coo
args : (Tensor x, int64_t sparse_dim)
output : Tensor(out@SparseCooTensor)
invoke : to_sparse_coo_impl(x, sparse_dim)
backward : dense_to_coo_grad
- api : relu
args : (Tensor x)
output : Tensor(out@SparseCooTensor)
......
......@@ -5,6 +5,26 @@
kernel :
func : sparse_conv3d_grad
- backward_api : coo_to_dense_grad
forward : coo_to_dense(Tensor x) -> Tensor(out@DenseTensor)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : sparse_coo_to_dense_grad
- backward_api : coo_values_grad
forward : coo_values(Tensor x) -> Tensor(out@DenseTensor)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad@SparseCooTensor)
kernel :
func : coo_values_grad
- backward_api : dense_to_coo_grad
forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out@SparseCooTensor)
args : (Tensor out_grad)
output : Tensor(x_grad@DenseTensor)
invoke : to_dense_impl(out_grad)
- backward_api : sparse_relu_grad
forward : sparse_relu(Tensor x) -> Tensor(out@SparseCooTensor)
args : (Tensor x, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册