From acd08a9b4dce03ebe9cedfbe8c98c823799feeea Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 13 Apr 2022 09:50:21 +0800 Subject: [PATCH] Add kernel sparse_mask_helper; sparse_coo_tensor_grad (#41586) --- .../phi/kernels/funcs/sparse/common_shape.h | 39 ++++ .../kernels/sparse/cpu/sparse_mask_kernel.cc | 101 +++++++++-- .../kernels/sparse/cpu/sparse_utils_kernel.cc | 12 ++ .../kernels/sparse/gpu/sparse_mask_kernel.cu | 166 +++++++++++++++++- .../kernels/sparse/gpu/sparse_utils_kernel.cu | 12 ++ .../phi/kernels/sparse/sparse_mask_kernel.h | 6 + .../sparse/sparse_utils_grad_kernel.cc | 25 +++ .../kernels/sparse/sparse_utils_grad_kernel.h | 9 + .../phi/kernels/sparse/sparse_utils_kernel.h | 12 ++ .../tests/unittests/test_sparse_utils_op.py | 64 ++++++- python/paddle/sparse/creation.py | 36 +++- python/paddle/utils/code_gen/sparse_api.yaml | 8 + .../paddle/utils/code_gen/sparse_bw_api.yaml | 7 + 13 files changed, 476 insertions(+), 21 deletions(-) diff --git a/paddle/phi/kernels/funcs/sparse/common_shape.h b/paddle/phi/kernels/funcs/sparse/common_shape.h index 3617e3cd2f4..e4c836d1162 100644 --- a/paddle/phi/kernels/funcs/sparse/common_shape.h +++ b/paddle/phi/kernels/funcs/sparse/common_shape.h @@ -40,6 +40,45 @@ inline const DDim InferDenseDims(const DDim& x_dims, return values_dims; } +template +inline const IntT HOSTDEVICE IndicesToIndex(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int i) { + IntT index = 0; + for (IntT j = 0; j < sparse_dim; j++) { + index += indices[j * non_zero_num + i] * sparse_offsets[j]; + } + return index; +} + +template +inline void HOSTDEVICE FlattenIndices(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + const int start, + const int stride, + IntT* out) { + for (int i = start; i < non_zero_num; i += stride) { + out[i] = + IndicesToIndex(indices, sparse_offsets, non_zero_num, sparse_dim, i); + } +} + +// 1. indices.dims().size() == 2 +template +inline void CalcOffsetsPerDim(const DDim& dims, + const int64_t sparse_dim, + std::vector* offsets) { + IntT offset = 1; + for (IntT i = sparse_dim - 1; i >= 0; i--) { + (*offsets)[i] = offset; + offset *= dims[i]; + } +} + } // namespace sparse } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index 0a5e145312e..a07a7fb2ecf 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/api/ext/dispatch.h" @@ -38,12 +39,6 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& values = mask.non_zero_elements(); int sparse_dim = indices.dims().size(); - std::vector sparse_offsets(sparse_dim); - int64_t offset = 1; - for (int i = sparse_dim - 1; i >= 0; i--) { - sparse_offsets[i] = offset; - offset *= dims[i]; - } DenseTensor out_indices = phi::EmptyLike(dev_ctx, indices); DenseTensor out_values = phi::EmptyLike(dev_ctx, values); @@ -51,21 +46,25 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx, // the out_indices is same as indices of mask phi::Copy(dev_ctx, indices, dev_ctx.GetPlace(), false, &out_indices); - const IntT* indices_ptr = indices.data(); T* out_values_ptr = out_values.data(); const T* x_ptr = x.data(); const int64_t non_zero_num = mask.nnz(); auto dims_2d = flatten_to_2d(dims, sparse_dim); const int cols = dims_2d[1]; + const IntT* indices_ptr = indices.data(); + + std::vector out_indexs(non_zero_num), sparse_offsets(sparse_dim); + + phi::funcs::sparse::CalcOffsetsPerDim( + dims, sparse_dim, &sparse_offsets); for (int64_t i = 0; i < non_zero_num; i++) { - int64_t index = 0; - for (int j = 0; j < sparse_dim; j++) { - index += indices_ptr[j * non_zero_num + i] * sparse_offsets[j]; - } + int64_t index = phi::funcs::sparse::IndicesToIndex( + indices_ptr, sparse_offsets.data(), non_zero_num, sparse_dim, i); memcpy(out_values_ptr + i * cols, x_ptr + index * cols, cols * sizeof(T)); } + out->SetMember(out_indices, out_values, dims, true); } @@ -85,6 +84,73 @@ void SparseMaskKernel(const Context& dev_ctx, })); } +template +void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mask_indices, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + mask_indices.dims().size(), + 2, + phi::errors::InvalidArgument("the mask_indices must be 2-D tensor")); + + const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + + std::vector sparse_offsets(sparse_dim), x_indexs(x.nnz()), + mask_indexs(mask_indices.dims()[1]); + phi::funcs::sparse::CalcOffsetsPerDim( + x.dims(), sparse_dim, &sparse_offsets); + + phi::funcs::sparse::FlattenIndices(x.non_zero_indices().data(), + sparse_offsets.data(), + x.nnz(), + sparse_dim, + 0, + 1, + x_indexs.data()); + phi::funcs::sparse::FlattenIndices(mask_indices.data(), + sparse_offsets.data(), + x.nnz(), + sparse_dim, + 0, + 1, + mask_indexs.data()); + + std::unordered_map x_indexs_map; + for (uint64_t i = 0; i < x_indexs.size(); i++) { + x_indexs_map[x_indexs[i]] = i; + } + *out = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + T* out_ptr = out->data(); + memset(out_ptr, static_cast(0), out->numel() * sizeof(T)); + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + const T* in_ptr = x.non_zero_elements().data(); + // TODO(zhangkaihuo): multithreading can be used for acceleration + for (uint64_t i = 0; i < mask_indexs.size(); i++) { + auto iter = x_indexs_map.find(mask_indexs[i]); + if (iter != x_indexs_map.end()) { + memcpy(out_ptr + i * stride, + in_ptr + iter->second * stride, + stride * sizeof(T)); + } + } +} + +/** + * @brief filter values from x.values() using mask_indices + */ +template +void SparseMaskHelperKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mask_indices, + DenseTensor* out) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] { + SparseMaskHelperCPUKernel(dev_ctx, x, mask_indices, out); + })); +} + } // namespace sparse } // namespace phi @@ -101,3 +167,16 @@ PD_REGISTER_KERNEL(sparse_mask, int64_t) { kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(sparse_mask_helper, + CPU, + ALL_LAYOUT, + phi::sparse::SparseMaskHelperKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc index acc83426966..0499371a4dd 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc @@ -394,3 +394,15 @@ PD_REGISTER_KERNEL(csr_values, int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(sparse_coo_tensor, + CPU, + ALL_LAYOUT, + phi::sparse::SparseCooTensorKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index d206d6bbc19..96ab56697b9 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/ddim.h" @@ -20,6 +22,7 @@ limitations under the License. */ #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" #include "paddle/phi/api/ext/dispatch.h" @@ -59,7 +62,7 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx, const DenseTensor& indices = mask.non_zero_indices(); const DenseTensor& values = mask.non_zero_elements(); int sparse_dim = indices.dims().size(); - DenseTensor sparse_offsets = phi::Empty( + DenseTensor sparse_offsets = phi::Empty( dev_ctx, DenseTensorMeta(DataType::INT64, {sparse_dim}, DataLayout::NCHW)); std::vector h_sparse_offsets(sparse_dim); @@ -121,6 +124,153 @@ void SparseMaskKernel(const Context& dev_ctx, })); } +// TODO(zhangkaihuo): Use an op to realize the function of FlattenIndices +template +__global__ void FlattenIndicesKernel(const IntT* indices, + const IntT* sparse_offsets, + const int64_t non_zero_num, + const int64_t sparse_dim, + IntT* out) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + phi::funcs::sparse::FlattenIndices(indices, + sparse_offsets, + non_zero_num, + sparse_dim, + tid, + gridDim.x * blockDim.x, + out); +} + +template +__global__ void SparseMaskCopyKernel(const IntT* x_indexs, + const IntT* mask_indexs, + const IntT* bound_out, + const T* x_values, + const int64_t n, + const int64_t stride, + T* out_values) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + const IntT j = bound_out[i]; + if (j >= 0 && j < n && mask_indexs[i] == x_indexs[j]) { + for (int k = 0; k < stride; k++) { + out_values[i * stride + k] = x_values[j * stride + k]; + } + } + } +} + +template +void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mask_indices, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + mask_indices.dims().size(), + 2, + phi::errors::InvalidArgument("the mask_indices must be 2-D tensor")); + + const int64_t sparse_dim = x.non_zero_indices().dims()[0]; + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); + + std::vector sparse_offsets(sparse_dim); + + DenseTensorMeta x_indexs_meta(indices_dtype, {x.nnz()}, DataLayout::NCHW); + DenseTensorMeta mask_indexs_meta( + indices_dtype, {mask_indices.dims()[1]}, DataLayout::NCHW); + DenseTensorMeta sparse_offset_meta( + indices_dtype, {sparse_dim}, DataLayout::NCHW); + + DenseTensor x_indexs = + phi::Empty(dev_ctx, std::move(x_indexs_meta)); + DenseTensor mask_indexs = + phi::Empty(dev_ctx, std::move(mask_indexs_meta)); + DenseTensor bound_out = + phi::Empty(dev_ctx, std::move(mask_indexs_meta)); + DenseTensor d_sparse_offsets = + phi::Empty(dev_ctx, std::move(sparse_offset_meta)); + IntT* x_indexs_ptr = x_indexs.data(); + IntT* mask_indexs_ptr = mask_indexs.data(); + IntT* bound_out_ptr = bound_out.data(); + + // 1. calc the offsets of per dim + phi::funcs::sparse::CalcOffsetsPerDim(x.dims(), sparse_dim, &sparse_offsets); + // 2. copy sparse_offsets to device + phi::backends::gpu::GpuMemcpyAsync(d_sparse_offsets.data(), + sparse_offsets.data(), + sizeof(IntT) * sparse_dim, +#ifdef PADDLE_WITH_HIP + hipMemcpyHostToDevice, +#else + cudaMemcpyHostToDevice, +#endif + dev_ctx.stream()); + + // 3. flatten x indices and mask indices + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_indexs.numel(), 1); + FlattenIndicesKernel<<>>(x.non_zero_indices().data(), + d_sparse_offsets.data(), + x_indexs.numel(), + sparse_dim, + x_indexs_ptr); + + config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); + FlattenIndicesKernel<<>>(mask_indices.data(), + d_sparse_offsets.data(), + mask_indexs.numel(), + sparse_dim, + mask_indexs_ptr); +// 4. call thrust::lower_bound +#ifdef PADDLE_WITH_HIP + thrust::lower_bound(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::lower_bound(thrust::cuda::par.on(dev_ctx.stream()), +#endif + x_indexs_ptr, + x_indexs_ptr + x_indexs.numel(), + mask_indexs_ptr, + mask_indexs_ptr + mask_indexs.numel(), + bound_out_ptr); + + // 5. copy value to out + *out = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, out, static_cast(0)); + T* out_ptr = out->data(); + + const int64_t stride = + x.dims().size() == sparse_dim ? 1 : x.dims().size() - sparse_dim; + + SparseMaskCopyKernel<<>>(x_indexs_ptr, + mask_indexs_ptr, + bound_out_ptr, + x.non_zero_elements().data(), + mask_indexs.numel(), + stride, + out_ptr); +} + +template +void SparseMaskHelperKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mask_indices, + DenseTensor* out) { + PD_DISPATCH_INTEGRAL_TYPES( + x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] { + SparseMaskHelperGPUKernel(dev_ctx, x, mask_indices, out); + })); +} + } // namespace sparse } // namespace phi @@ -138,3 +288,17 @@ PD_REGISTER_KERNEL(sparse_mask, int64_t) { kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(sparse_mask_helper, + GPU, + ALL_LAYOUT, + phi::sparse::SparseMaskHelperKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); +} diff --git a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu index 1109baf92e3..0b6ac1aed01 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu @@ -665,3 +665,15 @@ PD_REGISTER_KERNEL(csr_values, int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } + +PD_REGISTER_KERNEL(sparse_coo_tensor, + GPU, + ALL_LAYOUT, + phi::sparse::SparseCooTensorKernel, + float, + double, + phi::dtype::float16, + uint8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/phi/kernels/sparse/sparse_mask_kernel.h b/paddle/phi/kernels/sparse/sparse_mask_kernel.h index 210412abd86..88899e3dc67 100644 --- a/paddle/phi/kernels/sparse/sparse_mask_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_mask_kernel.h @@ -26,5 +26,11 @@ void SparseMaskKernel(const Context& dev_ctx, const SparseCooTensor& mask, SparseCooTensor* out); +template +void SparseMaskHelperKernel(const Context& dev_ctx, + const SparseCooTensor& x, + const DenseTensor& mask_indices, + DenseTensor* out); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc index 35329807e77..15d78692f4f 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc @@ -66,6 +66,19 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } +PD_REGISTER_KERNEL(sparse_coo_tensor_grad, + CPU, + ALL_LAYOUT, + phi::sparse::SparseCooTensorGradKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL(coo_values_grad, GPU, @@ -95,4 +108,16 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense_grad, int64_t) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } +PD_REGISTER_KERNEL(sparse_coo_tensor_grad, + GPU, + ALL_LAYOUT, + phi::sparse::SparseCooTensorGradKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t) { + kernel->InputAt(1).SetDataLayout(phi::DataLayout::SPARSE_COO); +} #endif diff --git a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h index 0775582bf1f..a00b9c275c2 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" +#include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" namespace phi { namespace sparse { @@ -32,5 +33,13 @@ void SparseCooToDenseGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, SparseCooTensor* x_grad); +template +void SparseCooTensorGradKernel(const Context& dev_ctx, + const DenseTensor& indices, + const SparseCooTensor& out_grad, + DenseTensor* values_grad) { + SparseMaskHelperKernel(dev_ctx, out_grad, indices, values_grad); +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/sparse_utils_kernel.h b/paddle/phi/kernels/sparse/sparse_utils_kernel.h index 961cd9f829e..8cf9c0a2864 100644 --- a/paddle/phi/kernels/sparse/sparse_utils_kernel.h +++ b/paddle/phi/kernels/sparse/sparse_utils_kernel.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/api/lib/utils/storage.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" @@ -147,5 +148,16 @@ void CsrValuesKernel(const Context& dev_ctx, *out = x.non_zero_elements(); } +template +void SparseCooTensorKernel(const Context& dev_ctx, + const DenseTensor& values, + const DenseTensor& indices, + const IntArray& dense_shape, + SparseCooTensor* out) { + *out = + SparseCooTensor(indices, values, phi::make_ddim(dense_shape.GetData())); + // TODO(zhangkaihuo): sort and merge the dumplicate indices +} + } // namespace sparse } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py index 04488ac58c5..89cfc711910 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py @@ -134,9 +134,11 @@ class TestSparseConvert(unittest.TestCase): #test to_sparse_coo_grad backward out_grad_indices = [[0, 1], [0, 1]] out_grad_values = [2.0, 3.0] - out_grad = core.eager.sparse_coo_tensor( + out_grad = paddle.sparse.sparse_coo_tensor( paddle.to_tensor(out_grad_indices), - paddle.to_tensor(out_grad_values), out.shape, True) + paddle.to_tensor(out_grad_values), + shape=out.shape, + stop_gradient=True) out.backward(out_grad) assert np.array_equal(dense_x.grad.numpy(), out_grad.to_dense().numpy()) @@ -145,9 +147,11 @@ class TestSparseConvert(unittest.TestCase): with _test_eager_guard(): indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]] values = [1.0, 2.0, 3.0, 4.0, 5.0] - sparse_x = core.eager.sparse_coo_tensor( + sparse_x = paddle.sparse.sparse_coo_tensor( paddle.to_tensor(indices), - paddle.to_tensor(values), [3, 4], False) + paddle.to_tensor(values), + shape=[3, 4], + stop_gradient=False) dense_tensor = sparse_x.to_dense() #test to_dense_grad backward out_grad = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], @@ -158,6 +162,17 @@ class TestSparseConvert(unittest.TestCase): assert np.array_equal(correct_x_grad, sparse_x.grad.values().numpy()) + paddle.device.set_device("cpu") + sparse_x_cpu = paddle.sparse.sparse_coo_tensor( + paddle.to_tensor(indices), + paddle.to_tensor(values), + shape=[3, 4], + stop_gradient=False) + dense_tensor_cpu = sparse_x_cpu.to_dense() + dense_tensor_cpu.backward(paddle.to_tensor(out_grad)) + assert np.array_equal(correct_x_grad, + sparse_x_cpu.grad.values().numpy()) + def test_to_sparse_csr(self): with _test_eager_guard(): x = [[0, 1, 0, 2], [0, 0, 3, 0], [4, 5, 0, 0]] @@ -177,15 +192,52 @@ class TestSparseConvert(unittest.TestCase): with _test_eager_guard(): indices = [[0, 0, 1, 2, 2], [1, 3, 2, 0, 1]] values = [1.0, 2.0, 3.0, 4.0, 5.0] - sparse_x = core.eager.sparse_coo_tensor( + sparse_x = paddle.sparse.sparse_coo_tensor( paddle.to_tensor(indices), - paddle.to_tensor(values), [3, 4], False) + paddle.to_tensor(values), + shape=[3, 4], + stop_gradient=False) values_tensor = sparse_x.values() out_grad = [2.0, 3.0, 5.0, 8.0, 9.0] # test coo_values_grad values_tensor.backward(paddle.to_tensor(out_grad)) assert np.array_equal(out_grad, sparse_x.grad.values().numpy()) + def test_sparse_coo_tensor_grad(self): + with _test_eager_guard(): + indices = [[0, 1], [0, 1]] + values = [1, 2] + indices = paddle.to_tensor(indices, dtype='int32') + values = paddle.to_tensor( + values, dtype='float32', stop_gradient=False) + sparse_x = paddle.sparse.sparse_coo_tensor( + indices, values, shape=[2, 2], stop_gradient=False) + grad_indices = [[0, 1], [1, 1]] + grad_values = [2, 3] + grad_indices = paddle.to_tensor(grad_indices, dtype='int32') + grad_values = paddle.to_tensor(grad_values, dtype='float32') + sparse_out_grad = paddle.sparse.sparse_coo_tensor( + grad_indices, grad_values, shape=[2, 2]) + sparse_x.backward(sparse_out_grad) + correct_values_grad = [0, 3] + assert np.array_equal(correct_values_grad, values.grad.numpy()) + + place = core.CPUPlace() + indices_cpu = paddle.to_tensor(indices, dtype='int32', place=place) + values_cpu = paddle.to_tensor( + values, dtype='float32', place=place, stop_gradient=False) + sparse_x_cpu = paddle.sparse.sparse_coo_tensor( + indices_cpu, + values_cpu, + shape=[2, 2], + place=place, + stop_gradient=False) + + sparse_out_grad_cpu = paddle.sparse.sparse_coo_tensor( + grad_indices, grad_values, shape=[2, 2], place=place) + sparse_x_cpu.backward(sparse_out_grad_cpu) + assert np.array_equal(correct_values_grad, values_cpu.grad.numpy()) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/sparse/creation.py b/python/paddle/sparse/creation.py index e29351e3d17..ac9276f3142 100644 --- a/python/paddle/sparse/creation.py +++ b/python/paddle/sparse/creation.py @@ -14,6 +14,7 @@ from paddle import _C_ops from ..framework import core, dygraph_only +from ..framework import _current_expected_place, _get_paddle_place from ..tensor import to_tensor from ..tensor import max from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype @@ -38,6 +39,18 @@ def _infer_dense_shape(indices): return list(lens.numpy()) +def _get_place(place): + place = _get_paddle_place(place) + if place is None: + place = _current_expected_place() + elif not isinstance(place, (core.Place, core.CPUPlace, core.CUDAPinnedPlace, + core.CUDAPlace)): + raise ValueError( + "'place' must be any of paddle.Place, paddle.CPUPlace, paddle.CUDAPinnedPlace, paddle.CUDAPlace" + ) + return place + + @dygraph_only def sparse_coo_tensor(indices, values, @@ -94,6 +107,8 @@ def sparse_coo_tensor(indices, # values=[1., 2., 3.]) """ + place = _get_place(place) + if not isinstance(indices, core.eager.Tensor): indices = to_tensor( indices, dtype=None, place=place, stop_gradient=True) @@ -101,13 +116,20 @@ def sparse_coo_tensor(indices, values = to_tensor(values, dtype, place, stop_gradient) if len(indices.shape) != 2: raise ValueError("'indices' must be 2-D.") - if place is not None: + + if not indices.place._equals(place): indices = indices._copy_to(place, False) + + if not values.place._equals(place): values = values._copy_to(place, False) values = _handle_dtype(values, dtype) + values.stop_gradient = stop_gradient + if shape is None: shape = _infer_dense_shape(indices) - return core.eager.sparse_coo_tensor(indices, values, shape, stop_gradient) + + return _C_ops.final_state_sparse_create_sparse_coo_tensor(values, indices, + shape) #TODO: need to support shape is None @@ -171,6 +193,9 @@ def sparse_csr_tensor(crows, # cols=[1, 3, 2, 0, 1], # values=[1, 2, 3, 4, 5]) """ + + place = _get_place(place) + if not isinstance(crows, core.eager.Tensor): crows = to_tensor(crows, dtype=None, place=place, stop_gradient=True) if not isinstance(cols, core.eager.Tensor): @@ -182,10 +207,15 @@ def sparse_csr_tensor(crows, "SparseCsrTensor only support 2-D or 3-D matrix. The 'crows', 'cols' and 'values' must be 1-D." ) - if place is not None: + if not crows.place._equals(place): crows = crows._copy_to(place, False) + + if not cols.place._equals(place): cols = cols._copy_to(place, False) + + if not values.place._equals(place): values = values._copy_to(place, False) values = _handle_dtype(values, dtype) + values.stop_gradient = stop_gradient return core.eager.sparse_csr_tensor(crows, cols, values, shape, stop_gradient) diff --git a/python/paddle/utils/code_gen/sparse_api.yaml b/python/paddle/utils/code_gen/sparse_api.yaml index 7bdd77e27bc..2187d4abb2d 100644 --- a/python/paddle/utils/code_gen/sparse_api.yaml +++ b/python/paddle/utils/code_gen/sparse_api.yaml @@ -21,6 +21,14 @@ layout : x backward : coo_values_grad +- api : create_sparse_coo_tensor + args : (Tensor values, Tensor indices, IntArray dense_shape) + output : Tensor(out@SparseCooTensor) + kernel : + func : sparse_coo_tensor + layout : values + backward : create_sparse_coo_tensor_grad + - api : csr_values args : (Tensor x) output : Tensor(out@DenseTensor) diff --git a/python/paddle/utils/code_gen/sparse_bw_api.yaml b/python/paddle/utils/code_gen/sparse_bw_api.yaml index 800145b06e0..e3946cbf72b 100644 --- a/python/paddle/utils/code_gen/sparse_bw_api.yaml +++ b/python/paddle/utils/code_gen/sparse_bw_api.yaml @@ -19,6 +19,13 @@ kernel : func : coo_values_grad +- backward_api : create_sparse_coo_tensor_grad + forward : create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out@SparseCooTensor) + args : (Tensor indices, Tensor out_grad) + output : Tensor(values_grad@DenseTensor) + kernel : + func : sparse_coo_tensor_grad + - backward_api : dense_to_coo_grad forward : dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out@SparseCooTensor) args : (Tensor out_grad) -- GitLab