未验证 提交 d41a9373 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix batch csr (#43708)

上级 7673b39a
......@@ -198,7 +198,7 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
const auto& coo_values = x.non_zero_elements();
const IntT* batchs_ptr = coo_indices.data<IntT>();
const IntT* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num;
x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
const IntT* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>();
......
......@@ -371,7 +371,7 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
const auto& coo_values = x.non_zero_elements();
const IntT* batchs_ptr = coo_indices.data<IntT>();
const IntT* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num;
x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
const IntT* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>();
......
......@@ -318,50 +318,40 @@ class TestSparseConvert(unittest.TestCase):
def test_batch_csr(self):
with _test_eager_guard():
shape = [3, 3, 3]
def verify(x, crows, cols, values):
x = paddle.to_tensor(x)
csr = x.to_sparse_csr()
assert np.allclose(crows, csr.crows().numpy())
assert np.allclose(cols, csr.cols().numpy())
assert np.allclose(values, csr.values().numpy())
dense = csr.to_dense()
assert np.allclose(x.numpy(), dense.numpy())
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 1, 2, 3, 0, 0, 0, 0, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
]
crows = [[0, 0, 0, 0, 0, 1, 2, 3, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
]
crows = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 0, 0]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
def verify(dense_x):
sparse_x = dense_x.to_sparse_csr()
out = sparse_x.to_dense()
assert np.allclose(out.numpy(), dense_x.numpy())
shape = np.random.randint(low=1, high=10, size=3)
shape = list(shape)
dense_x = paddle.randn(shape)
dense_x = paddle.nn.functional.dropout(dense_x, p=0.5)
verify(dense_x)
#test batchs=1
shape[0] = 1
dense_x = paddle.randn(shape)
dense_x = paddle.nn.functional.dropout(dense_x, p=0.5)
verify(dense_x)
shape = np.random.randint(low=2, high=10, size=3)
shape = list(shape)
dense_x = paddle.randn(shape)
#set the 0th batch to zero
dense_x[0] = 0
verify(dense_x)
dense_x = paddle.randn(shape)
#set the 1th batch to zero
dense_x[1] = 0
verify(dense_x)
dense_x = paddle.randn(shape)
#set the 2th batch to zero
dense_x[2] = 0
verify(dense_x)
class TestCooError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册