Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
acf3e526
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
acf3e526
编写于
5月 25, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
5月 25, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]fix sparse bug (#53390)
上级
4ea1d041
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
104 addition
and
36 deletion
+104
-36
paddle/phi/core/sparse_coo_tensor.h
paddle/phi/core/sparse_coo_tensor.h
+7
-2
paddle/phi/core/sparse_csr_tensor.h
paddle/phi/core/sparse_csr_tensor.h
+7
-2
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
+21
-9
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
+7
-3
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
+41
-20
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
+21
-0
未找到文件。
paddle/phi/core/sparse_coo_tensor.h
浏览文件 @
acf3e526
...
@@ -126,8 +126,13 @@ class SparseCooTensor : public TensorBase,
...
@@ -126,8 +126,13 @@ class SparseCooTensor : public TensorBase,
bool
valid
()
const
noexcept
override
{
return
non_zero_elements_
.
valid
();
}
bool
valid
()
const
noexcept
override
{
return
non_zero_elements_
.
valid
();
}
/// \brief Test whether the non_zero_elements_ storage is allocated.
/// \brief Test whether the non_zero_elements_ storage is allocated.
/// return Whether the non_zero_elements_ storage is allocated.
/// In special cases, when nnz=0, non_zero_elements_ will not need to be
bool
initialized
()
const
override
{
return
non_zero_elements_
.
initialized
();
}
/// initialized, but it is neccessary to return true here, otherwise the
/// gradient will be None. return Whether the non_zero_elements_ storage is
/// allocated.
bool
initialized
()
const
override
{
return
values
().
initialized
()
||
(
nnz
()
==
0
&&
numel
()
>
0
);
}
/// \brief resize sparse coo tensor.
/// \brief resize sparse coo tensor.
/// \param dense_dims The dims of original dense tensor.
/// \param dense_dims The dims of original dense tensor.
...
...
paddle/phi/core/sparse_csr_tensor.h
浏览文件 @
acf3e526
...
@@ -131,8 +131,13 @@ class SparseCsrTensor : public TensorBase,
...
@@ -131,8 +131,13 @@ class SparseCsrTensor : public TensorBase,
bool
valid
()
const
noexcept
override
{
return
non_zero_elements_
.
valid
();
}
bool
valid
()
const
noexcept
override
{
return
non_zero_elements_
.
valid
();
}
/// \brief Test whether the non_zero_elements_ storage is allocated.
/// \brief Test whether the non_zero_elements_ storage is allocated.
/// return Whether the non_zero_elements_ storage is allocated.
/// In special cases, when nnz=0, non_zero_elements_ will not need to be
bool
initialized
()
const
override
{
return
non_zero_elements_
.
initialized
();
}
/// initialized, but it is neccessary to return true here, otherwise the
/// gradient will be None. return Whether the non_zero_elements_ storage is
/// allocated.
bool
initialized
()
const
override
{
return
values
().
initialized
()
||
(
nnz
()
==
0
&&
numel
()
>
0
);
}
/// \brief resize sparse csr tensor.
/// \brief resize sparse csr tensor.
/// \param dense_dims The dims of original dense tensor.
/// \param dense_dims The dims of original dense tensor.
...
...
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
浏览文件 @
acf3e526
...
@@ -113,13 +113,6 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -113,13 +113,6 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx,
SparseCooTensor
*
out
)
{
SparseCooTensor
*
out
)
{
const
DDim
&
x_dims
=
x
.
dims
();
const
DDim
&
x_dims
=
x
.
dims
();
const
int64_t
non_zero_num
=
x
.
cols
().
numel
();
const
int64_t
non_zero_num
=
x
.
cols
().
numel
();
const
auto
&
csr_crows
=
x
.
crows
();
const
auto
&
csr_cols
=
x
.
cols
();
const
auto
&
csr_values
=
x
.
values
();
const
IntT
*
csr_crows_data
=
csr_crows
.
data
<
IntT
>
();
const
IntT
*
csr_cols_data
=
csr_cols
.
data
<
IntT
>
();
const
T
*
csr_values_data
=
csr_values
.
data
<
T
>
();
int64_t
sparse_dim
=
2
;
int64_t
sparse_dim
=
2
;
if
(
x_dims
.
size
()
==
3
)
{
if
(
x_dims
.
size
()
==
3
)
{
sparse_dim
=
3
;
sparse_dim
=
3
;
...
@@ -127,6 +120,17 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -127,6 +120,17 @@ void CsrToCooCPUKernel(const CPUContext& dev_ctx,
phi
::
DenseTensor
indices
=
phi
::
DenseTensor
indices
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
phi
::
DenseTensor
values
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
non_zero_num
});
phi
::
DenseTensor
values
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
non_zero_num
});
if
(
x
.
nnz
()
<=
0
)
{
out
->
SetMember
(
indices
,
values
,
x_dims
,
true
);
return
;
}
const
auto
&
csr_crows
=
x
.
crows
();
const
auto
&
csr_cols
=
x
.
cols
();
const
auto
&
csr_values
=
x
.
values
();
const
IntT
*
csr_crows_data
=
csr_crows
.
data
<
IntT
>
();
const
IntT
*
csr_cols_data
=
csr_cols
.
data
<
IntT
>
();
const
T
*
csr_values_data
=
csr_values
.
data
<
T
>
();
IntT
*
coo_indices
=
indices
.
data
<
IntT
>
();
IntT
*
coo_indices
=
indices
.
data
<
IntT
>
();
IntT
*
batch_ptr
=
x_dims
.
size
()
==
2
?
nullptr
:
coo_indices
;
IntT
*
batch_ptr
=
x_dims
.
size
()
==
2
?
nullptr
:
coo_indices
;
IntT
*
coo_rows_data
=
IntT
*
coo_rows_data
=
...
@@ -177,7 +181,6 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx,
...
@@ -177,7 +181,6 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"SparseCsrTensor only support 2-D or 3-D matrix"
));
"SparseCsrTensor only support 2-D or 3-D matrix"
));
const
int64_t
non_zero_num
=
x
.
nnz
();
const
int64_t
non_zero_num
=
x
.
nnz
();
if
(
non_zero_num
<=
0
)
return
;
int
batchs
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
batchs
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
...
@@ -185,6 +188,10 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx,
...
@@ -185,6 +188,10 @@ void CooToCsrCPUKernel(const CPUContext& dev_ctx,
phi
::
DenseTensor
crows
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
*
(
rows
+
1
)});
phi
::
DenseTensor
crows
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
*
(
rows
+
1
)});
phi
::
DenseTensor
cols
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
non_zero_num
});
phi
::
DenseTensor
cols
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
non_zero_num
});
phi
::
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
CPUContext
>
(
dev_ctx
,
x
.
values
());
phi
::
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
CPUContext
>
(
dev_ctx
,
x
.
values
());
if
(
non_zero_num
<=
0
)
{
out
->
SetMember
(
crows
,
cols
,
values
,
x_dims
);
return
;
}
IntT
*
csr_crows_data
=
crows
.
data
<
IntT
>
();
IntT
*
csr_crows_data
=
crows
.
data
<
IntT
>
();
IntT
*
csr_cols_data
=
cols
.
data
<
IntT
>
();
IntT
*
csr_cols_data
=
cols
.
data
<
IntT
>
();
T
*
csr_values_data
=
values
.
data
<
T
>
();
T
*
csr_values_data
=
values
.
data
<
T
>
();
...
@@ -268,6 +275,12 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx,
...
@@ -268,6 +275,12 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx,
const
T
*
x_data
=
values
.
data
<
T
>
();
const
T
*
x_data
=
values
.
data
<
T
>
();
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
out
->
data
<
T
>
();
T
*
out_data
=
out
->
data
<
T
>
();
memset
(
out_data
,
0
,
sizeof
(
T
)
*
out
->
numel
());
if
(
x
.
nnz
()
<=
0
)
{
return
;
}
int64_t
base_offset
=
1
;
int64_t
base_offset
=
1
;
for
(
int64_t
i
=
0
;
i
<
dense_dim
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
dense_dim
;
i
++
)
{
base_offset
*=
dense_dims
[
sparse_dim
+
i
];
base_offset
*=
dense_dims
[
sparse_dim
+
i
];
...
@@ -279,7 +292,6 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx,
...
@@ -279,7 +292,6 @@ void CooToDenseCPUKernel(const CPUContext& dev_ctx,
offset
*=
dense_dims
[
i
];
offset
*=
dense_dims
[
i
];
}
}
memset
(
out_data
,
0
,
sizeof
(
T
)
*
out
->
numel
());
for
(
auto
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
int64_t
index
=
0
;
int64_t
index
=
0
;
for
(
int
j
=
0
;
j
<
sparse_dim
;
j
++
)
{
for
(
int
j
=
0
;
j
<
sparse_dim
;
j
++
)
{
...
...
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
浏览文件 @
acf3e526
...
@@ -61,6 +61,13 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -61,6 +61,13 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx,
phi
::
errors
::
InvalidArgument
(
"the input x and mask must have the shape"
));
phi
::
errors
::
InvalidArgument
(
"the input x and mask must have the shape"
));
const
DenseTensor
&
indices
=
mask
.
indices
();
const
DenseTensor
&
indices
=
mask
.
indices
();
const
DenseTensor
&
values
=
mask
.
values
();
const
DenseTensor
&
values
=
mask
.
values
();
DenseTensor
out_indices
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
indices
);
DenseTensor
out_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
values
);
if
(
mask
.
nnz
()
<=
0
)
{
out
->
SetMember
(
out_indices
,
out_values
,
dims
,
true
);
return
;
}
const
int
sparse_dim
=
mask
.
sparse_dim
();
const
int
sparse_dim
=
mask
.
sparse_dim
();
DenseTensor
sparse_offsets
=
phi
::
Empty
<
GPUContext
>
(
DenseTensor
sparse_offsets
=
phi
::
Empty
<
GPUContext
>
(
dev_ctx
,
dev_ctx
,
...
@@ -75,9 +82,6 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -75,9 +82,6 @@ void MaskCooGPUKernel(const GPUContext& dev_ctx,
gpuMemcpyHostToDevice
,
gpuMemcpyHostToDevice
,
dev_ctx
.
stream
());
dev_ctx
.
stream
());
DenseTensor
out_indices
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
indices
);
DenseTensor
out_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
values
);
phi
::
Copy
(
dev_ctx
,
indices
,
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
phi
::
Copy
(
dev_ctx
,
indices
,
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
const
IntT
*
indices_ptr
=
indices
.
data
<
IntT
>
();
const
IntT
*
indices_ptr
=
indices
.
data
<
IntT
>
();
...
...
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
浏览文件 @
acf3e526
...
@@ -164,18 +164,20 @@ void DenseToCooKernel(const Context& dev_ctx,
...
@@ -164,18 +164,20 @@ void DenseToCooKernel(const Context& dev_ctx,
T
*
sparse_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
values
);
T
*
sparse_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
values
);
// 3. calc indices by indexs and get values by indexs
// 3. calc indices by indexs and get values by indexs
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
if
(
non_zero_num
>
0
)
{
GetNonZeroElementsAndIndices
<<<
config
.
block_per_grid
.
x
,
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
config
.
thread_per_block
.
x
,
GetNonZeroElementsAndIndices
<<<
config
.
block_per_grid
.
x
,
0
,
config
.
thread_per_block
.
x
,
dev_ctx
.
stream
()
>>>
(
x_data
,
0
,
sparse_dim
,
dev_ctx
.
stream
()
>>>
(
x_data
,
cols
,
sparse_dim
,
d_x_dims
.
data
<
int64_t
>
(),
cols
,
non_zero_num
,
d_x_dims
.
data
<
int64_t
>
(),
temp_indexs_ptr
,
non_zero_num
,
indices_data
,
temp_indexs_ptr
,
sparse_data
);
indices_data
,
sparse_data
);
}
out
->
SetMember
(
indices
,
values
,
x_dims
,
true
);
out
->
SetMember
(
indices
,
values
,
x_dims
,
true
);
}
}
...
@@ -218,6 +220,21 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -218,6 +220,21 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
SparseCooTensor
*
out
)
{
SparseCooTensor
*
out
)
{
const
DDim
&
x_dims
=
x
.
dims
();
const
DDim
&
x_dims
=
x
.
dims
();
const
int64_t
non_zero_num
=
x
.
cols
().
numel
();
const
int64_t
non_zero_num
=
x
.
cols
().
numel
();
int64_t
sparse_dim
=
2
;
if
(
x_dims
.
size
()
==
3
)
{
sparse_dim
=
3
;
}
if
(
x
.
nnz
()
<=
0
)
{
#ifdef PADDLE_WITH_HIP
DenseTensor
indices
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
#else
DenseTensor
indices
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
sparse_dim
,
non_zero_num
});
#endif
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
GPUContext
>
(
dev_ctx
,
x
.
values
());
out
->
SetMember
(
indices
,
values
,
x_dims
,
true
);
return
;
}
// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int')
// rocsparse_csr2coo only support index with type 'rocsparse_int' (aka 'int')
// now
// now
...
@@ -235,10 +252,6 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -235,10 +252,6 @@ void CsrToCooGPUKernel(const GPUContext& dev_ctx,
const
auto
&
csr_values
=
x
.
values
();
const
auto
&
csr_values
=
x
.
values
();
const
T
*
csr_values_data
=
csr_values
.
data
<
T
>
();
const
T
*
csr_values_data
=
csr_values
.
data
<
T
>
();
int64_t
sparse_dim
=
2
;
if
(
x_dims
.
size
()
==
3
)
{
sparse_dim
=
3
;
}
int
batches
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
batches
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
...
@@ -395,7 +408,6 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx,
...
@@ -395,7 +408,6 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx,
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"SparseCsrTensor only support 2-D or 3-D matrix"
));
"SparseCsrTensor only support 2-D or 3-D matrix"
));
const
int64_t
non_zero_num
=
x
.
nnz
();
const
int64_t
non_zero_num
=
x
.
nnz
();
if
(
non_zero_num
<=
0
)
return
;
int
batchs
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
batchs
=
x_dims
.
size
()
==
2
?
1
:
x_dims
[
0
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
int
rows
=
x_dims
.
size
()
==
2
?
x_dims
[
0
]
:
x_dims
[
1
];
...
@@ -403,6 +415,10 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx,
...
@@ -403,6 +415,10 @@ void CooToCsrGPUKernel(const GPUContext& dev_ctx,
phi
::
DenseTensor
crows
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
*
(
rows
+
1
)});
phi
::
DenseTensor
crows
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
batchs
*
(
rows
+
1
)});
phi
::
DenseTensor
cols
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
non_zero_num
});
phi
::
DenseTensor
cols
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
non_zero_num
});
phi
::
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
GPUContext
>
(
dev_ctx
,
x
.
values
());
phi
::
DenseTensor
values
=
phi
::
EmptyLike
<
T
,
GPUContext
>
(
dev_ctx
,
x
.
values
());
if
(
non_zero_num
<=
0
)
{
out
->
SetMember
(
crows
,
cols
,
values
,
x_dims
);
return
;
}
IntT
*
csr_crows_data
=
crows
.
data
<
IntT
>
();
IntT
*
csr_crows_data
=
crows
.
data
<
IntT
>
();
IntT
*
csr_cols_data
=
cols
.
data
<
IntT
>
();
IntT
*
csr_cols_data
=
cols
.
data
<
IntT
>
();
T
*
csr_values_data
=
values
.
data
<
T
>
();
T
*
csr_values_data
=
values
.
data
<
T
>
();
...
@@ -503,10 +519,17 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx,
...
@@ -503,10 +519,17 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx,
const
int64_t
dense_dim
=
values
.
dims
().
size
()
-
1
;
const
int64_t
dense_dim
=
values
.
dims
().
size
()
-
1
;
const
auto
place
=
dev_ctx
.
GetPlace
();
const
auto
place
=
dev_ctx
.
GetPlace
();
const
T
*
x_data
=
values
.
data
<
T
>
();
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
out
->
data
<
T
>
();
T
*
out_data
=
out
->
data
<
T
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_data
,
0
,
sizeof
(
T
)
*
out
->
numel
(),
dev_ctx
.
stream
());
if
(
x
.
nnz
()
<=
0
)
{
return
;
}
const
T
*
x_data
=
values
.
data
<
T
>
();
int64_t
base_offset
=
1
;
int64_t
base_offset
=
1
;
for
(
int64_t
i
=
0
;
i
<
dense_dim
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
dense_dim
;
i
++
)
{
base_offset
*=
dense_dims
[
sparse_dim
+
i
];
base_offset
*=
dense_dims
[
sparse_dim
+
i
];
...
@@ -525,8 +548,6 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx,
...
@@ -525,8 +548,6 @@ void CooToDenseGPUKernel(const GPUContext& dev_ctx,
sparse_dim
*
sizeof
(
int64_t
),
sparse_dim
*
sizeof
(
int64_t
),
gpuMemcpyHostToDevice
,
gpuMemcpyHostToDevice
,
dev_ctx
.
stream
());
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_data
,
0
,
sizeof
(
T
)
*
out
->
numel
(),
dev_ctx
.
stream
());
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
...
...
python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
浏览文件 @
acf3e526
...
@@ -351,6 +351,27 @@ class TestSparseConvert(unittest.TestCase):
...
@@ -351,6 +351,27 @@ class TestSparseConvert(unittest.TestCase):
dense_x
[
2
]
=
0
dense_x
[
2
]
=
0
verify
(
dense_x
)
verify
(
dense_x
)
def
test_zero_nnz
(
self
):
for
device
in
devices
:
if
device
==
'cpu'
or
(
device
==
'gpu'
and
paddle
.
is_compiled_with_cuda
()
):
paddle
.
device
.
set_device
(
device
)
x1
=
paddle
.
zeros
([
2
,
2
,
2
])
x2
=
paddle
.
zeros
([
2
,
2
,
2
])
sp_csr_x
=
x1
.
to_sparse_csr
()
sp_coo_x
=
x2
.
to_sparse_coo
(
1
)
sp_coo_x
.
stop_gradient
=
False
out1
=
sp_csr_x
.
to_dense
()
out2
=
sp_coo_x
.
to_dense
()
out2
.
backward
()
np
.
testing
.
assert_allclose
(
out1
.
numpy
(),
x1
.
numpy
())
np
.
testing
.
assert_allclose
(
out2
.
numpy
(),
x2
.
numpy
())
np
.
testing
.
assert_allclose
(
sp_coo_x
.
grad
.
to_dense
().
numpy
().
sum
(),
0.0
)
class
TestCooError
(
unittest
.
TestCase
):
class
TestCooError
(
unittest
.
TestCase
):
def
test_small_shape
(
self
):
def
test_small_shape
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录