Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36739748
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36739748
编写于
9月 07, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
9月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]Rename sparse kernel (#45730)
上级
c084a7b1
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
271 addition
and
310 deletion
+271
-310
paddle/phi/api/lib/sparse_api_custom_impl.cc
paddle/phi/api/lib/sparse_api_custom_impl.cc
+6
-6
paddle/phi/api/yaml/sparse_api.yaml
paddle/phi/api/yaml/sparse_api.yaml
+23
-27
paddle/phi/api/yaml/sparse_bw_api.yaml
paddle/phi/api/yaml/sparse_bw_api.yaml
+23
-22
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
+3
-3
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
+46
-48
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
+2
-2
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
+54
-56
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
+13
-13
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
+5
-5
paddle/phi/kernels/sparse/sparse_utils_kernel.h
paddle/phi/kernels/sparse/sparse_utils_kernel.h
+39
-41
paddle/phi/tests/api/test_sparse_utils_api.cc
paddle/phi/tests/api/test_sparse_utils_api.cc
+1
-1
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
+1
-1
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
+39
-48
paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc
paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc
+12
-14
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+3
-22
python/paddle/incubate/sparse/creation.py
python/paddle/incubate/sparse/creation.py
+1
-1
未找到文件。
paddle/phi/api/lib/sparse_api_custom_impl.cc
浏览文件 @
36739748
...
...
@@ -30,9 +30,9 @@ Tensor to_sparse_coo_impl(const Tensor& x, const int64_t sparse_dim) {
}
// 1. Get kernel signature and kernel
std
::
string
kernel_name
=
"dense_to_
sparse_
coo"
;
std
::
string
kernel_name
=
"dense_to_coo"
;
if
(
x
.
layout
()
==
phi
::
DataLayout
::
SPARSE_CSR
)
{
kernel_name
=
"
sparse_
csr_to_coo"
;
kernel_name
=
"csr_to_coo"
;
}
auto
kernel_key_set
=
ParseKernelKeyByInputArgs
(
x
);
...
...
@@ -88,9 +88,9 @@ Tensor to_sparse_csr_impl(const Tensor& x) {
return
x
;
}
// 1. Get kernel signature and kernel
std
::
string
kernel_name
=
"dense_to_
sparse_
csr"
;
std
::
string
kernel_name
=
"dense_to_csr"
;
if
(
x
.
layout
()
==
phi
::
DataLayout
::
SPARSE_COO
)
{
kernel_name
=
"
sparse_
coo_to_csr"
;
kernel_name
=
"coo_to_csr"
;
}
auto
kernel_key_set
=
ParseKernelKeyByInputArgs
(
x
);
...
...
@@ -151,9 +151,9 @@ Tensor to_dense_impl(const Tensor& x) {
}
// 1. Get kernel signature and kernel
std
::
string
kernel_name
=
"
sparse_
coo_to_dense"
;
std
::
string
kernel_name
=
"coo_to_dense"
;
if
(
x
.
layout
()
==
phi
::
DataLayout
::
SPARSE_CSR
)
{
kernel_name
=
"
sparse_
csr_to_dense"
;
kernel_name
=
"csr_to_dense"
;
}
auto
kernel_key_set
=
ParseKernelKeyByInputArgs
(
x
);
...
...
paddle/phi/api/yaml/sparse_api.yaml
浏览文件 @
36739748
...
...
@@ -89,27 +89,6 @@
intermediate
:
rulebook, counter
backward
:
conv3d_coo_grad
-
api
:
coo_to_dense
args
:
(Tensor x)
output
:
Tensor(out)
invoke
:
to_dense_impl(x)
backward
:
coo_to_dense_grad
-
api
:
create_sparse_coo_tensor
args
:
(Tensor values, Tensor indices, IntArray dense_shape)
output
:
Tensor(out)
kernel
:
func
:
sparse_coo_tensor{dense, dense -> sparse_coo}
layout
:
values
data_type
:
values
backward
:
create_sparse_coo_tensor_grad
-
api
:
dense_to_coo
args
:
(Tensor x, int64_t sparse_dim)
output
:
Tensor(out)
invoke
:
to_sparse_coo_impl(x, sparse_dim)
backward
:
dense_to_coo_grad
-
api
:
divide
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
...
...
@@ -224,6 +203,15 @@
layout
:
x
backward
:
softmax_grad
-
api
:
sparse_coo_tensor
args
:
(Tensor values, Tensor indices, IntArray dense_shape)
output
:
Tensor(out)
kernel
:
func
:
sparse_coo_tensor{dense, dense -> sparse_coo}
layout
:
values
data_type
:
values
backward
:
sparse_coo_tensor_grad
-
api
:
sqrt
args
:
(Tensor x)
output
:
Tensor(out)
...
...
@@ -272,24 +260,32 @@
-
api
:
to_dense
args
:
(Tensor x)
output
:
Tensor(out)
invoke
:
to_dense_impl(x)
kernel
:
func
:
coo_to_dense {sparse_coo -> dense},
csr_to_dense {sparse_csr -> dense}
backward
:
to_dense_grad
-
api
:
to_sparse_coo
args
:
(Tensor x, int64_t sparse_dim)
output
:
Tensor(out)
invoke
:
to_sparse_coo_impl(x, sparse_dim)
kernel
:
func
:
dense_to_coo { dense -> sparse_coo },
csr_to_coo { sparse_csr -> sparse_coo}
backward
:
to_sparse_coo_grad
-
api
:
to_sparse_csr
args
:
(Tensor x)
output
:
Tensor(out)
invoke
:
to_sparse_csr_impl(x)
kernel
:
func
:
dense_to_csr {dense -> sparse_csr},
coo_to_csr {sparse_coo -> sparse_csr}
-
api
:
values
args
:
(Tensor x)
output
:
Tensor(out)
kernel
:
func
:
coo_values
{sparse_coo -> dense},
csr_values
{sparse_csr -> dense}
func
:
values_coo
{sparse_coo -> dense},
values_csr
{sparse_csr -> dense}
layout
:
x
backward
:
values_grad
...
...
paddle/phi/api/yaml/sparse_bw_api.yaml
浏览文件 @
36739748
...
...
@@ -88,26 +88,6 @@
kernel
:
func
:
conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
-
backward_api
:
coo_to_dense_grad
forward
:
coo_to_dense(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad)
kernel
:
func
:
sparse_coo_to_dense_grad{sparse_coo, dense-> sparse_coo}
-
backward_api
:
create_sparse_coo_tensor_grad
forward
:
create_sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out)
args
:
(Tensor indices, Tensor out_grad)
output
:
Tensor(values_grad)
kernel
:
func
:
sparse_coo_tensor_grad{dense, sparse_coo -> dense}
-
backward_api
:
dense_to_coo_grad
forward
:
dense_to_coo(Tensor x, int64_t sparse_dim) -> Tensor(out)
args
:
(Tensor out_grad)
output
:
Tensor(x_grad)
invoke
:
coo_to_dense(out_grad)
-
backward_api
:
divide_grad
forward
:
divide(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out, Tensor out_grad)
...
...
@@ -239,6 +219,13 @@
kernel
:
func
:
softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
sparse_coo_tensor_grad
forward
:
sparse_coo_tensor(Tensor values, Tensor indices, IntArray dense_shape) -> Tensor(out)
args
:
(Tensor indices, Tensor out_grad)
output
:
Tensor(values_grad)
kernel
:
func
:
sparse_coo_tensor_grad{dense, sparse_coo -> dense}
-
backward_api
:
sqrt_grad
forward
:
sqrt(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
...
...
@@ -279,12 +266,26 @@
func
:
tanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
tanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
to_dense_grad
forward
:
to_dense(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad)
kernel
:
func
:
coo_to_dense_grad{sparse_coo, dense -> sparse_coo}
-
backward_api
:
to_sparse_coo_grad
forward
:
to_sparse_coo(Tensor x, int64_t sparse_dim) -> Tensor(out)
args
:
(Tensor out_grad)
output
:
Tensor(x_grad)
kernel
:
func
:
coo_to_dense { sparse_coo -> dense }
-
backward_api
:
values_grad
forward
:
coo_values
(Tensor x) -> Tensor(out)
forward
:
values_coo
(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
output
:
Tensor(x_grad)
kernel
:
func
:
coo_values
_grad{sparse_coo, dense-> sparse_coo}
func
:
values_coo
_grad{sparse_coo, dense-> sparse_coo}
-
backward_api
:
fused_attention_grad
forward
:
fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
...
...
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
浏览文件 @
36739748
...
...
@@ -270,15 +270,15 @@ void ElementWiseCooKernelImpl(const Context& dev_ctx,
const SparseCsrTensor& y, \
SparseCsrTensor* out) { \
funcs::name##Functor<T> functor; \
auto coo_x =
SparseCsrToCoo<T>(dev_ctx, x);
\
auto coo_y =
SparseCsrToCoo<T>(dev_ctx, y);
\
auto coo_x =
CsrToCoo<T>(dev_ctx, x);
\
auto coo_y =
CsrToCoo<T>(dev_ctx, y);
\
DenseTensor indeces; \
DenseTensor values; \
SparseCooTensor coo_out; \
coo_out.SetMember(indeces, values, x.dims()); \
ElementWiseCooKernelImpl<T, IntT, Context, funcs::name##Functor<T>>( \
dev_ctx, coo_x, coo_y, &coo_out, functor); \
*out =
SparseCooToCsr<T>(dev_ctx, coo_out);
\
*out =
CooToCsr<T>(dev_ctx, coo_out);
\
}
#define DEFINE_CSR_ELEMENTWISE_KERNEL(name) \
...
...
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
浏览文件 @
36739748
...
...
@@ -63,7 +63,7 @@ inline int64_t GetNonZeroNum(const DenseTensor& dense,
}
template
<
typename
T
,
typename
Context
>
void
DenseTo
Sparse
CooKernel
(
const
Context
&
dev_ctx
,
void
DenseToCooKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
int64_t
sparse_dim
,
SparseCooTensor
*
out
)
{
...
...
@@ -107,7 +107,7 @@ void DenseToSparseCooKernel(const Context& dev_ctx,
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CsrToCooCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
CsrToCooCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCooTensor
*
out
)
{
const
DDim
&
x_dims
=
x
.
dims
();
...
...
@@ -157,17 +157,16 @@ void SparseCsrToCooCPUKernel(const CPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CsrToCooKernel
(
const
Context
&
dev_ctx
,
void
CsrToCooKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCooTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
crows
().
dtype
(),
"SparseCsrToCooCPUKernel"
,
([
&
]
{
SparseCsrToCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
crows
().
dtype
(),
"CsrToCooCPUKernel"
,
([
&
]
{
CsrToCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CooToCsrCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
CooToCsrCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCsrTensor
*
out
)
{
const
auto
&
x_dims
=
x
.
dims
();
...
...
@@ -247,17 +246,16 @@ void SparseCooToCsrCPUKernel(const CPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToCsrKernel
(
const
Context
&
dev_ctx
,
void
CooToCsrKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCsrTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"SparseCooToCsrCPUKernel"
,
([
&
]
{
SparseCooToCsrCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"CooToCsrCPUKernel"
,
([
&
]
{
CooToCsrCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CooToDenseCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
CooToDenseCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
)
{
const
auto
non_zero_num
=
x
.
nnz
();
...
...
@@ -300,22 +298,22 @@ void SparseCooToDenseCPUKernel(const CPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToDenseKernel
(
const
Context
&
dev_ctx
,
void
CooToDenseKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"
Sparse
CooToDenseCPUKernel"
,
([
&
]
{
Sparse
CooToDenseCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
x
.
indices
().
dtype
(),
"CooToDenseCPUKernel"
,
([
&
]
{
CooToDenseCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
dense_to_
sparse_
coo
,
PD_REGISTER_KERNEL
(
dense_to_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
DenseTo
Sparse
CooKernel
,
phi
::
sparse
::
DenseToCooKernel
,
float
,
double
,
paddle
::
float16
,
...
...
@@ -325,10 +323,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_coo,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
csr_to_coo
,
PD_REGISTER_KERNEL
(
csr_to_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CsrToCooKernel
,
phi
::
sparse
::
CsrToCooKernel
,
float
,
double
,
paddle
::
float16
,
...
...
@@ -338,10 +336,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_coo,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
coo_to_csr
,
PD_REGISTER_KERNEL
(
coo_to_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToCsrKernel
,
phi
::
sparse
::
CooToCsrKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -351,10 +349,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_csr,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
dense_to_
sparse_
csr
,
PD_REGISTER_KERNEL
(
dense_to_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
DenseTo
Sparse
CsrKernel
,
phi
::
sparse
::
DenseToCsrKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -364,10 +362,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_csr,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
coo_to_dense
,
PD_REGISTER_KERNEL
(
coo_to_dense
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToDenseKernel
,
phi
::
sparse
::
CooToDenseKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -377,10 +375,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
csr_to_dense
,
PD_REGISTER_KERNEL
(
csr_to_dense
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CsrToDenseKernel
,
phi
::
sparse
::
CsrToDenseKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -390,10 +388,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
coo_values
,
PD_REGISTER_KERNEL
(
values_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CooValues
Kernel
,
phi
::
sparse
::
ValuesCoo
Kernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -405,10 +403,10 @@ PD_REGISTER_KERNEL(coo_values,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
csr_values
,
PD_REGISTER_KERNEL
(
values_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CsrValues
Kernel
,
phi
::
sparse
::
ValuesCsr
Kernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
浏览文件 @
36739748
...
...
@@ -43,10 +43,10 @@ void MatmulCooDenseGradKernel(const Context& dev_ctx,
// 'cusparseSDDMM' only support CSR now, so use COO->CSR->COO,
// which will increase some expenses.
EmptyLikeCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
dx
);
SparseCsrTensor
dx_csr
=
Sparse
CooToCsr
<
T
,
Context
>
(
dev_ctx
,
*
dx
);
SparseCsrTensor
dx_csr
=
CooToCsr
<
T
,
Context
>
(
dev_ctx
,
*
dx
);
sparse_blas
.
SDDMM
(
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
&
dx_csr
);
Sparse
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
dx_csr
,
dx
);
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
dx_csr
,
dx
);
}
// dy{Dense} = x'{SparseCoo} * dout{Dense}
...
...
paddle/phi/kernels/sparse/gpu/sparse_utils_kernel.cu
浏览文件 @
36739748
...
...
@@ -93,7 +93,7 @@ __global__ void GetNonZeroElementsAndIndices(const T* dense_data,
}
template
<
typename
T
,
typename
Context
>
void
DenseTo
Sparse
CooKernel
(
const
Context
&
dev_ctx
,
void
DenseToCooKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
int64_t
sparse_dim
,
SparseCooTensor
*
out
)
{
...
...
@@ -208,7 +208,7 @@ __global__ void ConvertCsrCrowsToCooRows(const IntT* crows_ptr,
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CsrToCooGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
CsrToCooGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCooTensor
*
out
)
{
const
DDim
&
x_dims
=
x
.
dims
();
...
...
@@ -274,12 +274,11 @@ void SparseCsrToCooGPUKernel(const GPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CsrToCooKernel
(
const
Context
&
dev_ctx
,
void
CsrToCooKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCooTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
crows
().
dtype
(),
"SparseCsrToCooGPUKernel"
,
([
&
]
{
SparseCsrToCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
crows
().
dtype
(),
"CsrToCooGPUKernel"
,
([
&
]
{
CsrToCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
...
...
@@ -343,7 +342,7 @@ __global__ void ConvertCooRowsToCsrCrows(
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CooToCsrGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
CooToCsrGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCsrTensor
*
out
)
{
const
auto
&
x_dims
=
x
.
dims
();
...
...
@@ -416,17 +415,16 @@ void SparseCooToCsrGPUKernel(const GPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToCsrKernel
(
const
Context
&
dev_ctx
,
void
CooToCsrKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCsrTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"SparseCooToCsrGPUKernel"
,
([
&
]
{
SparseCooToCsrGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"CooToCsrGPUKernel"
,
([
&
]
{
CooToCsrGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
template
<
typename
ValueT
,
typename
IndicesT
>
__global__
void
Kernel
Sparse
CooToDense
(
const
IndicesT
*
indices
,
__global__
void
KernelCooToDense
(
const
IndicesT
*
indices
,
const
int64_t
*
sparse_offsets
,
const
ValueT
*
data
,
ValueT
*
dense_data
,
...
...
@@ -447,7 +445,7 @@ __global__ void KernelSparseCooToDense(const IndicesT* indices,
}
template
<
typename
T
,
typename
IntT
>
void
Sparse
CooToDenseGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
CooToDenseGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
)
{
const
auto
non_zero_num
=
x
.
nnz
();
...
...
@@ -490,7 +488,7 @@ void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx,
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
Kernel
Sparse
CooToDense
<
T
,
IntT
>
KernelCooToDense
<
T
,
IntT
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
...
...
@@ -504,22 +502,22 @@ void SparseCooToDenseGPUKernel(const GPUContext& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToDenseKernel
(
const
Context
&
dev_ctx
,
void
CooToDenseKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
)
{
PD_VISIT_BASE_INTEGRAL_TYPES
(
x
.
indices
().
dtype
(),
"
Sparse
CooToDenseGPUKernel"
,
([
&
]
{
Sparse
CooToDenseGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
x
.
indices
().
dtype
(),
"CooToDenseGPUKernel"
,
([
&
]
{
CooToDenseGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
out
);
}));
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
dense_to_
sparse_
coo
,
PD_REGISTER_KERNEL
(
dense_to_coo
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
DenseTo
Sparse
CooKernel
,
phi
::
sparse
::
DenseToCooKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -529,10 +527,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_coo,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
csr_to_coo
,
PD_REGISTER_KERNEL
(
csr_to_coo
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CsrToCooKernel
,
phi
::
sparse
::
CsrToCooKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -542,10 +540,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_coo,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
coo_to_csr
,
PD_REGISTER_KERNEL
(
coo_to_csr
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToCsrKernel
,
phi
::
sparse
::
CooToCsrKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -555,10 +553,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_csr,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
dense_to_
sparse_
csr
,
PD_REGISTER_KERNEL
(
dense_to_csr
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
DenseTo
Sparse
CsrKernel
,
phi
::
sparse
::
DenseToCsrKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -568,10 +566,10 @@ PD_REGISTER_KERNEL(dense_to_sparse_csr,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
coo_to_dense
,
PD_REGISTER_KERNEL
(
coo_to_dense
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToDenseKernel
,
phi
::
sparse
::
CooToDenseKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -581,10 +579,10 @@ PD_REGISTER_KERNEL(sparse_coo_to_dense,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
sparse_
csr_to_dense
,
PD_REGISTER_KERNEL
(
csr_to_dense
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CsrToDenseKernel
,
phi
::
sparse
::
CsrToDenseKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -594,10 +592,10 @@ PD_REGISTER_KERNEL(sparse_csr_to_dense,
int
,
int64_t
)
{}
PD_REGISTER_KERNEL
(
coo_values
,
PD_REGISTER_KERNEL
(
values_coo
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CooValues
Kernel
,
phi
::
sparse
::
ValuesCoo
Kernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -609,10 +607,10 @@ PD_REGISTER_KERNEL(coo_values,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
csr_values
,
PD_REGISTER_KERNEL
(
values_csr
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CsrValues
Kernel
,
phi
::
sparse
::
ValuesCsr
Kernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.cc
浏览文件 @
36739748
...
...
@@ -20,7 +20,7 @@ namespace phi {
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
CooValues
GradKernel
(
const
Context
&
dev_ctx
,
void
ValuesCoo
GradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
out_grad
,
SparseCooTensor
*
x_grad
)
{
...
...
@@ -28,7 +28,7 @@ void CooValuesGradKernel(const Context& dev_ctx,
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToDenseGradKernel
(
const
Context
&
dev_ctx
,
void
CooToDenseGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
out_grad
,
SparseCooTensor
*
x_grad
)
{
...
...
@@ -38,10 +38,10 @@ void SparseCooToDenseGradKernel(const Context& dev_ctx,
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
coo_values
_grad
,
PD_REGISTER_KERNEL
(
values_coo
_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CooValues
GradKernel
,
phi
::
sparse
::
ValuesCoo
GradKernel
,
float
,
double
,
uint8_t
,
...
...
@@ -52,10 +52,10 @@ PD_REGISTER_KERNEL(coo_values_grad,
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_
coo_to_dense_grad
,
PD_REGISTER_KERNEL
(
coo_to_dense_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToDenseGradKernel
,
phi
::
sparse
::
CooToDenseGradKernel
,
float
,
double
,
uint8_t
,
...
...
@@ -80,10 +80,10 @@ PD_REGISTER_KERNEL(sparse_coo_tensor_grad,
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
coo_values
_grad
,
PD_REGISTER_KERNEL
(
values_coo
_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
CooValues
GradKernel
,
phi
::
sparse
::
ValuesCoo
GradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
@@ -94,10 +94,10 @@ PD_REGISTER_KERNEL(coo_values_grad,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
sparse_
coo_to_dense_grad
,
PD_REGISTER_KERNEL
(
coo_to_dense_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
Sparse
CooToDenseGradKernel
,
phi
::
sparse
::
CooToDenseGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
...
...
paddle/phi/kernels/sparse/sparse_utils_grad_kernel.h
浏览文件 @
36739748
...
...
@@ -22,13 +22,13 @@ namespace phi {
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
CooValues
GradKernel
(
const
Context
&
dev_ctx
,
void
ValuesCoo
GradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
out_grad
,
SparseCooTensor
*
x_grad
);
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToDenseGradKernel
(
const
Context
&
dev_ctx
,
void
CooToDenseGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
out_grad
,
SparseCooTensor
*
x_grad
);
...
...
paddle/phi/kernels/sparse/sparse_utils_kernel.h
浏览文件 @
36739748
...
...
@@ -24,55 +24,53 @@ namespace phi {
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
DenseTo
Sparse
CooKernel
(
const
Context
&
dev_ctx
,
void
DenseToCooKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
int64_t
sparse_dim
,
SparseCooTensor
*
out
);
template
<
typename
T
,
typename
Context
>
SparseCooTensor
DenseTo
Sparse
Coo
(
const
Context
&
dev_ctx
,
SparseCooTensor
DenseToCoo
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
int64_t
sparse_dim
)
{
DenseTensor
indices
;
DenseTensor
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
DenseTo
Sparse
CooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
sparse_dim
,
&
coo
);
DenseToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
sparse_dim
,
&
coo
);
return
coo
;
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CsrToCooKernel
(
const
Context
&
dev_ctx
,
void
CsrToCooKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCooTensor
*
out
);
template
<
typename
T
,
typename
Context
>
SparseCooTensor
SparseCsrToCoo
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
)
{
SparseCooTensor
CsrToCoo
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
)
{
DenseTensor
indices
;
DenseTensor
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
Sparse
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
return
coo
;
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToCsrKernel
(
const
Context
&
dev_ctx
,
void
CooToCsrKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCsrTensor
*
out
);
template
<
typename
T
,
typename
Context
>
SparseCsrTensor
SparseCooToCsr
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
SparseCsrTensor
CooToCsr
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
DenseTensor
crows
;
DenseTensor
cols
;
DenseTensor
non_zero_elements
;
SparseCsrTensor
csr
(
crows
,
cols
,
non_zero_elements
,
x
.
dims
());
Sparse
CooToCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
csr
);
CooToCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
csr
);
return
csr
;
}
template
<
typename
T
,
typename
Context
>
void
DenseTo
Sparse
CsrKernel
(
const
Context
&
dev_ctx
,
void
DenseToCsrKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
SparseCsrTensor
*
out
)
{
const
auto
&
x_dims
=
x
.
dims
();
...
...
@@ -85,61 +83,61 @@ void DenseToSparseCsrKernel(const Context& dev_ctx,
DenseTensor
indices
;
DenseTensor
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
DenseTo
Sparse
CooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
sparse_dim
,
&
coo
);
Sparse
CooToCsrKernel
<
T
,
Context
>
(
dev_ctx
,
coo
,
out
);
DenseToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
sparse_dim
,
&
coo
);
CooToCsrKernel
<
T
,
Context
>
(
dev_ctx
,
coo
,
out
);
}
template
<
typename
T
,
typename
Context
>
SparseCsrTensor
DenseTo
Sparse
Csr
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
SparseCsrTensor
DenseToCsr
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
)
{
DenseTensor
crows
;
DenseTensor
cols
;
DenseTensor
non_zero_elements
;
SparseCsrTensor
csr
(
crows
,
cols
,
non_zero_elements
,
x
.
dims
());
DenseTo
Sparse
CsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
csr
);
DenseToCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
csr
);
return
csr
;
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CooToDenseKernel
(
const
Context
&
dev_ctx
,
void
CooToDenseKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
DenseTensor
Sparse
CooToDense
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
DenseTensor
CooToDense
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
)
{
DenseTensorMeta
meta
(
x
.
dtype
(),
x
.
dims
(),
x
.
non_zero_elements
().
layout
());
DenseTensor
dense
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
meta
));
Sparse
CooToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
dense
);
CooToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
dense
);
return
dense
;
}
template
<
typename
T
,
typename
Context
>
void
Sparse
CsrToDenseKernel
(
const
Context
&
dev_ctx
,
void
CsrToDenseKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
DenseTensor
*
out
)
{
DenseTensor
indices
;
DenseTensor
values
;
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
Sparse
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
Sparse
CooToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
coo
,
out
);
CsrToCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
coo
);
CooToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
coo
,
out
);
}
template
<
typename
T
,
typename
Context
>
DenseTensor
Sparse
CsrToDense
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
)
{
DenseTensor
CsrToDense
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
)
{
DenseTensorMeta
meta
(
x
.
dtype
(),
x
.
dims
(),
x
.
non_zero_elements
().
layout
());
DenseTensor
dense
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
meta
));
Sparse
CsrToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
dense
);
CsrToDenseKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
&
dense
);
return
dense
;
}
template
<
typename
T
,
typename
Context
>
void
CooValues
Kernel
(
const
Context
&
dev_ctx
,
void
ValuesCoo
Kernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
DenseTensor
*
out
)
{
*
out
=
x
.
non_zero_elements
();
}
template
<
typename
T
,
typename
Context
>
void
CsrValues
Kernel
(
const
Context
&
dev_ctx
,
void
ValuesCsr
Kernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
DenseTensor
*
out
)
{
*
out
=
x
.
non_zero_elements
();
...
...
paddle/phi/tests/api/test_sparse_utils_api.cc
浏览文件 @
36739748
...
...
@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
PD_DECLARE_KERNEL
(
dense_to_
sparse_
coo
,
CPU
,
ALL_LAYOUT
);
PD_DECLARE_KERNEL
(
dense_to_coo
,
CPU
,
ALL_LAYOUT
);
TEST
(
API
,
to_sparse_coo
)
{
const
auto
alloc
=
std
::
make_shared
<
paddle
::
experimental
::
DefaultAllocator
>
(
...
...
paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
浏览文件 @
36739748
...
...
@@ -47,7 +47,7 @@ TEST(DEV_API, sparse_relu) {
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
FLOAT32
,
{
3
,
4
},
DataLayout
::
NCHW
));
memcpy
(
dense_x
.
data
<
float
>
(),
data
.
data
(),
data
.
size
()
*
sizeof
(
float
));
auto
sparse_coo
=
sparse
::
DenseTo
Sparse
Coo
<
float
>
(
dev_ctx_cpu
,
dense_x
,
2
);
auto
sparse_coo
=
sparse
::
DenseToCoo
<
float
>
(
dev_ctx_cpu
,
dense_x
,
2
);
auto
sparse_out
=
sparse
::
ReluCoo
<
float
>
(
dev_ctx_cpu
,
sparse_coo
);
DenseTensor
dense_out
=
...
...
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
浏览文件 @
36739748
...
...
@@ -49,12 +49,9 @@ namespace tests {
const Sparse##type##Tensor& y, \
const DDim& dense_dims) { \
auto out = sparse::ElementWise##name##type<T>(dev_ctx_cpu, x, y); \
const DenseTensor denseX = \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, x); \
const DenseTensor denseY = \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, y); \
const DenseTensor denseOut = \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, out); \
const DenseTensor denseX = sparse::type##ToDense<T>(dev_ctx_cpu, x); \
const DenseTensor denseY = sparse::type##ToDense<T>(dev_ctx_cpu, y); \
const DenseTensor denseOut = sparse::type##ToDense<T>(dev_ctx_cpu, out); \
auto expectResult = name<T>(dev_ctx_cpu, denseX, denseY); \
for (int j = 0; j < denseOut.numel(); ++j) { \
auto actualResultRow = denseOut.template data<T>()[j]; \
...
...
@@ -114,8 +111,8 @@ TEST(DEV_API, sparse_elementwise_coo_kernel_double) {
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
auto
coo_x
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
coo_y
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
auto
coo_x
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
coo_y
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
TestElementWiseAddCoo
<
T
>
(
dev_ctx_cpu
,
coo_x
,
coo_y
,
dense_dims
);
TestElementWiseSubtractCoo
<
T
>
(
dev_ctx_cpu
,
coo_x
,
coo_y
,
dense_dims
);
...
...
@@ -159,8 +156,8 @@ TEST(DEV_API, sparse_elementwise_csr_kernel_float) {
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
auto
csr_x
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
csr_y
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
auto
csr_x
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
csr_y
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
TestElementWiseAddCsr
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
TestElementWiseSubtractCsr
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
...
...
@@ -190,20 +187,18 @@ TEST(DEV_API, sparse_elementwise_csr_kernel_float) {
dev_ctx_cpu, \
DenseTensorMeta(DataType::FLOAT32, dense_dims, DataLayout::NCHW)); \
\
phi::name##GradKernel<T>( \
dev_ctx_cpu, \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, x), \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, y), \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, out), \
phi::name##GradKernel<T>(dev_ctx_cpu, \
sparse::type##ToDense<T>(dev_ctx_cpu, x), \
sparse::type##ToDense<T>(dev_ctx_cpu, y), \
sparse::type##ToDense<T>(dev_ctx_cpu, out), \
-1, \
&expectdx, \
&expectdy); \
const DenseTensor densedX = \
sparse::
Sparse##type##ToDense<T>(dev_ctx_cpu, dresult[0]);
\
sparse::
type##ToDense<T>(dev_ctx_cpu, dresult[0]);
\
const DenseTensor densedY = \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, dresult[1]); \
const DenseTensor denseOut = \
sparse::Sparse##type##ToDense<T>(dev_ctx_cpu, out); \
sparse::type##ToDense<T>(dev_ctx_cpu, dresult[1]); \
const DenseTensor denseOut = sparse::type##ToDense<T>(dev_ctx_cpu, out); \
\
for (int j = 0; j < densedX.numel(); ++j) { \
auto actualResultRow = densedX.template data<T>()[j]; \
...
...
@@ -248,18 +243,16 @@ void TestElementWiseDivideCsrGrad(const Context& dev_ctx_cpu,
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
FLOAT32
,
dense_dims
,
DataLayout
::
NCHW
));
phi
::
DivideGradKernel
<
T
>
(
dev_ctx_cpu
,
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_cpu
,
x
),
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_cpu
,
y
),
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
x
),
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
y
),
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
out
),
-
1
,
&
expectdx
,
&
expectdy
);
const
DenseTensor
densedX
=
sparse
::
SparseCsrToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
0
]);
const
DenseTensor
densedY
=
sparse
::
SparseCsrToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
1
]);
const
DenseTensor
denseOut
=
sparse
::
SparseCsrToDense
<
T
>
(
dev_ctx_cpu
,
out
);
const
DenseTensor
densedX
=
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
0
]);
const
DenseTensor
densedY
=
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
1
]);
const
DenseTensor
denseOut
=
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
out
);
for
(
int
j
=
0
;
j
<
densedX
.
numel
();
++
j
)
{
auto
actualResultRow
=
densedX
.
template
data
<
T
>()[
j
];
auto
expectResultRow
=
expectdx
.
template
data
<
T
>()[
j
];
...
...
@@ -291,18 +284,16 @@ void TestElementWiseDivideCooGrad(const Context& dev_ctx_cpu,
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
FLOAT32
,
dense_dims
,
DataLayout
::
NCHW
));
phi
::
DivideGradKernel
<
T
>
(
dev_ctx_cpu
,
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_cpu
,
x
),
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_cpu
,
y
),
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
x
),
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
y
),
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
out
),
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
out
),
-
1
,
&
expectdx
,
&
expectdy
);
const
DenseTensor
densedX
=
sparse
::
SparseCooToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
0
]);
const
DenseTensor
densedY
=
sparse
::
SparseCooToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
1
]);
const
DenseTensor
denseOut
=
sparse
::
SparseCooToDense
<
T
>
(
dev_ctx_cpu
,
out
);
const
DenseTensor
densedX
=
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
0
]);
const
DenseTensor
densedY
=
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
dresult
[
1
]);
const
DenseTensor
denseOut
=
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
out
);
for
(
int
j
=
0
;
j
<
densedX
.
numel
();
++
j
)
{
auto
actualResultRow
=
densedX
.
template
data
<
T
>()[
j
];
auto
expectResultRow
=
expectdx
.
template
data
<
T
>()[
j
];
...
...
@@ -356,11 +347,11 @@ TEST(DEV_API, sparse_elementwise_csr_grad_kernel_float) {
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
auto
csr_x
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
csr_y
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
auto
csr_x
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
csr_y
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
auto
dx
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
auto
dy
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
dx
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_y
);
auto
dy
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
TestElementWiseAddCsrGrad
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
TestElementWiseSubtractCsrGrad
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
...
...
@@ -402,11 +393,11 @@ TEST(DEV_API, sparse_elementwise_coo_grad_kernel_double) {
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
auto
csr_x
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
csr_y
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
auto
csr_x
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
csr_y
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
auto
dx
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
auto
dy
=
sparse
::
DenseTo
Sparse
Coo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
dx
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_y
,
sparse_dim
);
auto
dy
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
TestElementWiseAddCooGrad
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
TestElementWiseSubtractCooGrad
<
T
>
(
dev_ctx_cpu
,
csr_x
,
csr_y
,
dense_dims
);
...
...
paddle/phi/tests/kernels/test_sparse_utils_dev_api.cc
浏览文件 @
36739748
...
...
@@ -94,8 +94,7 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x,
.
get
());
// 1. test cpu
auto
cpu_sparse_out
=
sparse
::
DenseToSparseCoo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
auto
cpu_sparse_out
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_cpu
,
dense_x
,
sparse_dim
);
CheckResult
<
T
,
int64_t
>
(
&
dev_ctx_cpu
,
cpu_sparse_out
,
non_zero_data
,
...
...
@@ -129,8 +128,7 @@ void TestDenseToSparseCoo(const DenseTensor& dense_x,
DenseTensorMeta
(
dense_x
.
dtype
(),
dense_x
.
dims
(),
dense_x
.
layout
()));
phi
::
Copy
(
dev_ctx_gpu
,
dense_x
,
phi
::
GPUPlace
(),
true
,
&
d_dense_x
);
auto
sparse_out
=
sparse
::
DenseToSparseCoo
<
T
>
(
dev_ctx_gpu
,
d_dense_x
,
sparse_dim
);
auto
sparse_out
=
sparse
::
DenseToCoo
<
T
>
(
dev_ctx_gpu
,
d_dense_x
,
sparse_dim
);
CheckResult
<
T
,
int64_t
>
(
&
dev_ctx_gpu
,
sparse_out
,
non_zero_data
,
...
...
@@ -310,7 +308,7 @@ void TestSparseCsrToCoo(const DDim& dense_dims,
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
phi
::
CPUPlace
())
.
get
());
auto
cpu_sparse_out
=
sparse
::
Sparse
CsrToCoo
<
T
>
(
dev_ctx_cpu
,
csr
);
auto
cpu_sparse_out
=
sparse
::
CsrToCoo
<
T
>
(
dev_ctx_cpu
,
csr
);
CheckResult
<
T
,
int64_t
>
(
&
dev_ctx_cpu
,
cpu_sparse_out
,
non_zero_data
,
...
...
@@ -345,7 +343,7 @@ void TestSparseCsrToCoo(const DDim& dense_dims,
phi
::
Copy
(
dev_ctx_gpu
,
cols
,
d_cols
.
place
(),
true
,
&
d_cols
);
phi
::
Copy
(
dev_ctx_gpu
,
values
,
d_values
.
place
(),
true
,
&
d_values
);
phi
::
SparseCsrTensor
d_csr
(
d_crows
,
d_cols
,
d_values
,
dense_dims
);
auto
cuda_sparse_out
=
sparse
::
Sparse
CsrToCoo
<
T
>
(
dev_ctx_gpu
,
d_csr
);
auto
cuda_sparse_out
=
sparse
::
CsrToCoo
<
T
>
(
dev_ctx_gpu
,
d_csr
);
CheckResult
<
T
,
int64_t
>
(
&
dev_ctx_gpu
,
cuda_sparse_out
,
non_zero_data
,
...
...
@@ -491,7 +489,7 @@ void TestCooToCsr(const DDim& dense_dims,
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
phi
::
CPUPlace
())
.
get
());
auto
cpu_sparse_out
=
sparse
::
Sparse
CooToCsr
<
T
>
(
dev_ctx_cpu
,
coo
);
auto
cpu_sparse_out
=
sparse
::
CooToCsr
<
T
>
(
dev_ctx_cpu
,
coo
);
CheckCsrResult
<
T
,
int64_t
>
(
&
dev_ctx_cpu
,
cpu_sparse_out
,
non_zero_data
,
...
...
@@ -525,7 +523,7 @@ void TestCooToCsr(const DDim& dense_dims,
phi
::
Copy
(
dev_ctx_gpu
,
indices
,
phi
::
GPUPlace
(),
true
,
&
d_indices
);
phi
::
Copy
(
dev_ctx_gpu
,
values
,
phi
::
GPUPlace
(),
true
,
&
d_values
);
phi
::
SparseCooTensor
d_coo
(
d_indices
,
d_values
,
dense_dims
);
auto
cuda_sparse_out
=
sparse
::
Sparse
CooToCsr
<
T
>
(
dev_ctx_gpu
,
d_coo
);
auto
cuda_sparse_out
=
sparse
::
CooToCsr
<
T
>
(
dev_ctx_gpu
,
d_coo
);
CheckCsrResult
<
T
,
int64_t
>
(
&
dev_ctx_gpu
,
cuda_sparse_out
,
non_zero_data
,
...
...
@@ -591,7 +589,7 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x,
.
get
());
// 1. test cpu
auto
cpu_sparse_out
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
auto
cpu_sparse_out
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_cpu
,
dense_x
);
CheckCsrResult
<
T
,
int64_t
>
(
&
dev_ctx_cpu
,
cpu_sparse_out
,
non_zero_data
,
...
...
@@ -624,7 +622,7 @@ void TestDenseToSparseCsr(const DenseTensor& dense_x,
.
get
());
dev_ctx_gpu
.
PartialInitWithAllocator
();
phi
::
Copy
(
dev_ctx_gpu
,
dense_x
,
phi
::
GPUPlace
(),
true
,
&
d_dense_x
);
auto
sparse_out
=
sparse
::
DenseTo
Sparse
Csr
<
T
>
(
dev_ctx_gpu
,
d_dense_x
);
auto
sparse_out
=
sparse
::
DenseToCsr
<
T
>
(
dev_ctx_gpu
,
d_dense_x
);
CheckCsrResult
<
T
,
int64_t
>
(
&
dev_ctx_gpu
,
sparse_out
,
...
...
@@ -731,7 +729,7 @@ void TestSparseCooToDense(const DDim& dense_dims,
SparseCooTensor
coo
(
dense_indices
,
dense_elements
,
dense_dims
);
DenseTensor
dense_out
=
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_cpu
,
coo
);
DenseTensor
dense_out
=
sparse
::
CooToDense
<
T
>
(
dev_ctx_cpu
,
coo
);
int
cmp
=
memcmp
(
&
dense_data
[
0
],
dense_out
.
data
<
T
>
(),
sizeof
(
T
)
*
dense_data
.
size
());
...
...
@@ -763,7 +761,7 @@ void TestSparseCooToDense(const DDim& dense_dims,
phi
::
Copy
(
dev_ctx_gpu
,
dense_elements
,
phi
::
GPUPlace
(),
true
,
&
d_dense_elements
);
SparseCooTensor
coo_cuda
(
d_dense_indices
,
d_dense_elements
,
dense_dims
);
auto
dense_out_cuda
=
sparse
::
Sparse
CooToDense
<
T
>
(
dev_ctx_gpu
,
coo_cuda
);
auto
dense_out_cuda
=
sparse
::
CooToDense
<
T
>
(
dev_ctx_gpu
,
coo_cuda
);
DenseTensor
h_dense_out
(
alloc
.
get
(),
DenseTensorMeta
(
dense_out_cuda
.
dtype
(),
...
...
@@ -878,7 +876,7 @@ void TestSparseCsrToDense(const DDim& dense_dims,
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
phi
::
CPUPlace
())
.
get
());
DenseTensor
cpu_sparse_out
=
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_cpu
,
csr
);
DenseTensor
cpu_sparse_out
=
sparse
::
CsrToDense
<
T
>
(
dev_ctx_cpu
,
csr
);
int
cmp_cpu
=
memcmp
(
cpu_sparse_out
.
data
<
T
>
(),
dense_data
.
data
(),
sizeof
(
T
)
*
dense_data
.
size
());
...
...
@@ -911,7 +909,7 @@ void TestSparseCsrToDense(const DDim& dense_dims,
phi
::
Copy
(
dev_ctx_gpu
,
cols
,
phi
::
GPUPlace
(),
true
,
&
d_cols
);
phi
::
Copy
(
dev_ctx_gpu
,
values
,
phi
::
GPUPlace
(),
true
,
&
d_values
);
phi
::
SparseCsrTensor
d_csr
(
d_crows
,
d_cols
,
d_values
,
dense_dims
);
auto
cuda_sparse_out
=
sparse
::
Sparse
CsrToDense
<
T
>
(
dev_ctx_gpu
,
d_csr
);
auto
cuda_sparse_out
=
sparse
::
CsrToDense
<
T
>
(
dev_ctx_gpu
,
d_csr
);
phi
::
DenseTensor
h_out
(
alloc
.
get
(),
cpu_sparse_out
.
meta
());
phi
::
Copy
(
dev_ctx_gpu
,
cuda_sparse_out
,
phi
::
CPUPlace
(),
true
,
&
h_out
);
int
cmp_cuda
=
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
36739748
...
...
@@ -923,12 +923,7 @@ def monkey_patch_varbase():
print(sparse_x.values())
#[1, 2, 3, 4, 5]
"""
if
self
.
is_sparse_coo
()
or
self
.
is_sparse_csr
():
return
_C_ops
.
sparse_values
(
self
)
else
:
raise
ValueError
(
"only SparseCooTensor and SparseCsrTensor have method values"
)
@
framework
.
dygraph_only
def
to_dense
(
self
):
...
...
@@ -956,12 +951,7 @@ def monkey_patch_varbase():
# [4., 5., 0., 0.]]
"""
if
self
.
is_sparse_coo
():
return
_C_ops
.
sparse_coo_to_dense
(
self
)
elif
self
.
is_sparse_csr
():
return
_C_ops
.
sparse_to_dense
(
self
)
else
:
return
self
@
framework
.
dygraph_only
def
to_sparse_coo
(
self
,
sparse_dim
):
...
...
@@ -987,16 +977,7 @@ def monkey_patch_varbase():
#values=[1., 2., 3., 4.]
"""
if
self
.
is_sparse_csr
():
return
_C_ops
.
sparse_to_sparse_coo
(
self
,
sparse_dim
)
elif
self
.
is_sparse_coo
():
return
self
elif
self
.
is_selected_rows
():
raise
ValueError
(
"SelectedRows does not support to_sparse_coo method"
)
else
:
#is dense tensor
return
_C_ops
.
sparse_dense_to_coo
(
self
,
sparse_dim
)
if
framework
.
_in_eager_mode_
and
not
hasattr
(
core
,
"eager"
):
return
...
...
python/paddle/incubate/sparse/creation.py
浏览文件 @
36739748
...
...
@@ -166,7 +166,7 @@ def sparse_coo_tensor(indices,
"the number of dimensions(len(shape) must be sparse_dim({}) + dense_dim({}), but get {}"
.
format
(
sparse_dim
,
dense_dim
,
len
(
shape
)))
return
_C_ops
.
sparse_
create_
sparse_coo_tensor
(
values
,
indices
,
shape
)
return
_C_ops
.
sparse_sparse_coo_tensor
(
values
,
indices
,
shape
)
#TODO: need to support shape is None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录