Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
03517d8a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
03517d8a
编写于
6月 17, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
6月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix batch csr (#43553)
* fix to_sparse_csr
上级
5a5649c2
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
89 addition
and
17 deletion
+89
-17
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
+8
-2
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
+28
-11
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
+48
-0
python/paddle/incubate/sparse/creation.py
python/paddle/incubate/sparse/creation.py
+5
-4
未找到文件。
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
浏览文件 @
03517d8a
...
...
@@ -206,7 +206,11 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
if
(
batchs
>
1
)
{
for
(
int
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
if
(
i
==
non_zero_num
-
1
||
batchs_ptr
[
i
]
!=
batchs_ptr
[
i
+
1
])
{
offsets
[
batchs_ptr
[
i
]]
=
i
+
1
;
const
int
start
=
batchs_ptr
[
i
];
const
int
end
=
i
==
non_zero_num
-
1
?
batchs
:
batchs_ptr
[
i
+
1
];
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
offsets
[
j
]
=
i
+
1
;
}
}
}
}
else
{
...
...
@@ -214,7 +218,6 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
}
for
(
int
b
=
0
;
b
<
batchs
;
b
++
)
{
if
(
offsets
[
b
]
==
0
)
continue
;
int
batch_start
=
0
;
int
batch_non_zero_num
=
offsets
[
b
];
if
(
b
>
0
)
{
...
...
@@ -233,6 +236,9 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
for
(
IntT
i
=
coo_rows_ptr
[
batch_non_zero_num
-
1
]
+
1
;
i
<
rows
+
1
;
i
++
)
{
csr_crows_data
[
b
*
(
rows
+
1
)
+
i
]
=
batch_non_zero_num
;
}
if
(
batch_non_zero_num
==
0
)
{
memset
(
csr_crows_data
+
b
*
(
rows
+
1
),
0
,
sizeof
(
IntT
)
*
(
rows
+
1
));
}
}
memcpy
(
csr_cols_data
,
coo_cols_data
,
sizeof
(
IntT
)
*
non_zero_num
);
...
...
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
浏览文件 @
03517d8a
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/common_shape.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
...
...
@@ -283,19 +284,24 @@ void SparseCsrToCooKernel(const Context& dev_ctx,
template
<
typename
IntT
>
__global__
void
GetBatchsOffset
(
const
IntT
*
batchs_ptr
,
const
int
batchs
,
const
int
non_zero_num
,
IntT
*
batchs_offset
)
{
int
*
batchs_offset
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
if
(
i
==
non_zero_num
-
1
||
batchs_ptr
[
i
]
!=
batchs_ptr
[
i
+
1
])
{
batchs_offset
[
batchs_ptr
[
i
]]
=
i
+
1
;
const
int
start
=
batchs_ptr
[
i
];
const
int
end
=
i
==
non_zero_num
-
1
?
batchs
:
batchs_ptr
[
i
+
1
];
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
batchs_offset
[
j
]
=
i
+
1
;
}
}
}
}
template
<
typename
IntT
>
__global__
void
ConvertCooRowsToCsrCrows
(
const
IntT
*
batchs_offset
,
// can be null if batchs = 1
const
int
*
batchs_offset
,
// can be null if batchs = 1
const
IntT
*
coo_rows_data
,
IntT
*
csr_crows_data
,
const
int
rows
,
...
...
@@ -303,12 +309,12 @@ __global__ void ConvertCooRowsToCsrCrows(
const
int
b
=
blockIdx
.
y
;
int
batch_non_zero_num
=
batchs_offset
==
nullptr
?
non_zero_num
:
batchs_offset
[
b
];
if
(
batch_non_zero_num
==
0
)
return
;
IntT
batch_start
=
0
;
if
(
b
>
0
)
{
batch_start
=
batchs_offset
[
b
-
1
];
batch_non_zero_num
-=
batch_start
;
}
const
IntT
*
coo_rows_ptr
=
coo_rows_data
+
batch_start
;
const
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
batch_non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
...
@@ -328,6 +334,11 @@ __global__ void ConvertCooRowsToCsrCrows(
}
}
}
if
(
batch_non_zero_num
==
0
)
{
for
(
int
i
=
tid
;
i
<
rows
+
1
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
csr_crows_data
[
b
*
(
rows
+
1
)
+
i
]
=
0
;
}
}
}
template
<
typename
T
,
typename
IntT
>
...
...
@@ -365,13 +376,19 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
batchs
,
1
);
if
(
batchs
>
1
)
{
phi
::
DenseTensor
batchs_offset
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
});
IntT
*
batchs_offset_ptr
=
batchs_offset
.
data
<
IntT
>
();
GetBatchsOffset
<
IntT
>
<<<
config
.
block_per_grid
.
x
,
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
phi
::
DenseTensor
batchs_offset
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
batchs
});
int
*
batchs_offset_ptr
=
batchs_offset
.
data
<
int
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
int
>
set_zero
;
// set zero if the nnz=0 of batchs[0]
set_zero
(
dev_ctx
,
&
batchs_offset
,
static_cast
<
IntT
>
(
0
));
GetBatchsOffset
<
IntT
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
batchs_ptr
,
non_zero_num
,
batchs_offset_ptr
);
dev_ctx
.
stream
()
>>>
(
batchs_ptr
,
batchs
,
non_zero_num
,
batchs_offset_ptr
);
config
.
block_per_grid
.
y
=
batchs
;
ConvertCooRowsToCsrCrows
<
IntT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
.
x
,
...
...
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
浏览文件 @
03517d8a
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
paddle
from
paddle.incubate
import
sparse
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
_test_eager_guard
...
...
@@ -315,6 +316,53 @@ class TestSparseConvert(unittest.TestCase):
assert
np
.
array_equal
(
values_sorted
,
sparse_x
.
values
().
numpy
())
def
test_batch_csr
(
self
):
with
_test_eager_guard
():
shape
=
[
3
,
3
,
3
]
def
verify
(
x
,
crows
,
cols
,
values
):
x
=
paddle
.
to_tensor
(
x
)
csr
=
x
.
to_sparse_csr
()
assert
np
.
allclose
(
crows
,
csr
.
crows
().
numpy
())
assert
np
.
allclose
(
cols
,
csr
.
cols
().
numpy
())
assert
np
.
allclose
(
values
,
csr
.
values
().
numpy
())
dense
=
csr
.
to_dense
()
assert
np
.
allclose
(
x
.
numpy
(),
dense
.
numpy
())
x
=
[
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
[[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
]
crows
=
[[
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
]]
cols
=
[
0
,
1
,
2
,
0
,
1
,
2
]
values
=
[
1.0
,
2.0
,
3.0
,
1.0
,
2.0
,
3.0
]
verify
(
x
,
crows
,
cols
,
values
)
x
=
[
[[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
]
crows
=
[[
0
,
0
,
0
,
0
,
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
]]
cols
=
[
0
,
1
,
2
,
0
,
1
,
2
]
values
=
[
1.0
,
2.0
,
3.0
,
1.0
,
2.0
,
3.0
]
verify
(
x
,
crows
,
cols
,
values
)
x
=
[
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
[[
1.0
,
0
,
0
],
[
0
,
2.0
,
0
],
[
0
,
0
,
3.0
]],
[[
0
,
0
,
0
],
[
0
,
0
,
0
],
[
0
,
0
,
0
]],
]
crows
=
[[
0
,
1
,
2
,
3
,
0
,
1
,
2
,
3
,
0
,
0
,
0
,
0
]]
cols
=
[
0
,
1
,
2
,
0
,
1
,
2
]
values
=
[
1.0
,
2.0
,
3.0
,
1.0
,
2.0
,
3.0
]
verify
(
x
,
crows
,
cols
,
values
)
class
TestCooError
(
unittest
.
TestCase
):
...
...
python/paddle/incubate/sparse/creation.py
浏览文件 @
03517d8a
...
...
@@ -249,6 +249,7 @@ def sparse_csr_tensor(crows,
raise
ValueError
(
"SparseCsrTensor only support 2-D or 3-D matrix. but get shape {}"
.
format
(
shape
))
rows
=
shape
[
len
(
shape
)
-
2
]
if
not
crows
.
place
.
_equals
(
place
):
crows
=
crows
.
_copy_to
(
place
,
False
)
...
...
@@ -268,10 +269,10 @@ def sparse_csr_tensor(crows,
raise
ValueError
(
"the length of cols must be same as length of values"
)
if
len
(
shape
)
==
2
:
if
crows
.
shape
[
0
]
!=
shape
[
0
]
+
1
:
if
crows
.
shape
[
0
]
!=
rows
+
1
:
raise
ValueError
(
"The length({}) of crows must be equal to the rows({})+1 of matrix."
.
format
(
crows
.
shape
[
0
],
shape
[
0
]
))
.
format
(
crows
.
shape
[
0
],
rows
))
if
crows
[
0
]
!=
0
:
raise
ValueError
(
"the 0th value of crows must be 0"
)
...
...
@@ -279,10 +280,10 @@ def sparse_csr_tensor(crows,
raise
ValueError
(
"the last value of crows must be equal the number of non-zero"
)
else
:
if
crows
.
shape
[
0
]
%
(
shape
[
0
]
+
1
)
!=
0
:
if
crows
.
shape
[
0
]
%
(
rows
+
1
)
!=
0
:
raise
ValueError
(
"The length({}) of crows must be divisible the rows({})+1 of matrix."
.
format
(
crows
.
shape
[
0
],
shape
[
0
]
))
.
format
(
crows
.
shape
[
0
],
rows
))
# TODO(zkh2016): check whether the value in crows and cols is legal
return
core
.
eager
.
sparse_csr_tensor
(
crows
,
cols
,
values
,
shape
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录