diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc index 98afed84d6643836c5a36779dc05a646315d4150..2a609d8c1f073e2c43cd54d044ead90867d09adb 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h" #include "paddle/phi/kernels/sparse/empty_kernel.h" namespace phi { @@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx, DenseTensor dx_crows = phi::EmptyLike(dev_ctx, x.crows()); DenseTensor dx_cols = phi::EmptyLike(dev_ctx, x.cols()); DenseTensor dx_values = phi::EmptyLike(dev_ctx, x.values()); + dx->set_meta(x.meta()); dx->SetMember(dx_crows, dx_cols, dx_values, x.dims()); } @@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx, SparseCooTensor* dx) { DenseTensor dx_indices = phi::EmptyLike(dev_ctx, x.indices()); DenseTensor dx_values = phi::EmptyLike(dev_ctx, x.values()); + dx->set_meta(x.meta()); dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced()); } +template +void CopyCooValues(const Context& dev_ctx, + const SparseCooTensor& dout, + const SparseCooTensor& x, + SparseCooTensor* dx) { + Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, dx->mutable_indices()); + + const int sparse_dim = x.sparse_dim(); + std::vector sparse_offsets(sparse_dim), dout_indexs(dout.nnz()), + x_indexs(x.nnz()); + + phi::funcs::sparse::CalcOffsetsPerDim( + dout.dims(), sparse_dim, sparse_offsets.data()); + + phi::funcs::sparse::FlattenIndices(dout.indices().data(), + sparse_offsets.data(), + dout.nnz(), + sparse_dim, + 0, + 1, + dout_indexs.data()); + + phi::funcs::sparse::FlattenIndices(x.indices().data(), + sparse_offsets.data(), + x.nnz(), + sparse_dim, + 0, + 1, + x_indexs.data()); + + size_t i = 0, j = 0; + T* dx_values_ptr = dx->mutable_values()->data(); + const T* dout_values_ptr = dout.values().data(); + + int64_t element_size = 1; + for (auto j = 1; j < x.values().dims().size(); ++j) { + element_size *= x.values().dims()[j]; + } + + while (i < dout_indexs.size() && j < x_indexs.size()) { + if (dout_indexs[i] == x_indexs[j]) { + memcpy(dx_values_ptr + j * element_size, + dout_values_ptr + i * element_size, + element_size * sizeof(T)); + ++i; + ++j; + } else if (dout_indexs[i] > x_indexs[j]) { + memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T)); + ++j; + } else { + ++i; + } + } + while (j < x_indexs.size()) { + memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T)); + ++j; + } +} + +template +void CopyCsrValues(const Context& dev_ctx, + const SparseCsrTensor& dout, + const SparseCsrTensor& x, + SparseCsrTensor* dx) { + Copy(dev_ctx, x.crows(), dev_ctx.GetPlace(), false, dx->mutable_crows()); + Copy(dev_ctx, x.cols(), dev_ctx.GetPlace(), false, dx->mutable_cols()); + + const auto& x_dims = x.dims(); + int batch = x_dims.size() == 2 ? 1 : x_dims[0]; + int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1]; + + const IntT* x_crows_ptr = x.crows().data(); + const IntT* x_cols_ptr = x.cols().data(); + + const IntT* dout_crows_ptr = dout.crows().data(); + const IntT* dout_cols_ptr = dout.cols().data(); + const T* dout_values_ptr = dout.values().data(); + + T* dx_values_ptr = dx->mutable_values()->data(); + + for (int b = 0; b < batch; b++) { + for (int r = 0; r < rows; r++) { + int x_start = x_crows_ptr[b * (rows + 1) + r]; + int dout_start = dout_crows_ptr[b * (rows + 1) + r]; + int x_row_nnz = x_crows_ptr[b * (rows + 1) + r + 1] - x_start; + int dout_row_nnz = dout_crows_ptr[b * (rows + 1) + r + 1] - dout_start; + int i = 0, j = 0; + while (i < x_row_nnz && j < dout_row_nnz) { + if (x_cols_ptr[x_start + i] == dout_cols_ptr[dout_start + j]) { + dx_values_ptr[x_start + i] = dout_values_ptr[dout_start + j]; + ++i; + ++j; + } else if (x_cols_ptr[x_start + i] < dout_cols_ptr[dout_start + j]) { + dx_values_ptr[x_start + i] = static_cast(0); + ++i; + } else { + ++j; + } + } + while (i < x_row_nnz) { + dx_values_ptr[x_start + i] = static_cast(0); + ++i; + } + } + } +} + template void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, const SparseCsrTensor& x, @@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, if (dx != nullptr && dy == nullptr) { VLOG(4) << "Special case when dy is not needed"; AllocCsrPtr(dev_ctx, x, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCsrValues(dev_ctx, dout, x, dx); } else if (dx == nullptr && dy != nullptr) { VLOG(4) << "Special case when dx is not needed"; AllocCsrPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCsrValues(dev_ctx, dout, y, dy); } else { AllocCsrPtr(dev_ctx, x, dx); AllocCsrPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCsrValues(dev_ctx, dout, x, dx); + CopyCsrValues(dev_ctx, dout, y, dy); } } @@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx, SparseCsrTensor* dy) { if (dx) { AllocCsrPtr(dev_ctx, x, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCsrValues(dev_ctx, dout, x, dx); } if (dy) { AllocCsrPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCsrValues(dev_ctx, dout, y, dy); phi::NegativeKernel( dev_ctx, dout.values(), dy->mutable_values()); } @@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx, if (dx) { // dout*y AllocCsrPtr(dev_ctx, x, dx); - sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, y, dx); + SparseCsrTensor tmp_dx; + AllocCsrPtr(dev_ctx, x, &tmp_dx); + sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, y, &tmp_dx); + CopyCsrValues(dev_ctx, tmp_dx, x, dx); } if (dy) { // dout*x AllocCsrPtr(dev_ctx, y, dy); - sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, x, dy); + SparseCsrTensor tmp_dy; + AllocCsrPtr(dev_ctx, y, &tmp_dy); + sparse::ElementWiseMultiplyCsrKernel(dev_ctx, dout, x, &tmp_dy); + CopyCsrValues(dev_ctx, tmp_dy, y, dy); } } @@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx, if (dx) { // dout/y AllocCsrPtr(dev_ctx, x, dx); - sparse::ElementWiseDivideCsrKernel(dev_ctx, dout, y, dx); + SparseCsrTensor tmp_dx; + AllocCsrPtr(dev_ctx, x, &tmp_dx); + sparse::ElementWiseDivideCsrKernel(dev_ctx, dout, y, &tmp_dx); + CopyCsrValues(dev_ctx, tmp_dx, x, dx); } if (dy) { // -dout * out / y AllocCsrPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + SparseCsrTensor tmp_dy; + AllocCsrPtr(dev_ctx, y, &tmp_dy); + + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy); phi::NegativeKernel( - dev_ctx, dout.values(), dy->mutable_values()); - auto tmp = sparse::ElementWiseMultiplyCsr(dev_ctx, *dy, out); - sparse::ElementWiseDivideCsrKernel(dev_ctx, tmp, y, dy); + dev_ctx, dout.values(), tmp_dy.mutable_values()); + auto tmp = sparse::ElementWiseMultiplyCsr(dev_ctx, tmp_dy, out); + sparse::ElementWiseDivideCsrKernel(dev_ctx, tmp, y, &tmp_dy); + CopyCsrValues(dev_ctx, tmp_dy, y, dy); } } @@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx, if (dx != nullptr && dy == nullptr) { VLOG(4) << "Special case when dy is not needed"; AllocCooPtr(dev_ctx, x, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCooValues(dev_ctx, dout, x, dx); } else if (dx == nullptr && dy != nullptr) { VLOG(4) << "Special case when dx is not needed"; AllocCooPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCooValues(dev_ctx, dout, y, dy); } else { AllocCooPtr(dev_ctx, x, dx); AllocCooPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCooValues(dev_ctx, dout, x, dx); + CopyCooValues(dev_ctx, dout, y, dy); } } @@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx, SparseCooTensor* dy) { if (dx) { AllocCooPtr(dev_ctx, x, dx); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); + CopyCooValues(dev_ctx, dout, x, dx); } if (dy) { AllocCooPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + CopyCooValues(dev_ctx, dout, y, dy); phi::NegativeKernel( dev_ctx, dout.values(), dy->mutable_values()); } @@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx, if (dx) { // dout*y AllocCooPtr(dev_ctx, x, dx); - sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, y, dx); + SparseCooTensor tmp_dx; + AllocCooPtr(dev_ctx, x, &tmp_dx); + sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, y, &tmp_dx); + CopyCooValues(dev_ctx, tmp_dx, x, dx); } if (dy) { // dout*x AllocCooPtr(dev_ctx, y, dy); - sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, x, dy); + SparseCooTensor tmp_dy; + AllocCooPtr(dev_ctx, y, &tmp_dy); + sparse::ElementWiseMultiplyCooKernel(dev_ctx, dout, x, &tmp_dy); + CopyCooValues(dev_ctx, tmp_dy, y, dy); } } @@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx, if (dx) { // dout/y AllocCooPtr(dev_ctx, x, dx); - sparse::ElementWiseDivideCooKernel(dev_ctx, dout, y, dx); + SparseCooTensor tmp_dx; + AllocCooPtr(dev_ctx, x, &tmp_dx); + sparse::ElementWiseDivideCooKernel(dev_ctx, dout, y, &tmp_dx); + CopyCooValues(dev_ctx, tmp_dx, x, dx); } if (dy) { // -dout * out / y AllocCooPtr(dev_ctx, y, dy); - Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); + SparseCooTensor tmp_dy; + AllocCooPtr(dev_ctx, y, &tmp_dy); + Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy); phi::NegativeKernel( - dev_ctx, dout.values(), dy->mutable_values()); - auto tmp = sparse::ElementWiseMultiplyCoo(dev_ctx, *dy, out); - sparse::ElementWiseDivideCooKernel(dev_ctx, tmp, y, dy); + dev_ctx, dout.values(), tmp_dy.mutable_values()); + auto tmp = sparse::ElementWiseMultiplyCoo(dev_ctx, tmp_dy, out); + sparse::ElementWiseDivideCooKernel(dev_ctx, tmp, y, &tmp_dy); + CopyCooValues(dev_ctx, tmp_dy, y, dy); } } -// CPU Kernel end -// Kernel template void ElementWiseDivideCsrGradKernel(const Context& dev_ctx, const SparseCsrTensor& x, diff --git a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc index 5c849c2f7bbad30fd26b80e0de38de00e3a731b1..72e3d00962b5dc7c0134274496cce4279069d97a 100644 --- a/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc @@ -32,18 +32,13 @@ template struct BinaryOPWithZeroCompareFunctor { explicit BinaryOPWithZeroCompareFunctor(Functor functor) : functor_(functor) {} - inline HOSTDEVICE bool operator()(const T* a, + inline HOSTDEVICE void operator()(const T* a, const T* b, T* result, const int64_t len) const { - bool is_zero = true; for (int64_t i = 0; i < len; ++i) { result[i] = functor_(a[i], b[i]); - if (result[i] != 0) { - is_zero = false; - } } - return is_zero; } Functor functor_; }; @@ -88,55 +83,41 @@ void Merge(const IntT el_len, // merge while (a < len_a && b < (is_divide ? len_b_max : len_b)) { if (a_index[a] == b_index[b]) { - if (!functor(a_values + a * el_len, - b_values[b_index[b]], - c_values + nnz * el_len, - el_len)) { - c_index[nnz] = a_index[a]; - ++nnz; - } + functor(a_values + a * el_len, + b_values[b_index[b]], + c_values + nnz * el_len, + el_len); + c_index[nnz] = a_index[a]; + ++nnz; ++a; ++b; } else if (a_index[a] < b_index[b]) { // coordinate x[a] < coordinate y[b] - if (!functor(a_values + a * el_len, - zero.data(), - c_values + nnz * el_len, - el_len)) { - c_index[nnz] = a_index[a]; - ++nnz; - } + functor( + a_values + a * el_len, zero.data(), c_values + nnz * el_len, el_len); + c_index[nnz] = a_index[a]; + ++nnz; ++a; } else if (a_index[a] > b_index[b]) { // coordinate x[a] > coordinate y[b] - if (!functor(zero.data(), - b_values[b_index[b]], - c_values + nnz * el_len, - el_len)) { - c_index[nnz] = b_index[b]; - ++nnz; - } + functor( + zero.data(), b_values[b_index[b]], c_values + nnz * el_len, el_len); + c_index[nnz] = b_index[b]; + ++nnz; ++b; } } // a tail while (a < len_a) { - if (!functor(a_values + a * el_len, - zero.data(), - c_values + nnz * el_len, - el_len)) { - c_index[nnz] = a_index[a]; - ++nnz; - } + functor( + a_values + a * el_len, zero.data(), c_values + nnz * el_len, el_len); + c_index[nnz] = a_index[a]; + ++nnz; ++a; } // b tail while (b < (is_divide ? len_b_max : len_b)) { - if (!functor(zero.data(), - b_values[b_index[b]], - c_values + nnz * el_len, - el_len)) { - c_index[nnz] = b_index[b]; - ++nnz; - } + functor(zero.data(), b_values[b_index[b]], c_values + nnz * el_len, el_len); + c_index[nnz] = b_index[b]; + ++nnz; ++b; } } diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py index bec8b1a3447701277f8cf6164adaa240dc00a63c..3583d861e49de15efe80a022aaf8df9011f48b6d 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py @@ -37,6 +37,11 @@ def get_actual_res(x, y, op): return res +def mask_to_zero(x, mask): + x[mask == 0] = 0 + return x + + class TestSparseElementWiseAPI(unittest.TestCase): """ test paddle.sparse.add, subtract, multiply, divide @@ -45,14 +50,20 @@ class TestSparseElementWiseAPI(unittest.TestCase): def setUp(self): np.random.seed(2022) self.op_list = op_list - self.csr_shape = [128, 256] + self.csr_shape = [8, 8] self.coo_shape = [4, 8, 3, 5] self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] def func_test_csr(self, op): for dtype in self.support_dtypes: - x = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) - y = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) + x = np.random.randint(-255, 255, size=self.csr_shape) + y = np.random.randint(-255, 255, size=self.csr_shape) + mask_x = x / x + mask_y = y / y + mask_x[mask_x != 1] = 0 + mask_y[mask_y != 1] = 0 + x = x.astype(dtype) + y = y.astype(dtype) dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) @@ -63,9 +74,10 @@ class TestSparseElementWiseAPI(unittest.TestCase): csr_y = s_dense_y.to_sparse_csr() actual_res = get_actual_res(csr_x, csr_y, op) + actual_res.backward() expect_res = op(dense_x, dense_y) - expect_res.backward(expect_res) + expect_res.backward() np.testing.assert_allclose( expect_res.numpy(), @@ -74,15 +86,14 @@ class TestSparseElementWiseAPI(unittest.TestCase): equal_nan=True, ) if not (op == __truediv__ and dtype in ['int32', 'int64']): - actual_res.backward(actual_res) np.testing.assert_allclose( - dense_x.grad.numpy(), + mask_to_zero(dense_x.grad.numpy(), mask_x), csr_x.grad.to_dense().numpy(), rtol=1e-05, equal_nan=True, ) np.testing.assert_allclose( - dense_y.grad.numpy(), + mask_to_zero(dense_y.grad.numpy(), mask_y), csr_y.grad.to_dense().numpy(), rtol=1e-05, equal_nan=True, @@ -124,12 +135,14 @@ class TestSparseElementWiseAPI(unittest.TestCase): rtol=1e-05, equal_nan=True, ) + np.testing.assert_allclose(coo_x.shape, coo_x.grad.shape) np.testing.assert_allclose( dense_x.grad.numpy(), coo_x.grad.to_dense().numpy(), rtol=1e-05, equal_nan=True, ) + np.testing.assert_allclose(coo_y.shape, coo_y.grad.shape) np.testing.assert_allclose( dense_y.grad.numpy(), coo_y.grad.to_dense().numpy(),