未验证 提交 aeb8c2e2 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]Fix the bug of elementwise_grad (#52102)

上级 8b622d58
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/phi/kernels/elementwise_kernel.h" #include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h" #include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace phi { namespace phi {
...@@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx, ...@@ -39,6 +40,7 @@ void AllocCsrPtr(const Context& dev_ctx,
DenseTensor dx_crows = phi::EmptyLike<IntT>(dev_ctx, x.crows()); DenseTensor dx_crows = phi::EmptyLike<IntT>(dev_ctx, x.crows());
DenseTensor dx_cols = phi::EmptyLike<IntT>(dev_ctx, x.cols()); DenseTensor dx_cols = phi::EmptyLike<IntT>(dev_ctx, x.cols());
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values()); DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
dx->set_meta(x.meta());
dx->SetMember(dx_crows, dx_cols, dx_values, x.dims()); dx->SetMember(dx_crows, dx_cols, dx_values, x.dims());
} }
...@@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx, ...@@ -48,9 +50,117 @@ void AllocCooPtr(const Context& dev_ctx,
SparseCooTensor* dx) { SparseCooTensor* dx) {
DenseTensor dx_indices = phi::EmptyLike<IntT>(dev_ctx, x.indices()); DenseTensor dx_indices = phi::EmptyLike<IntT>(dev_ctx, x.indices());
DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values()); DenseTensor dx_values = phi::EmptyLike<T>(dev_ctx, x.values());
dx->set_meta(x.meta());
dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced()); dx->SetMember(dx_indices, dx_values, x.dims(), x.coalesced());
} }
template <typename T, typename IntT, typename Context>
void CopyCooValues(const Context& dev_ctx,
const SparseCooTensor& dout,
const SparseCooTensor& x,
SparseCooTensor* dx) {
Copy(dev_ctx, x.indices(), dev_ctx.GetPlace(), false, dx->mutable_indices());
const int sparse_dim = x.sparse_dim();
std::vector<IntT> sparse_offsets(sparse_dim), dout_indexs(dout.nnz()),
x_indexs(x.nnz());
phi::funcs::sparse::CalcOffsetsPerDim<IntT>(
dout.dims(), sparse_dim, sparse_offsets.data());
phi::funcs::sparse::FlattenIndices(dout.indices().data<IntT>(),
sparse_offsets.data(),
dout.nnz(),
sparse_dim,
0,
1,
dout_indexs.data());
phi::funcs::sparse::FlattenIndices(x.indices().data<IntT>(),
sparse_offsets.data(),
x.nnz(),
sparse_dim,
0,
1,
x_indexs.data());
size_t i = 0, j = 0;
T* dx_values_ptr = dx->mutable_values()->data<T>();
const T* dout_values_ptr = dout.values().data<T>();
int64_t element_size = 1;
for (auto j = 1; j < x.values().dims().size(); ++j) {
element_size *= x.values().dims()[j];
}
while (i < dout_indexs.size() && j < x_indexs.size()) {
if (dout_indexs[i] == x_indexs[j]) {
memcpy(dx_values_ptr + j * element_size,
dout_values_ptr + i * element_size,
element_size * sizeof(T));
++i;
++j;
} else if (dout_indexs[i] > x_indexs[j]) {
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
++j;
} else {
++i;
}
}
while (j < x_indexs.size()) {
memset(dx_values_ptr + j * element_size, 0, element_size * sizeof(T));
++j;
}
}
template <typename T, typename IntT, typename Context>
void CopyCsrValues(const Context& dev_ctx,
const SparseCsrTensor& dout,
const SparseCsrTensor& x,
SparseCsrTensor* dx) {
Copy(dev_ctx, x.crows(), dev_ctx.GetPlace(), false, dx->mutable_crows());
Copy(dev_ctx, x.cols(), dev_ctx.GetPlace(), false, dx->mutable_cols());
const auto& x_dims = x.dims();
int batch = x_dims.size() == 2 ? 1 : x_dims[0];
int rows = x_dims.size() == 2 ? x_dims[0] : x_dims[1];
const IntT* x_crows_ptr = x.crows().data<IntT>();
const IntT* x_cols_ptr = x.cols().data<IntT>();
const IntT* dout_crows_ptr = dout.crows().data<IntT>();
const IntT* dout_cols_ptr = dout.cols().data<IntT>();
const T* dout_values_ptr = dout.values().data<T>();
T* dx_values_ptr = dx->mutable_values()->data<T>();
for (int b = 0; b < batch; b++) {
for (int r = 0; r < rows; r++) {
int x_start = x_crows_ptr[b * (rows + 1) + r];
int dout_start = dout_crows_ptr[b * (rows + 1) + r];
int x_row_nnz = x_crows_ptr[b * (rows + 1) + r + 1] - x_start;
int dout_row_nnz = dout_crows_ptr[b * (rows + 1) + r + 1] - dout_start;
int i = 0, j = 0;
while (i < x_row_nnz && j < dout_row_nnz) {
if (x_cols_ptr[x_start + i] == dout_cols_ptr[dout_start + j]) {
dx_values_ptr[x_start + i] = dout_values_ptr[dout_start + j];
++i;
++j;
} else if (x_cols_ptr[x_start + i] < dout_cols_ptr[dout_start + j]) {
dx_values_ptr[x_start + i] = static_cast<T>(0);
++i;
} else {
++j;
}
}
while (i < x_row_nnz) {
dx_values_ptr[x_start + i] = static_cast<T>(0);
++i;
}
}
}
}
template <typename T, typename IntT, typename Context> template <typename T, typename IntT, typename Context>
void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
const SparseCsrTensor& x, const SparseCsrTensor& x,
...@@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx, ...@@ -62,16 +172,16 @@ void ElementWiseAddCsrGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) { if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed"; VLOG(4) << "Special case when dy is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} else if (dx == nullptr && dy != nullptr) { } else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed"; VLOG(4) << "Special case when dx is not needed";
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} else { } else {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} }
} }
...@@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx, ...@@ -84,12 +194,12 @@ void ElementWiseSubtractCsrGradCPUKernel(const Context& dev_ctx,
SparseCsrTensor* dy) { SparseCsrTensor* dy) {
if (dx) { if (dx) {
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} }
if (dy) { if (dy) {
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCsrValues<T, IntT, Context>(dev_ctx, dout, y, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values()); dev_ctx, dout.values(), dy->mutable_values());
} }
...@@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx, ...@@ -105,13 +215,19 @@ void ElementWiseMultiplyCsrGradCPUKernel(const Context& dev_ctx,
if (dx) { if (dx) {
// dout*y // dout*y
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, dx); SparseCsrTensor tmp_dx;
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
} }
if (dy) { if (dy) {
// dout*x // dout*x
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, dy); SparseCsrTensor tmp_dy;
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
sparse::ElementWiseMultiplyCsrKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
} }
} }
...@@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx, ...@@ -126,17 +242,24 @@ void ElementWiseDivideCsrGradCPUKernel(const Context& dev_ctx,
if (dx) { if (dx) {
// dout/y // dout/y
AllocCsrPtr<T, IntT>(dev_ctx, x, dx); AllocCsrPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, dx); SparseCsrTensor tmp_dx;
AllocCsrPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
} }
if (dy) { if (dy) {
// -dout * out / y // -dout * out / y
AllocCsrPtr<T, IntT>(dev_ctx, y, dy); AllocCsrPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); SparseCsrTensor tmp_dy;
AllocCsrPtr<T, IntT>(dev_ctx, y, &tmp_dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values()); dev_ctx, dout.values(), tmp_dy.mutable_values());
auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, *dy, out); auto tmp = sparse::ElementWiseMultiplyCsr<T, Context>(dev_ctx, tmp_dy, out);
sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, dy); sparse::ElementWiseDivideCsrKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
CopyCsrValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
} }
} }
...@@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx, ...@@ -151,16 +274,16 @@ void ElementWiseAddCooGradCPUKernel(const Context& dev_ctx,
if (dx != nullptr && dy == nullptr) { if (dx != nullptr && dy == nullptr) {
VLOG(4) << "Special case when dy is not needed"; VLOG(4) << "Special case when dy is not needed";
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} else if (dx == nullptr && dy != nullptr) { } else if (dx == nullptr && dy != nullptr) {
VLOG(4) << "Special case when dx is not needed"; VLOG(4) << "Special case when dx is not needed";
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} else { } else {
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
} }
} }
...@@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx, ...@@ -173,12 +296,12 @@ void ElementWiseSubtractCooGradCPUKernel(const Context& dev_ctx,
SparseCooTensor* dy) { SparseCooTensor* dy) {
if (dx) { if (dx) {
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dx); CopyCooValues<T, IntT, Context>(dev_ctx, dout, x, dx);
} }
if (dy) { if (dy) {
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); CopyCooValues<T, IntT, Context>(dev_ctx, dout, y, dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values()); dev_ctx, dout.values(), dy->mutable_values());
} }
...@@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx, ...@@ -194,13 +317,19 @@ void ElementWiseMultiplyCooGradCPUKernel(const Context& dev_ctx,
if (dx) { if (dx) {
// dout*y // dout*y
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, dx); SparseCooTensor tmp_dx;
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
} }
if (dy) { if (dy) {
// dout*x // dout*x
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, dy); SparseCooTensor tmp_dy;
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
sparse::ElementWiseMultiplyCooKernel<T, Context>(dev_ctx, dout, x, &tmp_dy);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
} }
} }
...@@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx, ...@@ -215,22 +344,26 @@ void ElementWiseDivideCooGradCPUKernel(const Context& dev_ctx,
if (dx) { if (dx) {
// dout/y // dout/y
AllocCooPtr<T, IntT>(dev_ctx, x, dx); AllocCooPtr<T, IntT>(dev_ctx, x, dx);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, dx); SparseCooTensor tmp_dx;
AllocCooPtr<T, IntT>(dev_ctx, x, &tmp_dx);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, dout, y, &tmp_dx);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dx, x, dx);
} }
if (dy) { if (dy) {
// -dout * out / y // -dout * out / y
AllocCooPtr<T, IntT>(dev_ctx, y, dy); AllocCooPtr<T, IntT>(dev_ctx, y, dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, dy); SparseCooTensor tmp_dy;
AllocCooPtr<T, IntT>(dev_ctx, y, &tmp_dy);
Copy(dev_ctx, dout, dev_ctx.GetPlace(), false, &tmp_dy);
phi::NegativeKernel<T, Context>( phi::NegativeKernel<T, Context>(
dev_ctx, dout.values(), dy->mutable_values()); dev_ctx, dout.values(), tmp_dy.mutable_values());
auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, *dy, out); auto tmp = sparse::ElementWiseMultiplyCoo<T, Context>(dev_ctx, tmp_dy, out);
sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, dy); sparse::ElementWiseDivideCooKernel<T, Context>(dev_ctx, tmp, y, &tmp_dy);
CopyCooValues<T, IntT, Context>(dev_ctx, tmp_dy, y, dy);
} }
} }
// CPU Kernel end
// Kernel
template <typename T, typename Context> template <typename T, typename Context>
void ElementWiseDivideCsrGradKernel(const Context& dev_ctx, void ElementWiseDivideCsrGradKernel(const Context& dev_ctx,
const SparseCsrTensor& x, const SparseCsrTensor& x,
......
...@@ -32,18 +32,13 @@ template <typename T, typename Functor> ...@@ -32,18 +32,13 @@ template <typename T, typename Functor>
struct BinaryOPWithZeroCompareFunctor { struct BinaryOPWithZeroCompareFunctor {
explicit BinaryOPWithZeroCompareFunctor(Functor functor) explicit BinaryOPWithZeroCompareFunctor(Functor functor)
: functor_(functor) {} : functor_(functor) {}
inline HOSTDEVICE bool operator()(const T* a, inline HOSTDEVICE void operator()(const T* a,
const T* b, const T* b,
T* result, T* result,
const int64_t len) const { const int64_t len) const {
bool is_zero = true;
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
result[i] = functor_(a[i], b[i]); result[i] = functor_(a[i], b[i]);
if (result[i] != 0) {
is_zero = false;
}
} }
return is_zero;
} }
Functor functor_; Functor functor_;
}; };
...@@ -88,55 +83,41 @@ void Merge(const IntT el_len, ...@@ -88,55 +83,41 @@ void Merge(const IntT el_len,
// merge // merge
while (a < len_a && b < (is_divide ? len_b_max : len_b)) { while (a < len_a && b < (is_divide ? len_b_max : len_b)) {
if (a_index[a] == b_index[b]) { if (a_index[a] == b_index[b]) {
if (!functor(a_values + a * el_len, functor(a_values + a * el_len,
b_values[b_index[b]], b_values[b_index[b]],
c_values + nnz * el_len, c_values + nnz * el_len,
el_len)) { el_len);
c_index[nnz] = a_index[a]; c_index[nnz] = a_index[a];
++nnz; ++nnz;
}
++a; ++a;
++b; ++b;
} else if (a_index[a] < b_index[b]) { // coordinate x[a] < coordinate y[b] } else if (a_index[a] < b_index[b]) { // coordinate x[a] < coordinate y[b]
if (!functor(a_values + a * el_len, functor(
zero.data(), a_values + a * el_len, zero.data(), c_values + nnz * el_len, el_len);
c_values + nnz * el_len, c_index[nnz] = a_index[a];
el_len)) { ++nnz;
c_index[nnz] = a_index[a];
++nnz;
}
++a; ++a;
} else if (a_index[a] > b_index[b]) { // coordinate x[a] > coordinate y[b] } else if (a_index[a] > b_index[b]) { // coordinate x[a] > coordinate y[b]
if (!functor(zero.data(), functor(
b_values[b_index[b]], zero.data(), b_values[b_index[b]], c_values + nnz * el_len, el_len);
c_values + nnz * el_len, c_index[nnz] = b_index[b];
el_len)) { ++nnz;
c_index[nnz] = b_index[b];
++nnz;
}
++b; ++b;
} }
} }
// a tail // a tail
while (a < len_a) { while (a < len_a) {
if (!functor(a_values + a * el_len, functor(
zero.data(), a_values + a * el_len, zero.data(), c_values + nnz * el_len, el_len);
c_values + nnz * el_len, c_index[nnz] = a_index[a];
el_len)) { ++nnz;
c_index[nnz] = a_index[a];
++nnz;
}
++a; ++a;
} }
// b tail // b tail
while (b < (is_divide ? len_b_max : len_b)) { while (b < (is_divide ? len_b_max : len_b)) {
if (!functor(zero.data(), functor(zero.data(), b_values[b_index[b]], c_values + nnz * el_len, el_len);
b_values[b_index[b]], c_index[nnz] = b_index[b];
c_values + nnz * el_len, ++nnz;
el_len)) {
c_index[nnz] = b_index[b];
++nnz;
}
++b; ++b;
} }
} }
......
...@@ -37,6 +37,11 @@ def get_actual_res(x, y, op): ...@@ -37,6 +37,11 @@ def get_actual_res(x, y, op):
return res return res
def mask_to_zero(x, mask):
x[mask == 0] = 0
return x
class TestSparseElementWiseAPI(unittest.TestCase): class TestSparseElementWiseAPI(unittest.TestCase):
""" """
test paddle.sparse.add, subtract, multiply, divide test paddle.sparse.add, subtract, multiply, divide
...@@ -45,14 +50,20 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -45,14 +50,20 @@ class TestSparseElementWiseAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(2022) np.random.seed(2022)
self.op_list = op_list self.op_list = op_list
self.csr_shape = [128, 256] self.csr_shape = [8, 8]
self.coo_shape = [4, 8, 3, 5] self.coo_shape = [4, 8, 3, 5]
self.support_dtypes = ['float32', 'float64', 'int32', 'int64'] self.support_dtypes = ['float32', 'float64', 'int32', 'int64']
def func_test_csr(self, op): def func_test_csr(self, op):
for dtype in self.support_dtypes: for dtype in self.support_dtypes:
x = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) x = np.random.randint(-255, 255, size=self.csr_shape)
y = np.random.randint(-255, 255, size=self.csr_shape).astype(dtype) y = np.random.randint(-255, 255, size=self.csr_shape)
mask_x = x / x
mask_y = y / y
mask_x[mask_x != 1] = 0
mask_y[mask_y != 1] = 0
x = x.astype(dtype)
y = y.astype(dtype)
dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False) dense_x = paddle.to_tensor(x, dtype=dtype, stop_gradient=False)
dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False) dense_y = paddle.to_tensor(y, dtype=dtype, stop_gradient=False)
...@@ -63,9 +74,10 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -63,9 +74,10 @@ class TestSparseElementWiseAPI(unittest.TestCase):
csr_y = s_dense_y.to_sparse_csr() csr_y = s_dense_y.to_sparse_csr()
actual_res = get_actual_res(csr_x, csr_y, op) actual_res = get_actual_res(csr_x, csr_y, op)
actual_res.backward()
expect_res = op(dense_x, dense_y) expect_res = op(dense_x, dense_y)
expect_res.backward(expect_res) expect_res.backward()
np.testing.assert_allclose( np.testing.assert_allclose(
expect_res.numpy(), expect_res.numpy(),
...@@ -74,15 +86,14 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -74,15 +86,14 @@ class TestSparseElementWiseAPI(unittest.TestCase):
equal_nan=True, equal_nan=True,
) )
if not (op == __truediv__ and dtype in ['int32', 'int64']): if not (op == __truediv__ and dtype in ['int32', 'int64']):
actual_res.backward(actual_res)
np.testing.assert_allclose( np.testing.assert_allclose(
dense_x.grad.numpy(), mask_to_zero(dense_x.grad.numpy(), mask_x),
csr_x.grad.to_dense().numpy(), csr_x.grad.to_dense().numpy(),
rtol=1e-05, rtol=1e-05,
equal_nan=True, equal_nan=True,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
dense_y.grad.numpy(), mask_to_zero(dense_y.grad.numpy(), mask_y),
csr_y.grad.to_dense().numpy(), csr_y.grad.to_dense().numpy(),
rtol=1e-05, rtol=1e-05,
equal_nan=True, equal_nan=True,
...@@ -124,12 +135,14 @@ class TestSparseElementWiseAPI(unittest.TestCase): ...@@ -124,12 +135,14 @@ class TestSparseElementWiseAPI(unittest.TestCase):
rtol=1e-05, rtol=1e-05,
equal_nan=True, equal_nan=True,
) )
np.testing.assert_allclose(coo_x.shape, coo_x.grad.shape)
np.testing.assert_allclose( np.testing.assert_allclose(
dense_x.grad.numpy(), dense_x.grad.numpy(),
coo_x.grad.to_dense().numpy(), coo_x.grad.to_dense().numpy(),
rtol=1e-05, rtol=1e-05,
equal_nan=True, equal_nan=True,
) )
np.testing.assert_allclose(coo_y.shape, coo_y.grad.shape)
np.testing.assert_allclose( np.testing.assert_allclose(
dense_y.grad.numpy(), dense_y.grad.numpy(),
coo_y.grad.to_dense().numpy(), coo_y.grad.to_dense().numpy(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册