未验证 提交 d41a9373 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix batch csr (#43708)

上级 7673b39a
...@@ -198,7 +198,7 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx, ...@@ -198,7 +198,7 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
const auto& coo_values = x.non_zero_elements(); const auto& coo_values = x.non_zero_elements();
const IntT* batchs_ptr = coo_indices.data<IntT>(); const IntT* batchs_ptr = coo_indices.data<IntT>();
const IntT* coo_rows_data = const IntT* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
const IntT* coo_cols_data = coo_rows_data + non_zero_num; const IntT* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>(); const T* coo_values_data = coo_values.data<T>();
......
...@@ -371,7 +371,7 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx, ...@@ -371,7 +371,7 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
const auto& coo_values = x.non_zero_elements(); const auto& coo_values = x.non_zero_elements();
const IntT* batchs_ptr = coo_indices.data<IntT>(); const IntT* batchs_ptr = coo_indices.data<IntT>();
const IntT* coo_rows_data = const IntT* coo_rows_data =
batchs == 1 ? batchs_ptr : batchs_ptr + non_zero_num; x_dims.size() == 2 ? batchs_ptr : batchs_ptr + non_zero_num;
const IntT* coo_cols_data = coo_rows_data + non_zero_num; const IntT* coo_cols_data = coo_rows_data + non_zero_num;
const T* coo_values_data = coo_values.data<T>(); const T* coo_values_data = coo_values.data<T>();
......
...@@ -318,50 +318,40 @@ class TestSparseConvert(unittest.TestCase): ...@@ -318,50 +318,40 @@ class TestSparseConvert(unittest.TestCase):
def test_batch_csr(self): def test_batch_csr(self):
with _test_eager_guard(): with _test_eager_guard():
shape = [3, 3, 3]
def verify(dense_x):
def verify(x, crows, cols, values): sparse_x = dense_x.to_sparse_csr()
x = paddle.to_tensor(x) out = sparse_x.to_dense()
csr = x.to_sparse_csr() assert np.allclose(out.numpy(), dense_x.numpy())
assert np.allclose(crows, csr.crows().numpy())
assert np.allclose(cols, csr.cols().numpy()) shape = np.random.randint(low=1, high=10, size=3)
assert np.allclose(values, csr.values().numpy()) shape = list(shape)
dense_x = paddle.randn(shape)
dense = csr.to_dense() dense_x = paddle.nn.functional.dropout(dense_x, p=0.5)
assert np.allclose(x.numpy(), dense.numpy()) verify(dense_x)
x = [ #test batchs=1
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]], shape[0] = 1
[[0, 0, 0], [0, 0, 0], [0, 0, 0]], dense_x = paddle.randn(shape)
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]], dense_x = paddle.nn.functional.dropout(dense_x, p=0.5)
] verify(dense_x)
crows = [[0, 1, 2, 3, 0, 0, 0, 0, 0, 1, 2, 3]]
cols = [0, 1, 2, 0, 1, 2] shape = np.random.randint(low=2, high=10, size=3)
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] shape = list(shape)
dense_x = paddle.randn(shape)
verify(x, crows, cols, values) #set the 0th batch to zero
dense_x[0] = 0
x = [ verify(dense_x)
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]], dense_x = paddle.randn(shape)
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]], #set the 1th batch to zero
] dense_x[1] = 0
crows = [[0, 0, 0, 0, 0, 1, 2, 3, 0, 1, 2, 3]] verify(dense_x)
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] dense_x = paddle.randn(shape)
#set the 2th batch to zero
verify(x, crows, cols, values) dense_x[2] = 0
verify(dense_x)
x = [
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
]
crows = [[0, 1, 2, 3, 0, 1, 2, 3, 0, 0, 0, 0]]
cols = [0, 1, 2, 0, 1, 2]
values = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
verify(x, crows, cols, values)
class TestCooError(unittest.TestCase): class TestCooError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册