Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
eec4e034
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eec4e034
编写于
6月 24, 2022
作者:
zhouweiwei2014
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse] support batch compute of SparseTensor matmul/masked_matmul/softmax (#43703)
上级
fa9586a7
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
457 addition
and
232 deletion
+457
-232
paddle/fluid/platform/dynload/cusparse.h
paddle/fluid/platform/dynload/cusparse.h
+25
-18
paddle/phi/backends/dynload/cusparse.h
paddle/phi/backends/dynload/cusparse.h
+25
-18
paddle/phi/kernels/funcs/sparse/sparse_blas.h
paddle/phi/kernels/funcs/sparse/sparse_blas.h
+12
-22
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
+88
-58
paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc
+28
-19
paddle/phi/kernels/sparse/cpu/softmax_kernel.cc
paddle/phi/kernels/sparse/cpu/softmax_kernel.cc
+29
-21
paddle/phi/kernels/sparse/empty_kernel.cc
paddle/phi/kernels/sparse/empty_kernel.cc
+51
-1
paddle/phi/kernels/sparse/empty_kernel.h
paddle/phi/kernels/sparse/empty_kernel.h
+6
-0
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
+8
-22
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
+9
-24
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
+24
-9
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
+25
-13
python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py
python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py
+71
-1
python/paddle/fluid/tests/unittests/test_sparse_softmax_op.py
...on/paddle/fluid/tests/unittests/test_sparse_softmax_op.py
+54
-4
python/paddle/incubate/sparse/nn/functional/activation.py
python/paddle/incubate/sparse/nn/functional/activation.py
+1
-1
python/paddle/incubate/sparse/nn/layer/activation.py
python/paddle/incubate/sparse/nn/layer/activation.py
+1
-1
未找到文件。
paddle/fluid/platform/dynload/cusparse.h
浏览文件 @
eec4e034
...
@@ -31,24 +31,22 @@ namespace dynload {
...
@@ -31,24 +31,22 @@ namespace dynload {
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnMat);
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
CUSPARSE_ROUTINE_EACH
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
...
@@ -62,8 +60,17 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
...
@@ -62,8 +60,17 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
CUSPARSE_ROUTINE_EACH_R2
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
CUSPARSE_ROUTINE_EACH_R2
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
#if CUDA_VERSION >= 11070
#define CUSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCooSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH_R3
(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
#endif // PADDLE_WITH_CUDA
#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
}
// namespace dynload
}
// namespace dynload
}
// namespace platform
}
// namespace platform
...
...
paddle/phi/backends/dynload/cusparse.h
浏览文件 @
eec4e034
...
@@ -43,24 +43,22 @@ extern void *cusparse_dso_handle;
...
@@ -43,24 +43,22 @@ extern void *cusparse_dso_handle;
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
// APIs available after CUDA 11.0
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseCreate); \
__macro(cusparseSetStream); \
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseSetMatIndexBase); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroyDnMat);
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
CUSPARSE_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
...
@@ -74,8 +72,17 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
...
@@ -74,8 +72,17 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
CUSPARSE_ROUTINE_EACH_R2
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
CUSPARSE_ROUTINE_EACH_R2
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
#if CUDA_VERSION >= 11070
#define CUSPARSE_ROUTINE_EACH_R3(__macro) \
__macro(cusparseDnMatSetStridedBatch); \
__macro(cusparseCooSetStridedBatch); \
__macro(cusparseCsrSetStridedBatch);
CUSPARSE_ROUTINE_EACH_R3
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
)
#endif
#endif
#endif // PADDLE_WITH_CUDA
#undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
#undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
}
// namespace dynload
}
// namespace dynload
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/funcs/sparse/sparse_blas.h
浏览文件 @
eec4e034
...
@@ -28,33 +28,23 @@ class SparseBlas {
...
@@ -28,33 +28,23 @@ class SparseBlas {
public:
public:
explicit
SparseBlas
(
const
DeviceContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{}
explicit
SparseBlas
(
const
DeviceContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{}
// TODO(zhouwei25): implement "COO @ DENSE -> DENSE" of DSDMM
template
<
typename
T
,
typename
TensorType
>
template
<
typename
T
>
void
SPMM
(
bool
transa
,
void
DSDMM
(
bool
transa
,
bool
transb
,
bool
transb
,
T
alpha
,
T
alpha
,
const
TensorType
&
mat_a
,
const
phi
::
SparseCooTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
T
beta
,
phi
::
DenseTensor
*
mat_out
)
const
;
phi
::
DenseTensor
*
mat_c
)
const
;
template
<
typename
T
>
void
DSDMM
(
bool
transa
,
bool
transb
,
T
alpha
,
const
phi
::
SparseCsrTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
phi
::
DenseTensor
*
mat_c
)
const
;
template
<
typename
T
>
template
<
typename
T
,
typename
TensorType
>
void
SDDMM
(
bool
transa
,
void
SDDMM
(
bool
transa
,
bool
transb
,
bool
transb
,
T
alpha
,
T
alpha
,
const
phi
::
DenseTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
T
beta
,
phi
::
SparseCsrTensor
*
mat_c
)
const
;
TensorType
*
mat_out
)
const
;
private:
private:
const
DeviceContext
&
dev_ctx_
;
const
DeviceContext
&
dev_ctx_
;
...
@@ -66,8 +56,8 @@ class SparseBlasT : private SparseBlas<DeviceContext> {
...
@@ -66,8 +56,8 @@ class SparseBlasT : private SparseBlas<DeviceContext> {
using
SparseBlas
<
DeviceContext
>::
SparseBlas
;
using
SparseBlas
<
DeviceContext
>::
SparseBlas
;
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
void
DSD
MM
(
ARGS
...
args
)
const
{
void
SP
MM
(
ARGS
...
args
)
const
{
Base
()
->
template
DSD
MM
<
T
>(
args
...);
Base
()
->
template
SP
MM
<
T
>(
args
...);
}
}
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
...
...
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
浏览文件 @
eec4e034
...
@@ -47,6 +47,61 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
...
@@ -47,6 +47,61 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
}
}
}
}
template
<
typename
T
,
typename
IntT
>
inline
void
CreateCsrDescriptor
(
const
phi
::
SparseCsrTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
,
cusparseSpMatDescr_t
*
descriptor
)
{
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
PADDLE_ENFORCE_GE
(
x_ndims
,
2
,
phi
::
errors
::
InvalidArgument
(
"the dim size of SparseCsrTensor must be "
"greater than or eaqual to 2."
));
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
x_ndims
-
2
;
i
++
)
{
batch_size
*=
xdim_vec
[
i
];
}
PADDLE_ENFORCE_EQ
(
x
.
non_zero_crows
().
numel
(),
batch_size
*
(
M
+
1
),
phi
::
errors
::
PreconditionNotMet
(
"the length of SparseCsrTensor crows is not right."
));
const
IntT
*
crows_data
=
x
.
non_zero_crows
().
data
<
IntT
>
();
const
IntT
*
cols_data
=
x
.
non_zero_cols
().
data
<
IntT
>
();
const
T
*
values_data
=
x
.
non_zero_elements
().
data
<
T
>
();
int64_t
batch_nnz
=
x
.
nnz
()
/
batch_size
;
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
dev_ctx
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseCreateCsr
(
descriptor
,
M
,
N
,
batch_nnz
,
const_cast
<
IntT
*>
(
crows_data
),
const_cast
<
IntT
*>
(
cols_data
),
const_cast
<
T
*>
(
values_data
),
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
gpu_type
);
});
if
(
batch_size
>
1
)
{
#if CUDA_VERSION >= 11070
dev_ctx
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseCsrSetStridedBatch
(
*
descriptor
,
batch_size
,
M
+
1
,
batch_nnz
);
});
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Batch Sparse matmul use 'cusparseCsrSetStridedBatch', which is "
"supported from CUDA 11.7"
));
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
class
CuSparseSpMatDescriptor
{
class
CuSparseSpMatDescriptor
{
public:
public:
...
@@ -55,45 +110,9 @@ class CuSparseSpMatDescriptor {
...
@@ -55,45 +110,9 @@ class CuSparseSpMatDescriptor {
:
dev_ctx_
(
dev_ctx
)
{
:
dev_ctx_
(
dev_ctx
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"CuSparseSpMatDescriptor"
,
([
&
]
{
x
.
non_zero_crows
().
dtype
(),
"CuSparseSpMatDescriptor"
,
([
&
]
{
const
data_t
*
crows_data
=
x
.
non_zero_crows
().
data
<
data_t
>
();
CreateCsrDescriptor
<
T
,
data_t
>
(
x
,
dev_ctx_
,
&
descriptor_
);
const
data_t
*
cols_data
=
x
.
non_zero_cols
().
data
<
data_t
>
();
const
T
*
values_data
=
x
.
non_zero_elements
().
data
<
T
>
();
int64_t
nnz
=
x
.
nnz
();
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
for
(
int
i
=
0
;
i
<
x_ndims
-
2
;
i
++
)
{
batch_size
*=
xdim_vec
[
i
];
}
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseCreateCsr
(
&
descriptor_
,
M
,
N
,
nnz
,
const_cast
<
data_t
*>
(
crows_data
),
const_cast
<
data_t
*>
(
cols_data
),
const_cast
<
T
*>
(
values_data
),
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_64I
,
CUSPARSE_INDEX_BASE_ZERO
,
gpu_type
);
});
PADDLE_ENFORCE_EQ
(
x
.
non_zero_crows
().
numel
(),
batch_size
*
(
M
+
1
));
PADDLE_ENFORCE_EQ
(
x
.
non_zero_cols
().
numel
(),
x
.
nnz
());
if
(
batch_size
>
1
)
{
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseCsrSetStridedBatch
(
descriptor_
,
batch_size
,
M
+
1
,
nnz
);
});
}
}));
}));
VLOG
(
6
)
<<
"Create csr cusparseSpMatDescr_t "
<<
&
descriptor_
;
VLOG
(
6
)
<<
"Create cusparseSpMatDescr_t "
<<
&
descriptor_
;
}
}
~
CuSparseSpMatDescriptor
()
{
~
CuSparseSpMatDescriptor
()
{
...
@@ -116,9 +135,14 @@ class CuSparseDnMatDescriptor {
...
@@ -116,9 +135,14 @@ class CuSparseDnMatDescriptor {
explicit
CuSparseDnMatDescriptor
(
const
phi
::
DenseTensor
&
x
,
explicit
CuSparseDnMatDescriptor
(
const
phi
::
DenseTensor
&
x
,
const
phi
::
GPUContext
&
dev_ctx
)
const
phi
::
GPUContext
&
dev_ctx
)
:
dev_ctx_
(
dev_ctx
)
{
:
dev_ctx_
(
dev_ctx
)
{
const
T
*
x_data
=
x
.
data
<
T
>
();
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
std
::
vector
<
int64_t
>
xdim_vec
=
phi
::
vectorize
(
x
.
dims
());
auto
x_ndims
=
xdim_vec
.
size
();
auto
x_ndims
=
xdim_vec
.
size
();
PADDLE_ENFORCE_GE
(
x_ndims
,
2
,
phi
::
errors
::
InvalidArgument
(
"the dim size of DenseTensor must be "
"greater than or eaqual to 2."
));
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
M
=
xdim_vec
[
x_ndims
-
2
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int64_t
N
=
xdim_vec
[
x_ndims
-
1
];
int
batch_size
=
1
;
int
batch_size
=
1
;
...
@@ -126,6 +150,7 @@ class CuSparseDnMatDescriptor {
...
@@ -126,6 +150,7 @@ class CuSparseDnMatDescriptor {
batch_size
*=
xdim_vec
[
i
];
batch_size
*=
xdim_vec
[
i
];
}
}
const
T
*
x_data
=
x
.
data
<
T
>
();
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseCreateDnMat
(
&
descriptor_
,
phi
::
dynload
::
cusparseCreateDnMat
(
&
descriptor_
,
...
@@ -139,10 +164,16 @@ class CuSparseDnMatDescriptor {
...
@@ -139,10 +164,16 @@ class CuSparseDnMatDescriptor {
PADDLE_ENFORCE_EQ
(
x
.
numel
(),
batch_size
*
M
*
N
);
PADDLE_ENFORCE_EQ
(
x
.
numel
(),
batch_size
*
M
*
N
);
if
(
batch_size
>
1
)
{
if
(
batch_size
>
1
)
{
#if CUDA_VERSION >= 11070
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseDnMatSetStridedBatch
(
phi
::
dynload
::
cusparseDnMatSetStridedBatch
(
descriptor_
,
batch_size
,
M
*
N
);
descriptor_
,
batch_size
,
M
*
N
);
});
});
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Batch Sparse matmul use 'cusparseDnMatSetStridedBatch', which is "
"supported from CUDA 11.7"
));
#endif
}
}
VLOG
(
6
)
<<
"Create cusparseDnMatDescr_t "
<<
&
descriptor_
;
VLOG
(
6
)
<<
"Create cusparseDnMatDescr_t "
<<
&
descriptor_
;
}
}
...
@@ -162,20 +193,19 @@ class CuSparseDnMatDescriptor {
...
@@ -162,20 +193,19 @@ class CuSparseDnMatDescriptor {
};
};
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
,
typename
TensorType
>
void
SparseBlas
<
phi
::
GPUContext
>::
DSDMM
(
bool
transa
,
void
SparseBlas
<
phi
::
GPUContext
>::
SPMM
(
bool
transa
,
bool
transb
,
bool
transb
,
T
alpha
,
T
alpha
,
const
phi
::
SparseCsrTensor
&
mat_a
,
const
TensorType
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
T
beta
,
phi
::
DenseTensor
*
mat_c
)
const
{
phi
::
DenseTensor
*
mat_out
)
const
{
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
auto
a_descriptor
=
CuSparseSpMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
a_descriptor
=
CuSparseSpMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
b_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
b_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
c_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
*
mat_c
,
dev_ctx_
);
auto
out_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
*
mat_out
,
dev_ctx_
);
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
size_t
buffer_size
=
0
;
size_t
buffer_size
=
0
;
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
phi
::
dynload
::
cusparseSpMM_bufferSize
(
handle
,
phi
::
dynload
::
cusparseSpMM_bufferSize
(
handle
,
...
@@ -185,7 +215,7 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
...
@@ -185,7 +215,7 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
a_descriptor
.
descriptor
(),
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
&
beta
,
c
_descriptor
.
descriptor
(),
out
_descriptor
.
descriptor
(),
gpu_type
,
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
CUSPARSE_SPMM_ALG_DEFAULT
,
&
buffer_size
);
&
buffer_size
);
...
@@ -202,7 +232,7 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
...
@@ -202,7 +232,7 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
a_descriptor
.
descriptor
(),
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
&
beta
,
c
_descriptor
.
descriptor
(),
out
_descriptor
.
descriptor
(),
gpu_type
,
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
CUSPARSE_SPMM_ALG_DEFAULT
,
tmp_buffer_ptr
);
tmp_buffer_ptr
);
...
@@ -211,19 +241,19 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
...
@@ -211,19 +241,19 @@ void SparseBlas<phi::GPUContext>::DSDMM(bool transa,
#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11030
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
,
typename
TensorType
>
void
SparseBlas
<
phi
::
GPUContext
>::
SDDMM
(
bool
transa
,
void
SparseBlas
<
phi
::
GPUContext
>::
SDDMM
(
bool
transa
,
bool
transb
,
bool
transb
,
T
alpha
,
T
alpha
,
const
phi
::
DenseTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_a
,
const
phi
::
DenseTensor
&
mat_b
,
const
phi
::
DenseTensor
&
mat_b
,
T
beta
,
T
beta
,
phi
::
SparseCsrTensor
*
mat_c
)
const
{
TensorType
*
mat_out
)
const
{
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
cudaDataType_t
gpu_type
=
GetGpuDataType
<
T
>
();
auto
a_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
a_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_a
,
dev_ctx_
);
auto
b_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
b_descriptor
=
CuSparseDnMatDescriptor
<
T
>
(
mat_b
,
dev_ctx_
);
auto
c_descriptor
=
CuSparseSpMatDescriptor
<
T
>
(
*
mat_c
,
dev_ctx_
);
auto
out_descriptor
=
CuSparseSpMatDescriptor
<
T
>
(
*
mat_out
,
dev_ctx_
);
size_t
buffer_size
=
0
;
size_t
buffer_size
=
0
;
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
dev_ctx_
.
CusparseCall
([
&
](
cusparseHandle_t
handle
)
{
...
@@ -234,7 +264,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
...
@@ -234,7 +264,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
a_descriptor
.
descriptor
(),
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
&
beta
,
c
_descriptor
.
descriptor
(),
out
_descriptor
.
descriptor
(),
gpu_type
,
gpu_type
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
&
buffer_size
);
&
buffer_size
);
...
@@ -252,7 +282,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
...
@@ -252,7 +282,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
a_descriptor
.
descriptor
(),
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
&
beta
,
c
_descriptor
.
descriptor
(),
out
_descriptor
.
descriptor
(),
gpu_type
,
gpu_type
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
tmp_buffer_ptr
);
tmp_buffer_ptr
);
...
@@ -266,7 +296,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
...
@@ -266,7 +296,7 @@ void SparseBlas<phi::GPUContext>::SDDMM(bool transa,
a_descriptor
.
descriptor
(),
a_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
b_descriptor
.
descriptor
(),
&
beta
,
&
beta
,
c
_descriptor
.
descriptor
(),
out
_descriptor
.
descriptor
(),
gpu_type
,
gpu_type
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
tmp_buffer_ptr
);
tmp_buffer_ptr
);
...
...
paddle/phi/kernels/sparse/cpu/softmax_grad_kernel.cc
浏览文件 @
eec4e034
...
@@ -38,11 +38,17 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
...
@@ -38,11 +38,17 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
"SparseCsrTensor only support axis=-1 for softmax, "
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"
));
"which is faster when reading data by row (axis=-1)"
));
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
dx
);
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
dx
);
auto
out_dim
=
out
.
dims
();
auto
out_dim
=
out
.
dims
();
int
rows
=
1
;
auto
out_rank
=
out_dim
.
size
();
for
(
int
i
=
0
;
i
<
out_dim
.
size
()
-
1
;
++
i
)
{
rows
*=
out_dim
[
i
];
int
batch_size
=
1
;
int
row_number
=
1
;
for
(
int
i
=
0
;
i
<
out_rank
-
1
;
++
i
)
{
if
(
i
<
out_rank
-
2
)
{
batch_size
*=
out_dim
[
i
];
}
else
if
(
i
==
out_rank
-
2
)
{
row_number
=
out_dim
[
i
];
}
}
}
const
DenseTensor
&
out_crows
=
out
.
non_zero_crows
();
const
DenseTensor
&
out_crows
=
out
.
non_zero_crows
();
...
@@ -50,7 +56,6 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
...
@@ -50,7 +56,6 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
const
DenseTensor
&
dout_values
=
dout
.
non_zero_elements
();
const
DenseTensor
&
dout_values
=
dout
.
non_zero_elements
();
DenseTensor
*
dx_values
=
dx
->
mutable_non_zero_elements
();
DenseTensor
*
dx_values
=
dx
->
mutable_non_zero_elements
();
int
row_first
=
0
;
int
row_nnz
=
0
;
int
row_nnz
=
0
;
const
T
*
out_data
=
out_values
.
data
<
T
>
();
const
T
*
out_data
=
out_values
.
data
<
T
>
();
const
T
*
dout_data
=
dout_values
.
data
<
T
>
();
const
T
*
dout_data
=
dout_values
.
data
<
T
>
();
...
@@ -60,20 +65,24 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
...
@@ -60,20 +65,24 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
out
.
non_zero_crows
().
dtype
(),
"SoftmaxCsrGradKernel"
,
([
&
]
{
out
.
non_zero_crows
().
dtype
(),
"SoftmaxCsrGradKernel"
,
([
&
]
{
const
data_t
*
out_crows_data
=
out_crows
.
data
<
data_t
>
();
const
data_t
*
out_crows_data
=
out_crows
.
data
<
data_t
>
();
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
row_first
=
static_cast
<
int
>
(
out_crows_data
[
i
]);
for
(
int
j
=
0
;
j
<
row_number
;
++
j
)
{
row_nnz
=
static_cast
<
int
>
(
out_crows_data
[
i
+
1
]
-
out_crows_data
[
i
]);
int
crow_idx
=
i
*
(
row_number
+
1
)
+
j
;
row_nnz
=
static_cast
<
int
>
(
out_crows_data
[
crow_idx
+
1
]
-
out_data
=
out_data
+
row_first
;
out_crows_data
[
crow_idx
]);
dout_data
=
dout_data
+
row_first
;
dx_data
=
dx_data
+
row_first
;
T
sum
=
0
;
phi
::
funcs
::
vec_mul_reduce
<
T
,
plt
::
avx
>
(
T
sum
=
0
;
row_nnz
,
dout_data
,
out_data
,
&
sum
);
phi
::
funcs
::
vec_mul_reduce
<
T
,
plt
::
avx
>
(
phi
::
funcs
::
vec_add_bias
<
T
,
plt
::
avx
>
(
row_nnz
,
dout_data
,
out_data
,
&
sum
);
row_nnz
,
static_cast
<
T
>
(
-
1
)
*
sum
,
dout_data
,
dx_data
);
phi
::
funcs
::
vec_add_bias
<
T
,
plt
::
avx
>
(
phi
::
funcs
::
vec_mul
<
T
,
plt
::
avx
>
(
row_nnz
,
static_cast
<
T
>
(
-
1
)
*
sum
,
dout_data
,
dx_data
);
row_nnz
,
dx_data
,
out_data
,
dx_data
);
phi
::
funcs
::
vec_mul
<
T
,
plt
::
avx
>
(
row_nnz
,
dx_data
,
out_data
,
dx_data
);
out_data
=
out_data
+
row_nnz
;
dout_data
=
dout_data
+
row_nnz
;
dx_data
=
dx_data
+
row_nnz
;
}
}
}
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/cpu/softmax_kernel.cc
浏览文件 @
eec4e034
...
@@ -37,18 +37,23 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
...
@@ -37,18 +37,23 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
"SparseCsrTensor only support axis=-1 for softmax, "
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"
));
"which is faster when reading data by row (axis=-1)"
));
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
auto
x_dim
=
x
.
dims
();
auto
x_dim
=
x
.
dims
();
auto
x_rank
=
x_dim
.
size
();
int
batch_size
=
1
;
int
row_number
=
1
;
int
row_number
=
1
;
for
(
int
i
=
0
;
i
<
x_dim
.
size
()
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
x_rank
-
1
;
++
i
)
{
row_number
*=
x_dim
[
i
];
if
(
i
<
x_rank
-
2
)
{
batch_size
*=
x_dim
[
i
];
}
else
if
(
i
==
x_rank
-
2
)
{
row_number
=
x_dim
[
i
];
}
}
}
const
DenseTensor
&
x_crows
=
x
.
non_zero_crows
();
const
DenseTensor
&
x_crows
=
x
.
non_zero_crows
();
const
DenseTensor
&
x_values
=
x
.
non_zero_elements
();
const
DenseTensor
&
x_values
=
x
.
non_zero_elements
();
DenseTensor
*
out_values
=
out
->
mutable_non_zero_elements
();
DenseTensor
*
out_values
=
out
->
mutable_non_zero_elements
();
int
row_first
=
0
;
int
row_nnz
=
0
;
int
row_nnz
=
0
;
T
row_max_val
=
0
;
T
row_max_val
=
0
;
const
T
*
x_data
=
x_values
.
data
<
T
>
();
const
T
*
x_data
=
x_values
.
data
<
T
>
();
...
@@ -58,23 +63,26 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
...
@@ -58,23 +63,26 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"CsrSoftmaxKernel"
,
([
&
]
{
x
.
non_zero_crows
().
dtype
(),
"CsrSoftmaxKernel"
,
([
&
]
{
const
data_t
*
x_crows_data
=
x_crows
.
data
<
data_t
>
();
const
data_t
*
x_crows_data
=
x_crows
.
data
<
data_t
>
();
for
(
int
i
=
0
;
i
<
row_number
;
++
i
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
row_first
=
static_cast
<
int
>
(
x_crows_data
[
i
]);
for
(
int
j
=
0
;
j
<
row_number
;
++
j
)
{
row_nnz
=
static_cast
<
int
>
(
x_crows_data
[
i
+
1
]
-
x_crows_data
[
i
]);
int
crow_idx
=
i
*
(
row_number
+
1
)
+
j
;
row_nnz
=
static_cast
<
int
>
(
x_crows_data
[
crow_idx
+
1
]
-
x_data
=
x_data
+
row_first
;
x_crows_data
[
crow_idx
]);
out_data
=
out_data
+
row_first
;
row_max_val
=
*
std
::
max_element
(
x_data
,
x_data
+
row_nnz
);
row_max_val
=
*
std
::
max_element
(
x_data
,
x_data
+
row_nnz
);
phi
::
funcs
::
vec_add_bias
<
T
,
plt
::
avx
>
(
phi
::
funcs
::
vec_add_bias
<
T
,
plt
::
avx
>
(
row_nnz
,
static_cast
<
T
>
(
-
1
)
*
row_max_val
,
x_data
,
out_data
);
row_nnz
,
static_cast
<
T
>
(
-
1
)
*
row_max_val
,
x_data
,
out_data
);
phi
::
funcs
::
vec_exp
<
T
>
(
row_nnz
,
out_data
,
out_data
);
phi
::
funcs
::
vec_exp
<
T
>
(
row_nnz
,
out_data
,
out_data
);
T
sum
=
0
;
T
sum
=
0
;
phi
::
funcs
::
vec_sum
<
T
,
plt
::
avx
>
(
row_nnz
,
out_data
,
&
sum
);
phi
::
funcs
::
vec_sum
<
T
,
plt
::
avx
>
(
row_nnz
,
out_data
,
&
sum
);
phi
::
funcs
::
vec_scal
<
T
,
plt
::
avx
>
(
phi
::
funcs
::
vec_scal
<
T
,
plt
::
avx
>
(
row_nnz
,
static_cast
<
T
>
(
1
)
/
sum
,
out_data
,
out_data
);
row_nnz
,
static_cast
<
T
>
(
1
)
/
sum
,
out_data
,
out_data
);
x_data
=
x_data
+
row_nnz
;
out_data
=
out_data
+
row_nnz
;
}
}
}
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/empty_kernel.cc
浏览文件 @
eec4e034
...
@@ -22,6 +22,25 @@ limitations under the License. */
...
@@ -22,6 +22,25 @@ limitations under the License. */
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
EmptyLikeCooKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCooTensor
*
out
)
{
const
DenseTensor
&
x_indices
=
x
.
non_zero_indices
();
const
DenseTensor
&
x_values
=
x
.
non_zero_elements
();
DenseTensor
*
out_indices
=
out
->
mutable_non_zero_indices
();
DenseTensor
*
out_values
=
out
->
mutable_non_zero_elements
();
phi
::
Copy
(
dev_ctx
,
x_indices
,
dev_ctx
.
GetPlace
(),
false
,
out_indices
);
phi
::
Copy
(
dev_ctx
,
x_values
,
dev_ctx
.
GetPlace
(),
false
,
out_values
);
out_values
->
Resize
(
x_values
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
out_values
);
out
->
set_dims
(
x
.
dims
());
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
EmptyLikeCsrKernel
(
const
Context
&
dev_ctx
,
void
EmptyLikeCsrKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
x
,
...
@@ -34,17 +53,33 @@ void EmptyLikeCsrKernel(const Context& dev_ctx,
...
@@ -34,17 +53,33 @@ void EmptyLikeCsrKernel(const Context& dev_ctx,
DenseTensor
*
out_cols
=
out
->
mutable_non_zero_cols
();
DenseTensor
*
out_cols
=
out
->
mutable_non_zero_cols
();
DenseTensor
*
out_values
=
out
->
mutable_non_zero_elements
();
DenseTensor
*
out_values
=
out
->
mutable_non_zero_elements
();
out
->
set_dims
(
x
.
dims
());
phi
::
Copy
(
dev_ctx
,
x_crows
,
dev_ctx
.
GetPlace
(),
false
,
out_crows
);
phi
::
Copy
(
dev_ctx
,
x_crows
,
dev_ctx
.
GetPlace
(),
false
,
out_crows
);
phi
::
Copy
(
dev_ctx
,
x_cols
,
dev_ctx
.
GetPlace
(),
false
,
out_cols
);
phi
::
Copy
(
dev_ctx
,
x_cols
,
dev_ctx
.
GetPlace
(),
false
,
out_cols
);
out_values
->
Resize
(
x_values
.
dims
());
out_values
->
Resize
(
x_values
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
out_values
);
dev_ctx
.
template
Alloc
<
T
>(
out_values
);
out
->
set_dims
(
x
.
dims
());
}
}
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
empty_like_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
EmptyLikeCooKernel
,
float
,
double
,
int8_t
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
empty_like_csr
,
PD_REGISTER_KERNEL
(
empty_like_csr
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
@@ -61,6 +96,21 @@ PD_REGISTER_KERNEL(empty_like_csr,
...
@@ -61,6 +96,21 @@ PD_REGISTER_KERNEL(empty_like_csr,
}
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
empty_like_coo
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
EmptyLikeCooKernel
,
float
,
double
,
int8_t
,
uint8_t
,
int16_t
,
int
,
int64_t
,
bool
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
empty_like_csr
,
PD_REGISTER_KERNEL
(
empty_like_csr
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
...
...
paddle/phi/kernels/sparse/empty_kernel.h
浏览文件 @
eec4e034
...
@@ -14,11 +14,17 @@ limitations under the License. */
...
@@ -14,11 +14,17 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
EmptyLikeCooKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCooTensor
*
out
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
EmptyLikeCsrKernel
(
const
Context
&
dev_ctx
,
void
EmptyLikeCsrKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
x
,
...
...
paddle/phi/kernels/sparse/gpu/matmul_grad_kernel.cu
浏览文件 @
eec4e034
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
phi
{
namespace
phi
{
...
@@ -38,23 +39,8 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx,
...
@@ -38,23 +39,8 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx,
// dx{SparseCsr} = dout{Dense} * y'{Dense}
// dx{SparseCsr} = dout{Dense} * y'{Dense}
if
(
dx
)
{
if
(
dx
)
{
// InferMeta of SparseCsrTensor 'dx'
// InferMeta of SparseCsrTensor 'dx', CreateLikeInferMeta
dx
->
set_dims
(
x
.
dims
());
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
dx
);
phi
::
Copy
(
dev_ctx
,
x
.
non_zero_crows
(),
dev_ctx
.
GetPlace
(),
false
,
dx
->
mutable_non_zero_crows
());
phi
::
Copy
(
dev_ctx
,
x
.
non_zero_cols
(),
dev_ctx
.
GetPlace
(),
false
,
dx
->
mutable_non_zero_cols
());
DenseTensor
*
values
=
dx
->
mutable_non_zero_elements
();
values
->
Resize
(
x
.
non_zero_elements
().
dims
());
dev_ctx
.
template
Alloc
<
T
>(
values
);
sparse_blas
.
SDDMM
(
sparse_blas
.
SDDMM
(
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
dx
);
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
dx
);
...
@@ -69,13 +55,13 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx,
...
@@ -69,13 +55,13 @@ void CsrDenseMatmulGradKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
T
>(
dy
);
dev_ctx
.
template
Alloc
<
T
>(
dy
);
sparse_blas
.
DSD
MM
(
sparse_blas
.
SP
MM
(
true
,
false
,
static_cast
<
T
>
(
1
),
x
,
dout
,
static_cast
<
T
>
(
0
),
dy
);
true
,
false
,
static_cast
<
T
>
(
1
),
x
,
dout
,
static_cast
<
T
>
(
0
),
dy
);
}
}
#else
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"
backward of 'sparse.mm' use cusparseSDDMM, Only
"
"
backward of 'sparse.matmul' use cusparseSDDMM, which is supported from
"
"
support it from
CUDA 11.3"
));
"CUDA 11.3"
));
#endif
#endif
}
}
...
@@ -97,7 +83,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
...
@@ -97,7 +83,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
meta_dx
.
set_dtype
(
x
.
dtype
());
meta_dx
.
set_dtype
(
x
.
dtype
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
dev_ctx
.
template
Alloc
<
T
>(
dx
);
sparse_blas
.
DSD
MM
(
sparse_blas
.
SP
MM
(
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
dx
);
false
,
true
,
static_cast
<
T
>
(
1
),
dout
,
y
,
static_cast
<
T
>
(
0
),
dx
);
}
}
...
@@ -109,7 +95,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
...
@@ -109,7 +95,7 @@ void CsrMaskedMatmulGradKernel(const Context& dev_ctx,
std
::
swap
(
trans_dim_vec
[
rank
-
1
],
trans_dim_vec
[
rank
-
2
]);
std
::
swap
(
trans_dim_vec
[
rank
-
1
],
trans_dim_vec
[
rank
-
2
]);
DenseTensor
trans_dy
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
trans_dim_vec
);
DenseTensor
trans_dy
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
trans_dim_vec
);
sparse_blas
.
DSD
MM
(
sparse_blas
.
SP
MM
(
true
,
false
,
static_cast
<
T
>
(
1
),
dout
,
x
,
static_cast
<
T
>
(
0
),
&
trans_dy
);
true
,
false
,
static_cast
<
T
>
(
1
),
dout
,
x
,
static_cast
<
T
>
(
0
),
&
trans_dy
);
// InferMeta of DenseTensor 'dy'
// InferMeta of DenseTensor 'dy'
...
...
paddle/phi/kernels/sparse/gpu/matmul_kernel.cu
浏览文件 @
eec4e034
...
@@ -26,6 +26,7 @@ limitations under the License. */
...
@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/funcs/sparse/sparse_blas.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -59,7 +60,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
...
@@ -59,7 +60,7 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
PADDLE_ENFORCE_EQ
(
xdim_vec
[
i
],
PADDLE_ENFORCE_EQ
(
xdim_vec
[
i
],
ydim_vec
[
i
],
ydim_vec
[
i
],
phi
::
errors
::
InvalidArgument
(
phi
::
errors
::
InvalidArgument
(
"x.dim[%d] and x.dim[%d] must
match
."
,
i
,
i
));
"x.dim[%d] and x.dim[%d] must
be eaqul
."
,
i
,
i
));
}
}
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
...
@@ -80,11 +81,11 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
...
@@ -80,11 +81,11 @@ void CsrDenseMatmulKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
T
>(
out
);
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
sparse_blas
.
DSD
MM
(
sparse_blas
.
SP
MM
(
false
,
false
,
static_cast
<
T
>
(
1
),
x
,
y
,
static_cast
<
T
>
(
0
),
out
);
false
,
false
,
static_cast
<
T
>
(
1
),
x
,
y
,
static_cast
<
T
>
(
0
),
out
);
#else
#else
PADDLE_THROW
(
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"
forward of 'sparse.mm
' use cusparseSpMM, "
phi
::
errors
::
Unimplemented
(
"
forward of 'sparse.matmul
' use cusparseSpMM, "
"which is supported from CUDA 11.0"
));
"which is supported from CUDA 11.0"
));
#endif
#endif
}
}
...
@@ -159,32 +160,16 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx,
...
@@ -159,32 +160,16 @@ void CsrMaskedMatmulKernel(const Context& dev_ctx,
"The shape of Input(x) and Input(y) is not suitable for matmul "
"The shape of Input(x) and Input(y) is not suitable for matmul "
"opetation, mask_dim[-1] must be eaqual to y_dim[-1]."
));
"opetation, mask_dim[-1] must be eaqual to y_dim[-1]."
));
// InferMeta of SparseCsrTensor 'out'
// InferMeta of SparseCsrTensor 'out', CreateLikeInferMeta
out
->
set_dims
(
mask
.
dims
());
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
mask
,
out
);
phi
::
Copy
(
dev_ctx
,
mask
.
non_zero_crows
(),
dev_ctx
.
GetPlace
(),
false
,
out
->
mutable_non_zero_crows
());
phi
::
Copy
(
dev_ctx
,
mask
.
non_zero_cols
(),
dev_ctx
.
GetPlace
(),
false
,
out
->
mutable_non_zero_cols
());
DenseTensor
*
values
=
out
->
mutable_non_zero_elements
();
values
->
Resize
(
mask
.
non_zero_elements
().
dims
());
dev_ctx
.
template
Alloc
<
T
>(
values
);
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
sparse_blas
.
SDDMM
(
sparse_blas
.
SDDMM
(
false
,
false
,
static_cast
<
T
>
(
1
),
x
,
y
,
static_cast
<
T
>
(
0
),
out
);
false
,
false
,
static_cast
<
T
>
(
1
),
x
,
y
,
static_cast
<
T
>
(
0
),
out
);
#else
#else
PADDLE_THROW
(
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
phi
::
errors
::
Unimplemented
(
" forward of 'sparse.masked_mm' use "
"forward of 'sparse.masked_matmul' use cusparseSDDMM, which is supported "
"cusparseSDDMM, which is supported from "
"from CUDA 11.3"
));
"CUDA 11.3"
));
#endif
#endif
}
}
...
...
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
浏览文件 @
eec4e034
...
@@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/softmax_grad_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -27,13 +28,20 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
...
@@ -27,13 +28,20 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
const
T
*
out_values
,
const
T
*
out_values
,
const
T
*
dout_values
,
const
T
*
dout_values
,
T
*
dx_values
,
T
*
dx_values
,
int
row_number
)
{
int
row_number
,
int
total_row_number
)
{
// dx = (dout - sum(dout * out)) * out
// dx = (dout - sum(dout * out)) * out
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
non_zero_idx
=
threadIdx
.
x
;
int
non_zero_idx
=
threadIdx
.
x
;
if
(
row
>=
row_number
)
return
;
if
(
row
>=
total_row_number
)
return
;
int
row_first
=
static_cast
<
int
>
(
out_crows
[
row
]);
int
cur_batch
=
row
/
row_number
;
int
row_nnz
=
static_cast
<
int
>
(
out_crows
[
row
+
1
]
-
out_crows
[
row
]);
int
crow_idx
=
cur_batch
*
(
row_number
+
1
)
+
(
row
%
row_number
);
int
cur_batch_offset
=
0
;
for
(
int
i
=
1
;
i
<
cur_batch
+
1
;
++
i
)
{
cur_batch_offset
+=
out_crows
[
i
*
(
row_number
+
1
)
-
1
];
}
int
row_first
=
cur_batch_offset
+
static_cast
<
int
>
(
out_crows
[
crow_idx
]);
int
row_nnz
=
static_cast
<
int
>
(
out_crows
[
crow_idx
+
1
]
-
out_crows
[
crow_idx
]);
if
(
row_nnz
==
0
)
return
;
if
(
row_nnz
==
0
)
return
;
int
kIteration
=
(
row_nnz
+
warpSize
-
1
)
/
warpSize
;
int
kIteration
=
(
row_nnz
+
warpSize
-
1
)
/
warpSize
;
...
@@ -70,12 +78,18 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
...
@@ -70,12 +78,18 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
dx
);
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
dx
);
auto
out_dim
=
out
.
dims
();
auto
out_dim
=
out
.
dims
();
auto
out_rank
=
out_dim
.
size
();
int
total_row_number
=
1
;
int
row_number
=
1
;
int
row_number
=
1
;
for
(
int
i
=
0
;
i
<
out_dim
.
size
()
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
out_rank
-
1
;
++
i
)
{
row_number
*=
out_dim
[
i
];
total_row_number
*=
out_dim
[
i
];
if
(
i
==
out_rank
-
2
)
{
row_number
=
out_dim
[
i
];
}
}
}
dim3
grid
((
row_number
+
3
)
/
4
);
dim3
grid
((
total_
row_number
+
3
)
/
4
);
dim3
block
(
32
,
4
);
dim3
block
(
32
,
4
);
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
...
@@ -85,7 +99,8 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
...
@@ -85,7 +99,8 @@ void SoftmaxCsrGradKernel(const Context& dev_ctx,
out
.
non_zero_elements
().
data
<
T
>
(),
out
.
non_zero_elements
().
data
<
T
>
(),
dout
.
non_zero_elements
().
data
<
T
>
(),
dout
.
non_zero_elements
().
data
<
T
>
(),
dx
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
dx
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
row_number
);
row_number
,
total_row_number
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
浏览文件 @
eec4e034
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/softmax_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
...
@@ -19,7 +21,6 @@ limitations under the License. */
...
@@ -19,7 +21,6 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/empty_kernel.h"
#include "paddle/phi/kernels/sparse/softmax_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -28,13 +29,20 @@ template <typename T, typename IntT = int>
...
@@ -28,13 +29,20 @@ template <typename T, typename IntT = int>
__global__
void
SoftmaxGpuKernel
(
const
IntT
*
x_crows
,
__global__
void
SoftmaxGpuKernel
(
const
IntT
*
x_crows
,
const
T
*
x_values
,
const
T
*
x_values
,
T
*
out_values
,
T
*
out_values
,
int
row_number
)
{
int
row_number
,
int
total_row_number
)
{
// out = exp(x-x_max) / sum(exp(x-x_max))
// out = exp(x-x_max) / sum(exp(x-x_max))
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
non_zero_idx
=
threadIdx
.
x
;
int
non_zero_idx
=
threadIdx
.
x
;
if
(
row
>=
row_number
)
return
;
if
(
row
>=
total_row_number
)
return
;
int
row_first
=
static_cast
<
int
>
(
x_crows
[
row
]);
int
cur_batch
=
row
/
row_number
;
int
row_nnz
=
static_cast
<
int
>
(
x_crows
[
row
+
1
]
-
x_crows
[
row
]);
int
crow_idx
=
cur_batch
*
(
row_number
+
1
)
+
(
row
%
row_number
);
int
cur_batch_offset
=
0
;
for
(
int
i
=
1
;
i
<
cur_batch
+
1
;
++
i
)
{
cur_batch_offset
+=
x_crows
[
i
*
(
row_number
+
1
)
-
1
];
}
int
row_first
=
cur_batch_offset
+
static_cast
<
int
>
(
x_crows
[
crow_idx
]);
int
row_nnz
=
static_cast
<
int
>
(
x_crows
[
crow_idx
+
1
]
-
x_crows
[
crow_idx
]);
if
(
row_nnz
==
0
)
return
;
if
(
row_nnz
==
0
)
return
;
int
kIteration
=
(
row_nnz
+
warpSize
-
1
)
/
warpSize
;
int
kIteration
=
(
row_nnz
+
warpSize
-
1
)
/
warpSize
;
...
@@ -81,17 +89,20 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
...
@@ -81,17 +89,20 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
"SparseCsrTensor only support axis=-1 for softmax, "
"SparseCsrTensor only support axis=-1 for softmax, "
"which is faster when reading data by row (axis=-1)"
));
"which is faster when reading data by row (axis=-1)"
));
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
out
);
auto
x_dim
=
x
.
dims
();
auto
x_dim
=
x
.
dims
();
auto
x_rank
=
x_dim
.
size
();
int
total_row_number
=
1
;
int
row_number
=
1
;
int
row_number
=
1
;
for
(
int
i
=
0
;
i
<
x_dim
.
size
()
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
x_rank
-
1
;
++
i
)
{
row_number
*=
x_dim
[
i
];
total_row_number
*=
x_dim
[
i
];
if
(
i
==
x_rank
-
2
)
{
row_number
=
x_dim
[
i
];
}
}
}
dim3
grid
((
row_number
+
3
)
/
4
);
dim3
block
(
32
,
4
);
DenseTensor
tmp_tensor
=
dim3
grid
((
total_row_number
+
3
)
/
4
);
phi
::
EmptyLike
<
T
,
Context
>
(
dev_ctx
,
x
.
non_zero_elements
()
);
dim3
block
(
32
,
4
);
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"CsrSoftmaxKernel"
,
([
&
]
{
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"CsrSoftmaxKernel"
,
([
&
]
{
SoftmaxGpuKernel
<
T
,
data_t
>
SoftmaxGpuKernel
<
T
,
data_t
>
...
@@ -99,7 +110,8 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
...
@@ -99,7 +110,8 @@ void SoftmaxCsrKernel(const Context& dev_ctx,
x
.
non_zero_crows
().
data
<
data_t
>
(),
x
.
non_zero_crows
().
data
<
data_t
>
(),
x
.
non_zero_elements
().
data
<
T
>
(),
x
.
non_zero_elements
().
data
<
T
>
(),
out
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
out
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
row_number
);
row_number
,
total_row_number
);
}));
}));
}
}
...
...
python/paddle/fluid/tests/unittests/test_sparse_matmul_op.py
浏览文件 @
eec4e034
...
@@ -114,7 +114,77 @@ class TestCsrMaskedMatmul2D(unittest.TestCase):
...
@@ -114,7 +114,77 @@ class TestCsrMaskedMatmul2D(unittest.TestCase):
self
.
assertTrue
(
np
.
allclose
(
np_y_grad
,
y
.
grad
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
np_y_grad
,
y
.
grad
.
numpy
()))
#TODO(zhouwei25): support unit test of batch 'paddle.sparse.mm/masked_mm'
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11070
,
"paddle is not compiled with CUDA and cuda version need to >= 11.7"
)
class
TestCsrDenseMatmul3D
(
unittest
.
TestCase
):
# x: csr, y: dense, out: dense
def
test_matmul
(
self
):
with
_test_eager_guard
():
paddle
.
set_default_dtype
(
'float32'
)
origin_x
=
paddle
.
rand
([
16
,
16
,
12
])
mask
=
paddle
.
randint
(
0
,
2
,
[
16
,
12
])
origin_x
=
origin_x
*
mask
origin_y
=
paddle
.
rand
([
16
,
12
,
10
])
dense_x
=
origin_x
.
detach
()
dense_x
.
stop_gradient
=
False
dense_y
=
origin_y
.
detach
()
dense_y
.
stop_gradient
=
False
dense_out
=
paddle
.
matmul
(
dense_x
,
dense_y
)
dense_out
.
backward
()
sp_x
=
origin_x
.
detach
().
to_sparse_csr
()
sp_x
.
stop_gradient
=
False
sp_y
=
origin_y
.
detach
()
sp_y
.
stop_gradient
=
False
sp_out
=
paddle
.
incubate
.
sparse
.
matmul
(
sp_x
,
sp_y
)
sp_out
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
sp_out
.
numpy
(),
dense_out
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
sp_x
.
grad
.
to_dense
().
numpy
(),
(
dense_x
.
grad
*
mask
).
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
sp_y
.
grad
.
numpy
(),
dense_y
.
grad
.
numpy
()))
@
unittest
.
skipIf
(
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11070
,
"paddle is not compiled with CUDA and cuda version need to >= 11.7"
)
class
TestCsrMaskedMatmul3D
(
unittest
.
TestCase
):
# x: dense, y: dense, out: csr
def
test_matmul
(
self
):
with
_test_eager_guard
():
paddle
.
set_default_dtype
(
'float64'
)
origin_x
=
paddle
.
rand
([
16
,
16
,
12
])
origin_y
=
paddle
.
rand
([
16
,
12
,
10
])
mask
=
paddle
.
randint
(
0
,
2
,
[
16
,
10
])
dense_x
=
origin_x
.
detach
()
dense_x
.
stop_gradient
=
False
dense_y
=
origin_y
.
detach
()
dense_y
.
stop_gradient
=
False
dense_out
=
paddle
.
matmul
(
dense_x
,
dense_y
)
dense_out
=
dense_out
*
mask
dense_out
.
backward
()
sp_x
=
origin_x
.
detach
()
sp_x
.
stop_gradient
=
False
sp_y
=
origin_y
.
detach
()
sp_y
.
stop_gradient
=
False
sp_out
=
paddle
.
incubate
.
sparse
.
masked_matmul
(
sp_x
,
sp_y
,
dense_out
.
to_sparse_csr
())
sp_out
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
sp_out
.
to_dense
().
numpy
(),
dense_out
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
sp_x
.
grad
.
numpy
(),
dense_x
.
grad
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
sp_y
.
grad
.
numpy
(),
dense_y
.
grad
.
numpy
()))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_sparse_softmax.py
→
python/paddle/fluid/tests/unittests/test_sparse_softmax
_op
.py
浏览文件 @
eec4e034
...
@@ -28,10 +28,10 @@ np.random.seed(2022)
...
@@ -28,10 +28,10 @@ np.random.seed(2022)
class
TestCsrSoftmax
(
unittest
.
TestCase
):
class
TestCsrSoftmax
(
unittest
.
TestCase
):
def
test_softmax
(
self
):
def
test_softmax
2d
(
self
):
with
_test_eager_guard
():
with
_test_eager_guard
():
mask
=
np
.
random
.
rand
(
1
,
5
)
<
0.5
mask
=
np
.
random
.
rand
(
1
6
,
128
)
<
0.5
np_x
=
np
.
random
.
rand
(
1
,
5
)
*
mask
np_x
=
np
.
random
.
rand
(
1
6
,
128
)
*
mask
np_csr
=
sp
.
csr_matrix
(
np_x
)
np_csr
=
sp
.
csr_matrix
(
np_x
)
row_number
=
np_csr
.
shape
[
0
]
row_number
=
np_csr
.
shape
[
0
]
...
@@ -56,6 +56,7 @@ class TestCsrSoftmax(unittest.TestCase):
...
@@ -56,6 +56,7 @@ class TestCsrSoftmax(unittest.TestCase):
# dx = (dout - sum(dout * out)) * out, dout=rand_x
# dx = (dout - sum(dout * out)) * out, dout=rand_x
out
.
backward
(
csr
.
detach
())
out
.
backward
(
csr
.
detach
())
dx
=
np
.
array
([])
for
i
in
range
(
row_number
):
for
i
in
range
(
row_number
):
start
=
np_csr
.
indptr
[
i
]
start
=
np_csr
.
indptr
[
i
]
end
=
np_csr
.
indptr
[
i
+
1
]
end
=
np_csr
.
indptr
[
i
+
1
]
...
@@ -64,7 +65,7 @@ class TestCsrSoftmax(unittest.TestCase):
...
@@ -64,7 +65,7 @@ class TestCsrSoftmax(unittest.TestCase):
out
=
np_out
[
start
:
end
]
out
=
np_out
[
start
:
end
]
dout
=
np_csr
.
data
[
start
:
end
]
dout
=
np_csr
.
data
[
start
:
end
]
sum
=
np
.
sum
(
dout
*
out
,
keepdims
=
True
)
sum
=
np
.
sum
(
dout
*
out
,
keepdims
=
True
)
dx
=
(
dout
-
sum
)
*
out
dx
=
np
.
concatenate
([
dx
,
(
dout
-
sum
)
*
out
])
self
.
assertTrue
(
np
.
allclose
(
csr
.
grad
.
crows
().
numpy
(),
self
.
assertTrue
(
np
.
allclose
(
csr
.
grad
.
crows
().
numpy
(),
np_csr
.
indptr
))
np_csr
.
indptr
))
...
@@ -72,6 +73,55 @@ class TestCsrSoftmax(unittest.TestCase):
...
@@ -72,6 +73,55 @@ class TestCsrSoftmax(unittest.TestCase):
np_csr
.
indices
))
np_csr
.
indices
))
self
.
assertTrue
(
np
.
allclose
(
csr
.
grad
.
values
().
numpy
(),
dx
))
self
.
assertTrue
(
np
.
allclose
(
csr
.
grad
.
values
().
numpy
(),
dx
))
def
test_softmax3d
(
self
):
with
_test_eager_guard
():
batchNum
=
16
mask
=
np
.
random
.
rand
(
batchNum
,
16
,
128
)
<
0.5
np_x
=
np
.
random
.
rand
(
batchNum
,
16
,
128
)
*
mask
np_out_list
=
[]
np_out
=
np
.
array
([])
for
i
in
range
(
batchNum
):
np_csr
=
sp
.
csr_matrix
(
np_x
[
i
,
:,
:])
row_number
=
np_csr
.
shape
[
0
]
for
j
in
range
(
row_number
,
):
start
=
np_csr
.
indptr
[
j
]
end
=
np_csr
.
indptr
[
j
+
1
]
if
start
==
end
:
continue
x
=
np_csr
.
data
[
start
:
end
]
x_max
=
np
.
max
(
x
,
keepdims
=
True
)
x_exp
=
np
.
exp
(
x
-
x_max
)
x_exp_sum
=
np
.
sum
(
x_exp
,
keepdims
=
True
)
np_out_list
.
append
(
x_exp
/
x_exp_sum
)
np_out
=
np
.
concatenate
([
np_out
,
x_exp
/
x_exp_sum
])
csr
=
paddle
.
to_tensor
(
np_x
,
stop_gradient
=
False
).
to_sparse_csr
()
m
=
paddle
.
incubate
.
sparse
.
nn
.
Softmax
()
out
=
m
(
csr
)
self
.
assertTrue
(
np
.
allclose
(
out
.
values
().
numpy
(),
np_out
))
# dx = (dout - sum(dout * out)) * out, dout=rand_x
out
.
backward
(
csr
.
detach
())
dx
=
np
.
array
([])
batch_offset
=
0
for
i
in
range
(
batchNum
):
np_csr
=
sp
.
csr_matrix
(
np_x
[
i
,
:,
:])
row_number
=
np_csr
.
shape
[
0
]
for
j
in
range
(
row_number
):
start
=
np_csr
.
indptr
[
j
]
end
=
np_csr
.
indptr
[
j
+
1
]
if
start
==
end
:
continue
dout
=
np_csr
.
data
[
start
:
end
]
out
=
np_out
[
batch_offset
+
start
:
batch_offset
+
end
]
sum
=
np
.
sum
(
dout
*
out
,
keepdims
=
True
)
dx
=
np
.
concatenate
([
dx
,
(
dout
-
sum
)
*
out
])
batch_offset
+=
np_csr
.
nnz
self
.
assertTrue
(
np
.
allclose
(
csr
.
grad
.
values
().
numpy
(),
dx
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/incubate/sparse/nn/functional/activation.py
浏览文件 @
eec4e034
...
@@ -55,7 +55,7 @@ def softmax(x, axis=-1, name=None):
...
@@ -55,7 +55,7 @@ def softmax(x, axis=-1, name=None):
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
Note:
Note:
Only support
ed
axis=-1 for SparseCsrTensor, which is faster when read data
Only support axis=-1 for SparseCsrTensor, which is faster when read data
by row (axis=-1).
by row (axis=-1).
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
...
...
python/paddle/incubate/sparse/nn/layer/activation.py
浏览文件 @
eec4e034
...
@@ -66,7 +66,7 @@ class Softmax(Layer):
...
@@ -66,7 +66,7 @@ class Softmax(Layer):
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor.
Note:
Note:
Only support
ed
axis=-1 for SparseCsrTensor, which is faster when read data
Only support axis=-1 for SparseCsrTensor, which is faster when read data
by row (axis=-1).
by row (axis=-1).
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
From the point of view of dense matrix, for each row :math:`i` and each column :math:`j`
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录