Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e3d94fc5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e3d94fc5
编写于
6月 23, 2022
作者:
M
Matsumoto Ruko
提交者:
GitHub
6月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Hackathon No.56 57 58 59】sparse elementwise add sub mul div (#41857)
上级
3be36a82
变更
15
展开全部
隐藏空白更改
内联
并排
Showing
15 changed file
with
1979 addition
and
0 deletion
+1979
-0
paddle/phi/kernels/activation_kernel.h
paddle/phi/kernels/activation_kernel.h
+1
-0
paddle/phi/kernels/cpu/activation_kernel.cc
paddle/phi/kernels/cpu/activation_kernel.cc
+10
-0
paddle/phi/kernels/funcs/activation_functor.h
paddle/phi/kernels/funcs/activation_functor.h
+8
-0
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
+413
-0
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
+451
-0
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
+112
-0
paddle/phi/kernels/sparse/elementwise_kernel.h
paddle/phi/kernels/sparse/elementwise_kernel.h
+78
-0
paddle/phi/tests/kernels/CMakeLists.txt
paddle/phi/tests/kernels/CMakeLists.txt
+4
-0
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
+422
-0
python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
...addle/fluid/tests/unittests/test_sparse_elementwise_op.py
+142
-0
python/paddle/incubate/sparse/__init__.py
python/paddle/incubate/sparse/__init__.py
+9
-0
python/paddle/incubate/sparse/math.py
python/paddle/incubate/sparse/math.py
+260
-0
python/paddle/utils/code_gen/sparse_api.yaml
python/paddle/utils/code_gen/sparse_api.yaml
+36
-0
python/paddle/utils/code_gen/sparse_bw_api.yaml
python/paddle/utils/code_gen/sparse_bw_api.yaml
+32
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/phi/kernels/activation_kernel.h
浏览文件 @
e3d94fc5
...
...
@@ -71,6 +71,7 @@ DECLARE_ACTIVATION_KERNEL(Log1p)
DECLARE_ACTIVATION_KERNEL
(
Round
)
DECLARE_ACTIVATION_KERNEL
(
Floor
)
DECLARE_ACTIVATION_KERNEL
(
Ceil
)
DECLARE_ACTIVATION_KERNEL
(
Negative
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
alpha
)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
threshold
)
...
...
paddle/phi/kernels/cpu/activation_kernel.cc
浏览文件 @
e3d94fc5
...
...
@@ -89,6 +89,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor)
DEFINE_CPU_ACTIVATION_KERNEL
(
Round
,
RoundFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Floor
,
FloorFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Ceil
,
CeilFunctor
)
DEFINE_CPU_ACTIVATION_KERNEL
(
Negative
,
NegativeFunctor
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
LeakyRelu
,
LeakyReluFunctor
,
alpha
)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS
(
ThresholdedRelu
,
...
...
@@ -182,6 +183,15 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL
(
round
,
RoundKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
floor
,
FloorKernel
)
PD_REGISTER_ACTIVATION_KERNEL
(
ceil
,
CeilKernel
)
PD_REGISTER_KERNEL
(
negative
,
CPU
,
ALL_LAYOUT
,
phi
::
NegativeKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{}
PD_REGISTER_ACTIVATION_KERNEL
(
celu
,
CeluKernel
)
PD_REGISTER_KERNEL
(
pow
,
CPU
,
ALL_LAYOUT
,
phi
::
PowKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/funcs/activation_functor.h
浏览文件 @
e3d94fc5
...
...
@@ -1814,6 +1814,14 @@ struct CeilFunctor : public BaseActivationFunctor<T> {
}
};
template
<
typename
T
>
struct
NegativeFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
-
x
;
}
};
template
<
typename
T
>
struct
ZeroGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
template
<
typename
Device
,
...
...
paddle/phi/kernels/sparse/cpu/elementwise_grad_kernel.cc
0 → 100644
浏览文件 @
e3d94fc5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/sparse/copy_kernel.h"
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
AllocCsrPtr
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
SparseCsrTensor
*
dx
)
{
DenseTensor
dx_crows
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_crows
());
DenseTensor
dx_cols
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_cols
());
DenseTensor
dx_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
dx
->
SetMember
(
dx_crows
,
dx_cols
,
dx_values
,
x
.
dims
());
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
AllocCooPtr
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
SparseCooTensor
*
dx
)
{
DenseTensor
dx_indices
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_indices
());
DenseTensor
dx_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
dx
->
SetMember
(
dx_indices
,
dx_values
,
x
.
dims
(),
true
);
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseAddCsrGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
)
{
// Special case when y_grad is not needed
if
(
dx
!=
nullptr
&&
dy
==
nullptr
)
{
VLOG
(
4
)
<<
"Special case when dy is not needed"
;
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
}
else
if
(
dx
==
nullptr
&&
dy
!=
nullptr
)
{
VLOG
(
4
)
<<
"Special case when dx is not needed"
;
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
}
else
{
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseSubtractCsrGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
)
{
if
(
dx
)
{
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
}
if
(
dy
)
{
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
phi
::
NegativeKernel
<
T
,
Context
>
(
dev_ctx
,
dout
.
non_zero_elements
(),
dy
->
mutable_non_zero_elements
());
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseMultiplyCsrGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
)
{
if
(
dx
)
{
// dout*y
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
sparse
::
ElementWiseMultiplyCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
y
,
dx
);
}
if
(
dy
)
{
// dout*x
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
sparse
::
ElementWiseMultiplyCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
x
,
dy
);
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseDivideCsrGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
out
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
)
{
if
(
dx
)
{
// dout/y
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
sparse
::
ElementWiseDivideCsrKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
y
,
dx
);
}
if
(
dy
)
{
// -dout * out / y
AllocCsrPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCsr
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
phi
::
NegativeKernel
<
T
,
Context
>
(
dev_ctx
,
dout
.
non_zero_elements
(),
dy
->
mutable_non_zero_elements
());
auto
tmp
=
sparse
::
ElementWiseMultiplyCsr
<
T
,
Context
>
(
dev_ctx
,
*
dy
,
out
);
sparse
::
ElementWiseDivideCsrKernel
<
T
,
Context
>
(
dev_ctx
,
tmp
,
y
,
dy
);
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseAddCooGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
)
{
// Special case when y_grad is not needed*/
if
(
dx
!=
nullptr
&&
dy
==
nullptr
)
{
VLOG
(
4
)
<<
"Special case when dy is not needed"
;
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
}
else
if
(
dx
==
nullptr
&&
dy
!=
nullptr
)
{
VLOG
(
4
)
<<
"Special case when dx is not needed"
;
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
}
else
{
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseSubtractCooGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
)
{
if
(
dx
)
{
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
}
if
(
dy
)
{
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
phi
::
NegativeKernel
<
T
,
Context
>
(
dev_ctx
,
dout
.
non_zero_elements
(),
dy
->
mutable_non_zero_elements
());
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseMultiplyCooGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
)
{
if
(
dx
)
{
// dout*y
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
sparse
::
ElementWiseMultiplyCooKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
y
,
dx
);
}
if
(
dy
)
{
// dout*x
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
sparse
::
ElementWiseMultiplyCooKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
x
,
dy
);
}
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
void
ElementWiseDivideCooGradCPUKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
)
{
if
(
dx
)
{
// dout/y
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
x
,
dx
);
sparse
::
ElementWiseDivideCooKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
y
,
dx
);
}
if
(
dy
)
{
// -dout * out / y
AllocCooPtr
<
T
,
IntT
>
(
dev_ctx
,
y
,
dy
);
CopyCoo
(
dev_ctx
,
dout
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
phi
::
NegativeKernel
<
T
,
Context
>
(
dev_ctx
,
dout
.
non_zero_elements
(),
dy
->
mutable_non_zero_elements
());
auto
tmp
=
sparse
::
ElementWiseMultiplyCoo
<
T
,
Context
>
(
dev_ctx
,
*
dy
,
out
);
sparse
::
ElementWiseDivideCooKernel
<
T
,
Context
>
(
dev_ctx
,
tmp
,
y
,
dy
);
}
}
// CPU Kernel end
// Kernel
template
<
typename
T
,
typename
Context
>
void
ElementWiseDivideCsrGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
out
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
)
{
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_crows
().
dtype
(),
"ElementWiseDivideCsrGradCPUKernel"
,
([
&
]
{
ElementWiseDivideCsrGradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}));
}
template
<
typename
T
,
typename
Context
>
void
ElementWiseDivideCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
)
{
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"ElementWiseDivideCooGradCPUKernel"
,
([
&
]
{
ElementWiseDivideCooGradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}));
}
#define DEFINE_ELEMENTWISE_GRAD_KERNEL(name) \
DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \
\
DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name)
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_CSR(name) \
template <typename T, typename Context> \
void ElementWise##name##CsrGradKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
const SparseCsrTensor& dout, \
SparseCsrTensor* dx, \
SparseCsrTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_crows().dtype(), \
"ElementWise##name##CsrGradCPUKernel", \
([&] { \
ElementWise##name##CsrGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
}
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_COO(name) \
template <typename T, typename Context> \
void ElementWise##name##CooGradKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
const SparseCooTensor& dout, \
SparseCooTensor* dx, \
SparseCooTensor* dy) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooGradCPUKernel", \
([&] { \
ElementWise##name##CooGradCPUKernel<T, data_t>( \
dev_ctx, x, y, dout, dx, dy); \
})); \
}
DEFINE_ELEMENTWISE_GRAD_KERNEL
(
Add
)
DEFINE_ELEMENTWISE_GRAD_KERNEL
(
Subtract
)
DEFINE_ELEMENTWISE_GRAD_KERNEL
(
Multiply
)
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
add_csr_csr_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddCsrGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
subtract_csr_csr_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseSubtractCsrGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
multiply_csr_csr_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseMultiplyCsrGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
divide_csr_csr_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseDivideCsrGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
3
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
add_coo_coo_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddCooGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
subtract_coo_coo_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseSubtractCooGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
multiply_coo_coo_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseMultiplyCooGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
divide_coo_coo_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseDivideCooGradKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
2
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
3
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/cpu/elementwise_kernel.cc
0 → 100644
浏览文件 @
e3d94fc5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/elementwise_kernel.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/elementwise_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.h"
#include "paddle/phi/kernels/sparse/sparse_utils_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Functor
>
struct
BinaryOPWithZeroCompareFunctor
{
explicit
BinaryOPWithZeroCompareFunctor
(
Functor
functor
)
:
functor_
(
functor
)
{}
inline
HOSTDEVICE
bool
operator
()(
const
T
*
a
,
const
T
*
b
,
T
*
result
,
const
int64_t
len
)
const
{
bool
is_zero
=
true
;
for
(
int64_t
i
=
0
;
i
<
len
;
++
i
)
{
result
[
i
]
=
functor_
(
a
[
i
],
b
[
i
]);
if
(
result
[
i
]
!=
0
)
{
is_zero
=
false
;
}
}
return
is_zero
;
}
Functor
functor_
;
};
template
<
typename
T
,
typename
IntT
,
typename
Functor
>
void
Merge
(
const
IntT
el_len
,
const
IntT
*
a_index
,
const
T
*
a_values
,
const
IntT
len_a
,
const
IntT
*
b_index_org
,
const
T
*
b_values_org
,
const
IntT
len_b
,
const
IntT
len_b_max
,
IntT
*
c_index
,
T
*
c_values
,
IntT
&
nnz
,
const
Functor
&
functor_org
,
const
bool
is_divide
)
{
IntT
a
=
0
;
IntT
b
=
0
;
nnz
=
0
;
const
IntT
*
b_index
=
nullptr
;
std
::
vector
<
IntT
>
b_full_index
;
const
std
::
vector
<
T
>
zero
(
el_len
,
0
);
auto
functor
=
BinaryOPWithZeroCompareFunctor
<
T
,
Functor
>
(
functor_org
);
std
::
vector
<
const
T
*>
b_values
(
len_b_max
,
zero
.
data
());
for
(
auto
i
=
0
;
i
<
len_b
;
++
i
)
{
b_values
[
b_index_org
[
i
]]
=
b_values_org
+
i
*
el_len
;
}
// if is divide expend b_index_org to b_full_index
if
(
is_divide
)
{
b_full_index
=
std
::
vector
<
IntT
>
(
len_b_max
);
for
(
int64_t
j
=
0
;
j
<
static_cast
<
int64_t
>
(
b_full_index
.
size
());
++
j
)
{
b_full_index
[
j
]
=
j
;
}
b_index
=
b_full_index
.
data
();
}
else
{
b_index
=
b_index_org
;
}
// merge
while
(
a
<
len_a
&&
b
<
(
is_divide
?
len_b_max
:
len_b
))
{
if
(
a_index
[
a
]
==
b_index
[
b
])
{
if
(
!
functor
(
a_values
+
a
*
el_len
,
b_values
[
b_index
[
b
]],
c_values
+
nnz
*
el_len
,
el_len
))
{
c_index
[
nnz
]
=
a_index
[
a
];
++
nnz
;
}
++
a
;
++
b
;
}
// coordinate x[a] < coordinate y[b]
else
if
(
a_index
[
a
]
<
b_index
[
b
])
{
if
(
!
functor
(
a_values
+
a
*
el_len
,
zero
.
data
(),
c_values
+
nnz
*
el_len
,
el_len
))
{
c_index
[
nnz
]
=
a_index
[
a
];
++
nnz
;
}
++
a
;
}
// coordinate x[a] > coordinate y[b]
else
if
(
a_index
[
a
]
>
b_index
[
b
])
{
if
(
!
functor
(
zero
.
data
(),
b_values
[
b_index
[
b
]],
c_values
+
nnz
*
el_len
,
el_len
))
{
c_index
[
nnz
]
=
b_index
[
b
];
++
nnz
;
}
++
b
;
}
}
// a tail
while
(
a
<
len_a
)
{
if
(
!
functor
(
a_values
+
a
*
el_len
,
zero
.
data
(),
c_values
+
nnz
*
el_len
,
el_len
))
{
c_index
[
nnz
]
=
a_index
[
a
];
++
nnz
;
}
++
a
;
}
// b tail
while
(
b
<
(
is_divide
?
len_b_max
:
len_b
))
{
if
(
!
functor
(
zero
.
data
(),
b_values
[
b_index
[
b
]],
c_values
+
nnz
*
el_len
,
el_len
))
{
c_index
[
nnz
]
=
b_index
[
b
];
++
nnz
;
}
++
b
;
}
}
// SparseCooTensor elementwise op, only support same shape tensor now
template
<
typename
T
,
typename
IntT
,
typename
Context
,
typename
Functor
>
void
ElementWiseCooKernelImpl
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
SparseCooTensor
*
out
,
const
Functor
&
functor
)
{
PADDLE_ENFORCE_EQ
(
x
.
dims
(),
y
.
dims
(),
phi
::
errors
::
InvalidArgument
(
"Currently only support same shape elementwise "
"compute. The input tensor X's shape "
"should be identical with Y's shape. But received X's "
"shape = [%s], Y's shape = [%s]."
,
x
.
dims
(),
y
.
dims
()));
int64_t
element_size
=
1
;
for
(
auto
j
=
1
;
j
<
x
.
non_zero_elements
().
dims
().
size
();
++
j
)
{
element_size
*=
x
.
non_zero_elements
().
dims
()[
j
];
}
IntT
nnz
=
0
;
const
auto
x_values
=
x
.
non_zero_elements
().
data
<
T
>
();
const
auto
y_values
=
y
.
non_zero_elements
().
data
<
T
>
();
const
auto
sparse_dim
=
x
.
non_zero_indices
().
dims
()[
0
];
const
bool
is_divide
=
std
::
is_same
<
Functor
,
funcs
::
DivideFunctor
<
T
>>::
value
;
int64_t
max_len
=
1
;
for
(
auto
j
=
0
;
j
<
sparse_dim
;
++
j
)
{
max_len
*=
x
.
dims
()[
j
];
}
std
::
vector
<
IntT
>
sparse_offsets
(
sparse_dim
),
x_indexs
(
x
.
nnz
()),
y_indexs
(
y
.
nnz
());
phi
::
funcs
::
sparse
::
CalcOffsetsPerDim
<
IntT
>
(
x
.
dims
(),
sparse_dim
,
sparse_offsets
.
data
());
phi
::
funcs
::
sparse
::
FlattenIndices
(
x
.
non_zero_indices
().
data
<
IntT
>
(),
sparse_offsets
.
data
(),
x
.
nnz
(),
sparse_dim
,
0
,
1
,
x_indexs
.
data
());
phi
::
funcs
::
sparse
::
FlattenIndices
(
y
.
non_zero_indices
().
data
<
IntT
>
(),
sparse_offsets
.
data
(),
y
.
nnz
(),
sparse_dim
,
0
,
1
,
y_indexs
.
data
());
std
::
vector
<
IntT
>
out_indexs
;
std
::
vector
<
T
>
out_values_vec
;
if
(
is_divide
)
{
out_indexs
.
reserve
(
max_len
);
}
else
{
out_indexs
.
reserve
(
x
.
nnz
()
+
y
.
nnz
());
}
out_values_vec
.
reserve
(
max_len
*
element_size
);
// merge x and y
Merge
<
T
,
IntT
,
Functor
>
(
element_size
,
x_indexs
.
data
(),
x_values
,
x_indexs
.
size
(),
y_indexs
.
data
(),
y_values
,
y_indexs
.
size
(),
max_len
,
out_indexs
.
data
(),
out_values_vec
.
data
(),
nnz
,
functor
,
is_divide
);
std
::
vector
<
IntT
>
out_indices_vec
;
out_indices_vec
.
resize
(
nnz
*
sparse_dim
);
Dim
<
DDim
::
kMaxRank
>
const_dims
;
for
(
auto
i
=
0
;
i
<
x
.
dims
().
size
();
i
++
)
{
const_dims
[
i
]
=
x
.
dims
()[
i
];
}
funcs
::
sparse
::
IndexToCoordinate
<
IntT
>
(
out_indexs
.
data
(),
const_dims
,
nnz
,
sparse_dim
,
0
,
1
,
out_indices_vec
.
data
());
if
(
nnz
==
0
)
{
phi
::
DenseTensor
out_indices
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_indices
());
phi
::
DenseTensor
out_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
out
->
SetMember
(
out_indices
,
out_values
,
x
.
dims
());
}
else
{
DenseTensorMeta
indices_meta
(
paddle
::
experimental
::
CppTypeToDataType
<
IntT
>::
Type
(),
phi
::
make_ddim
(
{
static_cast
<
int64_t
>
(
sparse_dim
),
static_cast
<
int64_t
>
(
nnz
)}),
DataLayout
::
NCHW
);
auto
indeces_dim
=
vectorize
(
slice_ddim
(
x
.
non_zero_elements
().
dims
(),
1
,
x
.
non_zero_elements
().
dims
().
size
()));
indeces_dim
.
insert
(
indeces_dim
.
begin
(),
nnz
);
DenseTensorMeta
values_meta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
phi
::
make_ddim
(
indeces_dim
),
DataLayout
::
NCHW
);
phi
::
DenseTensor
out_indices
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
indices_meta
));
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
std
::
memcpy
(
out_indices
.
data
<
IntT
>
(),
out_indices_vec
.
data
(),
sizeof
(
IntT
)
*
sparse_dim
*
nnz
);
std
::
memcpy
(
out_values
.
data
<
T
>
(),
out_values_vec
.
data
(),
sizeof
(
T
)
*
nnz
*
element_size
);
out
->
SetMember
(
out_indices
,
out_values
,
x
.
dims
());
}
}
#define DEFINE_CSR_ELEMENTWISE_CPU_KERNEL(name) \
template <typename T, typename IntT, typename Context> \
void ElementWise##name##CsrCPUKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
SparseCsrTensor* out) { \
funcs::name##Functor<T> functor; \
auto coo_x = SparseCsrToCoo<T>(dev_ctx, x); \
auto coo_y = SparseCsrToCoo<T>(dev_ctx, y); \
DenseTensor indeces; \
DenseTensor values; \
SparseCooTensor coo_out; \
coo_out.SetMember(indeces, values, x.dims()); \
ElementWiseCooKernelImpl<T, IntT, Context, funcs::name##Functor<T>>( \
dev_ctx, coo_x, coo_y, &coo_out, functor); \
*out = SparseCooToCsr<T>(dev_ctx, coo_out); \
}
#define DEFINE_CSR_ELEMENTWISE_KERNEL(name) \
template <typename T, typename Context> \
void ElementWise##name##CsrKernel(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y, \
SparseCsrTensor* out) { \
PD_VISIT_INTEGRAL_TYPES( \
x.non_zero_crows().dtype(), "ElementWise##name##CsrCPUKernel", ([&] { \
ElementWise##name##CsrCPUKernel<T, data_t>(dev_ctx, x, y, out); \
})); \
}
#define DEFINE_COO_ELEMENTWISE_CPU_KERNEL(name) \
template <typename T, typename IntT, typename Context> \
void ElementWise##name##CooCPUKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
SparseCooTensor* out) { \
funcs::name##Functor<T> functor; \
ElementWiseCooKernelImpl<T, IntT, Context, funcs::name##Functor<T>>( \
dev_ctx, x, y, out, functor); \
}
#define DEFINE_COO_ELEMENTWISE_KERNEL(name) \
template <typename T, typename Context> \
void ElementWise##name##CooKernel(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y, \
SparseCooTensor* out) { \
PD_VISIT_INTEGRAL_TYPES(x.non_zero_indices().dtype(), \
"ElementWise##name##CooCPUKernel", \
([&] { \
ElementWise##name##CooCPUKernel<T, data_t>( \
dev_ctx, x, y, out); \
})); \
}
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL
(
Add
)
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL
(
Subtract
)
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL
(
Multiply
)
DEFINE_CSR_ELEMENTWISE_CPU_KERNEL
(
Divide
)
DEFINE_CSR_ELEMENTWISE_KERNEL
(
Add
)
DEFINE_CSR_ELEMENTWISE_KERNEL
(
Subtract
)
DEFINE_CSR_ELEMENTWISE_KERNEL
(
Multiply
)
DEFINE_CSR_ELEMENTWISE_KERNEL
(
Divide
)
DEFINE_COO_ELEMENTWISE_CPU_KERNEL
(
Add
)
DEFINE_COO_ELEMENTWISE_CPU_KERNEL
(
Subtract
)
DEFINE_COO_ELEMENTWISE_CPU_KERNEL
(
Multiply
)
DEFINE_COO_ELEMENTWISE_CPU_KERNEL
(
Divide
)
DEFINE_COO_ELEMENTWISE_KERNEL
(
Add
)
DEFINE_COO_ELEMENTWISE_KERNEL
(
Subtract
)
DEFINE_COO_ELEMENTWISE_KERNEL
(
Multiply
)
DEFINE_COO_ELEMENTWISE_KERNEL
(
Divide
)
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
add_csr_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddCsrKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
add_coo_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseAddCooKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
subtract_csr_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseSubtractCsrKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
subtract_coo_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseSubtractCooKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
multiply_csr_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseMultiplyCsrKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
multiply_coo_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseMultiplyCooKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
PD_REGISTER_KERNEL
(
divide_csr_csr
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseDivideCsrKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_CSR
);
}
PD_REGISTER_KERNEL
(
divide_coo_coo
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
ElementWiseDivideCooKernel
,
float
,
double
,
int16_t
,
int
,
int64_t
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
1
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/elementwise_grad_kernel.h
0 → 100644
浏览文件 @
e3d94fc5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace
phi
{
namespace
sparse
{
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD(name) \
DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, Csr) \
\
DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, Coo)
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC(name) \
DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, Csr) \
\
DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, Coo)
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD_WITH_TYPE(name, type) \
template <typename T, typename Context> \
void ElementWise##name##type##GradKernel(const Context& dev_ctx, \
const Sparse##type##Tensor& x, \
const Sparse##type##Tensor& y, \
const Sparse##type##Tensor& dout, \
Sparse##type##Tensor* dx, \
Sparse##type##Tensor* dy);
#define DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC_WITH_TYPE(name, type) \
template <typename T, typename Context> \
std::vector<Sparse##type##Tensor> ElementWise##name##type##Grad( \
const Context& dev_ctx, \
const Sparse##type##Tensor& x, \
const Sparse##type##Tensor& y, \
const Sparse##type##Tensor& dout) { \
Sparse##type##Tensor dx; \
Sparse##type##Tensor dy; \
ElementWise##name##type##GradKernel<T, Context>( \
dev_ctx, x, y, dout, &dx, &dy); \
return std::vector<Sparse##type##Tensor>{dx, dy}; \
}
DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD
(
Add
)
DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD
(
Subtract
)
DEFINE_ELEMENTWISE_GRAD_KERNEL_HEAD
(
Multiply
)
DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC
(
Add
)
DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC
(
Subtract
)
DEFINE_ELEMENTWISE_GRAD_KERNEL_FUNC
(
Multiply
)
template
<
typename
T
,
typename
Context
>
void
ElementWiseDivideCsrGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
out
,
const
SparseCsrTensor
&
dout
,
SparseCsrTensor
*
dx
,
SparseCsrTensor
*
dy
);
template
<
typename
T
,
typename
Context
>
void
ElementWiseDivideCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
dout
,
SparseCooTensor
*
dx
,
SparseCooTensor
*
dy
);
template
<
typename
T
,
typename
Context
>
std
::
vector
<
SparseCsrTensor
>
ElementWiseDivideCsrGrad
(
const
Context
&
dev_ctx
,
const
SparseCsrTensor
&
x
,
const
SparseCsrTensor
&
y
,
const
SparseCsrTensor
&
out
,
const
SparseCsrTensor
&
dout
)
{
SparseCsrTensor
dx
;
SparseCsrTensor
dy
;
ElementWiseDivideCsrGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
y
,
out
,
dout
,
&
dx
,
&
dy
);
return
std
::
vector
<
SparseCsrTensor
>
{
dx
,
dy
};
}
template
<
typename
T
,
typename
Context
>
std
::
vector
<
SparseCooTensor
>
ElementWiseDivideCooGrad
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
y
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
dout
)
{
SparseCooTensor
dx
;
SparseCooTensor
dy
;
ElementWiseDivideCooGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
y
,
out
,
dout
,
&
dx
,
&
dy
);
return
std
::
vector
<
SparseCooTensor
>
{
dx
,
dy
};
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/elementwise_kernel.h
0 → 100644
浏览文件 @
e3d94fc5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
namespace
phi
{
namespace
sparse
{
#define DEFINE_ELEMENTWISE_KERNEL_HEAD(name) \
DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, Csr) \
\
DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, Coo)
#define DEFINE_ELEMENTWISE_KERNEL_FUNC(name) \
DEFINE_CSR_ELEMENTWISE_KERNEL_FUNC(name) \
\
DEFINE_COO_ELEMENTWISE_KERNEL_FUNC(name)
#define DEFINE_ELEMENTWISE_KERNEL_HEAD_WITH_TYPE(name, type) \
template <typename T, typename Context> \
void ElementWise##name##type##Kernel(const Context& dev_ctx, \
const Sparse##type##Tensor& x, \
const Sparse##type##Tensor& y, \
Sparse##type##Tensor* out);
#define DEFINE_CSR_ELEMENTWISE_KERNEL_FUNC(name) \
template <typename T, typename Context> \
SparseCsrTensor ElementWise##name##Csr(const Context& dev_ctx, \
const SparseCsrTensor& x, \
const SparseCsrTensor& y) { \
DenseTensor non_zero_crows; \
DenseTensor non_zero_cols; \
DenseTensor non_zero_elements; \
SparseCsrTensor out( \
non_zero_crows, non_zero_cols, non_zero_elements, x.dims()); \
ElementWise##name##CsrKernel<T, Context>(dev_ctx, x, y, &out); \
return out; \
}
#define DEFINE_COO_ELEMENTWISE_KERNEL_FUNC(name) \
template <typename T, typename Context> \
SparseCooTensor ElementWise##name##Coo(const Context& dev_ctx, \
const SparseCooTensor& x, \
const SparseCooTensor& y) { \
DenseTensor non_zero_indices; \
DenseTensor non_zero_elements; \
SparseCooTensor out(non_zero_indices, non_zero_elements, x.dims()); \
ElementWise##name##CooKernel<T, Context>(dev_ctx, x, y, &out); \
return out; \
}
DEFINE_ELEMENTWISE_KERNEL_HEAD
(
Add
)
DEFINE_ELEMENTWISE_KERNEL_HEAD
(
Subtract
)
DEFINE_ELEMENTWISE_KERNEL_HEAD
(
Multiply
)
DEFINE_ELEMENTWISE_KERNEL_HEAD
(
Divide
)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Add
)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Subtract
)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Multiply
)
DEFINE_ELEMENTWISE_KERNEL_FUNC
(
Divide
)
}
// namespace sparse
}
// namespace phi
paddle/phi/tests/kernels/CMakeLists.txt
浏览文件 @
e3d94fc5
...
...
@@ -70,6 +70,10 @@ cc_test(
test_sparse_activation_dev_api
SRCS test_sparse_activation_dev_api.cc
DEPS phi phi_api_utils
)
cc_test
(
test_sparse_elementwise_dev_api
SRCS test_sparse_elementwise_dev_api.cc
DEPS phi phi_api_utils
)
cc_test
(
test_math_function
...
...
paddle/phi/tests/kernels/test_sparse_elementwise_dev_api.cc
0 → 100644
浏览文件 @
e3d94fc5
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
0 → 100644
浏览文件 @
e3d94fc5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
operator
import
__add__
,
__sub__
,
__mul__
,
__truediv__
import
numpy
as
np
import
paddle
from
paddle.fluid.framework
import
_test_eager_guard
op_list
=
[
__add__
,
__sub__
,
__mul__
,
__truediv__
]
def
get_actual_res
(
x
,
y
,
op
):
if
op
==
__add__
:
res
=
paddle
.
incubate
.
sparse
.
add
(
x
,
y
)
elif
op
==
__sub__
:
res
=
paddle
.
incubate
.
sparse
.
subtract
(
x
,
y
)
elif
op
==
__mul__
:
res
=
paddle
.
incubate
.
sparse
.
multiply
(
x
,
y
)
elif
op
==
__truediv__
:
res
=
paddle
.
incubate
.
sparse
.
divide
(
x
,
y
)
else
:
raise
ValueError
(
"unsupported op"
)
return
res
class
TestSparseElementWiseAPI
(
unittest
.
TestCase
):
"""
test paddle.sparse.add, subtract, multiply, divide
"""
def
setUp
(
self
):
paddle
.
fluid
.
set_flags
({
"FLAGS_retain_grad_for_all_tensor"
:
True
})
np
.
random
.
seed
(
2022
)
self
.
op_list
=
op_list
self
.
csr_shape
=
[
128
,
256
]
self
.
coo_shape
=
[
4
,
8
,
3
,
5
]
self
.
support_dtypes
=
[
'float32'
,
'float64'
,
'int32'
,
'int64'
]
def
func_test_csr
(
self
,
op
):
for
dtype
in
self
.
support_dtypes
:
x
=
np
.
random
.
randint
(
-
255
,
255
,
size
=
self
.
csr_shape
).
astype
(
dtype
)
y
=
np
.
random
.
randint
(
-
255
,
255
,
size
=
self
.
csr_shape
).
astype
(
dtype
)
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
dtype
,
stop_gradient
=
False
)
dense_y
=
paddle
.
to_tensor
(
y
,
dtype
=
dtype
,
stop_gradient
=
False
)
s_dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
dtype
,
stop_gradient
=
False
)
s_dense_y
=
paddle
.
to_tensor
(
y
,
dtype
=
dtype
,
stop_gradient
=
False
)
csr_x
=
s_dense_x
.
to_sparse_csr
()
csr_y
=
s_dense_y
.
to_sparse_csr
()
actual_res
=
get_actual_res
(
csr_x
,
csr_y
,
op
)
actual_res
.
backward
(
actual_res
)
expect_res
=
op
(
dense_x
,
dense_y
)
expect_res
.
backward
(
expect_res
)
self
.
assertTrue
(
np
.
allclose
(
expect_res
.
numpy
(),
actual_res
.
to_dense
().
numpy
(),
equal_nan
=
True
))
if
not
(
op
==
__truediv__
and
dtype
in
[
'int32'
,
'int64'
]):
self
.
assertTrue
(
np
.
allclose
(
dense_x
.
grad
.
numpy
(),
csr_x
.
grad
.
to_dense
().
numpy
(),
equal_nan
=
True
))
self
.
assertTrue
(
np
.
allclose
(
dense_y
.
grad
.
numpy
(),
csr_y
.
grad
.
to_dense
().
numpy
(),
equal_nan
=
True
))
def
func_test_coo
(
self
,
op
):
for
sparse_dim
in
range
(
len
(
self
.
coo_shape
)
-
1
,
len
(
self
.
coo_shape
)):
for
dtype
in
self
.
support_dtypes
:
x
=
np
.
random
.
randint
(
-
255
,
255
,
size
=
self
.
coo_shape
).
astype
(
dtype
)
y
=
np
.
random
.
randint
(
-
255
,
255
,
size
=
self
.
coo_shape
).
astype
(
dtype
)
dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
dtype
,
stop_gradient
=
False
)
dense_y
=
paddle
.
to_tensor
(
y
,
dtype
=
dtype
,
stop_gradient
=
False
)
s_dense_x
=
paddle
.
to_tensor
(
x
,
dtype
=
dtype
,
stop_gradient
=
False
)
s_dense_y
=
paddle
.
to_tensor
(
y
,
dtype
=
dtype
,
stop_gradient
=
False
)
coo_x
=
s_dense_x
.
to_sparse_coo
(
sparse_dim
)
coo_y
=
s_dense_y
.
to_sparse_coo
(
sparse_dim
)
actual_res
=
get_actual_res
(
coo_x
,
coo_y
,
op
)
actual_res
.
backward
(
actual_res
)
expect_res
=
op
(
dense_x
,
dense_y
)
expect_res
.
backward
(
expect_res
)
self
.
assertTrue
(
np
.
allclose
(
expect_res
.
numpy
(),
actual_res
.
to_dense
().
numpy
(),
equal_nan
=
True
))
self
.
assertTrue
(
np
.
allclose
(
dense_x
.
grad
.
numpy
(),
coo_x
.
grad
.
to_dense
().
numpy
(),
equal_nan
=
True
))
self
.
assertTrue
(
np
.
allclose
(
dense_y
.
grad
.
numpy
(),
coo_y
.
grad
.
to_dense
().
numpy
(),
equal_nan
=
True
))
def
test_support_dtypes_csr
(
self
):
paddle
.
device
.
set_device
(
'cpu'
)
if
paddle
.
device
.
get_device
()
==
"cpu"
:
with
_test_eager_guard
():
for
op
in
op_list
:
self
.
func_test_csr
(
op
)
def
test_support_dtypes_coo
(
self
):
paddle
.
device
.
set_device
(
'cpu'
)
if
paddle
.
device
.
get_device
()
==
"cpu"
:
with
_test_eager_guard
():
for
op
in
op_list
:
self
.
func_test_coo
(
op
)
if
__name__
==
"__main__"
:
paddle
.
device
.
set_device
(
'cpu'
)
unittest
.
main
()
python/paddle/incubate/sparse/__init__.py
浏览文件 @
e3d94fc5
...
...
@@ -22,6 +22,11 @@ from .unary import tanh
from
.binary
import
matmul
from
.binary
import
masked_matmul
from
.math
import
add
from
.math
import
divide
from
.math
import
multiply
from
.math
import
subtract
from
.
import
nn
__all__
=
[
...
...
@@ -32,4 +37,8 @@ __all__ = [
'tanh'
,
'matmul'
,
'masked_matmul'
,
'add'
,
'subtract'
,
'multiply'
,
'divide'
,
]
python/paddle/incubate/sparse/math.py
0 → 100644
浏览文件 @
e3d94fc5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
sparse math functions
"""
from
__future__
import
print_function
from
paddle
import
_C_ops
,
in_dynamic_mode
,
device
,
int32
,
int64
from
paddle.tensor
import
cast
from
paddle.incubate.sparse
import
sparse_csr_tensor
def
_cast_coo
(
x
,
dtype
,
name
=
None
):
indices
=
x
.
indices
()
values
=
cast
(
x
.
values
(),
dtype
)
return
_C_ops
.
final_state_sparse_create_sparse_coo_tensor
(
values
,
indices
,
x
.
shape
)
def
_cast_csr
(
x
,
dtype
,
name
=
None
):
crows
=
x
.
crows
()
cols
=
x
.
cols
()
values
=
cast
(
x
.
values
(),
dtype
)
return
sparse_csr_tensor
(
crows
,
cols
,
values
,
x
.
shape
)
def
_cast
(
x
,
dtype
,
name
=
None
):
if
x
.
is_sparse_coo
():
return
_cast_coo
(
x
,
dtype
,
name
)
return
_cast_csr
(
x
,
dtype
,
name
)
def
add
(
x
,
y
,
name
=
None
):
"""
Add two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
The equation is:
.. math::
out = x + y
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: the result tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.device.set_device("cpu")
with _test_eager_guard():
x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
sparse_x = x.to_sparse_csr()
sparse_y = y.to_sparse_csr()
sparse_z = paddle.incubate.sparse.add(sparse_x, sparse_y)
print(sparse_z.to_dense())
# [[ 0., -1., 0., 0.],
# [ 0., 2., -6., 0.],
# [ 6., 8., 4., 8.]]
"""
assert
device
.
get_device
(
)
==
"cpu"
,
"Currently, Sparse add only support CPU device."
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
assert
x
.
is_sparse_csr
()
==
y
.
is_sparse_csr
(
),
f
"Expect sparse tensor type to be same"
if
x
.
is_sparse_coo
()
or
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_add
(
x
,
y
)
else
:
raise
ValueError
(
"Currently, sparse.add only support the input of SparseCooTensor or SparseCsrTensor"
)
def
subtract
(
x
,
y
,
name
=
None
):
"""
Subtract two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
The equation is:
.. math::
out = x - y
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: the result tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.device.set_device("cpu")
with _test_eager_guard():
x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
sparse_x = x.to_sparse_csr()
sparse_y = y.to_sparse_csr()
sparse_z = paddle.incubate.sparse.subtract(sparse_x, sparse_y)
print(sparse_z.to_dense())
# [[ 0., -1., 0., 4.],
# [ 0., -2., 0., 0.],
# [ 2., 2., -4., -8.]]
"""
assert
device
.
get_device
(
)
==
"cpu"
,
"Currently, Sparse subtract only support CPU device."
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
assert
x
.
is_sparse_csr
()
==
y
.
is_sparse_csr
(
),
f
"Expect sparse tensor type to be same"
if
x
.
is_sparse_coo
()
or
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_subtract
(
x
,
y
)
else
:
raise
ValueError
(
"Currently, sparse.subtract only support the input of SparseCooTensor or SparseCsrTensor"
)
def
multiply
(
x
,
y
,
name
=
None
):
"""
Multiply two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
The equation is:
.. math::
out = x * y
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: the result tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.device.set_device("cpu")
with _test_eager_guard():
x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
sparse_x = x.to_sparse_csr()
sparse_y = y.to_sparse_csr()
sparse_z = paddle.incubate.sparse.multiply(sparse_x, sparse_y)
print(sparse_z.to_dense())
# [[ 0., 0., 0., -4.],
# [ 0., 0., 9., 0.],
# [ 8., 15., 0., 0.]]
"""
assert
device
.
get_device
(
)
==
"cpu"
,
"Currently, Sparse multiply only support CPU device."
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
assert
x
.
is_sparse_csr
()
==
y
.
is_sparse_csr
(
),
f
"Expect sparse tensor type to be same"
if
x
.
is_sparse_coo
()
or
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_multiply
(
x
,
y
)
else
:
raise
ValueError
(
"Currently, sparse.multiply only support the input of SparseCooTensor or SparseCsrTensor"
)
def
divide
(
x
,
y
,
name
=
None
):
"""
Divide two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
The equation is:
.. math::
out = x / y
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: the result tensor.
Examples:
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
paddle.device.set_device("cpu")
with _test_eager_guard():
x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
sparse_x = x.to_sparse_csr()
sparse_y = y.to_sparse_csr()
sparse_z = paddle.incubate.sparse.divide(sparse_x, sparse_y)
print(sparse_z.to_dense())
# [[ nan , -inf. , nan , -1. ],
# [ nan , 0. , 1. , nan ],
# [ 2. , 1.66666663, 0. , 0. ]]
"""
assert
device
.
get_device
(
)
==
"cpu"
,
"Currently, Sparse divide only support CPU device."
assert
in_dynamic_mode
(),
"Currently, Sparse API only support dynamic mode"
assert
x
.
is_sparse_csr
()
==
y
.
is_sparse_csr
(
),
f
"Expect sparse tensor type to be same"
if
x
.
dtype
in
[
int32
,
int64
]:
if
x
.
is_sparse_coo
()
or
x
.
is_sparse_csr
():
cx
=
_cast
(
x
,
'float32'
)
cy
=
_cast
(
y
,
'float32'
)
return
_C_ops
.
final_state_sparse_divide
(
cx
,
cy
)
else
:
raise
ValueError
(
"Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor"
)
else
:
if
x
.
is_sparse_coo
()
or
x
.
is_sparse_csr
():
return
_C_ops
.
final_state_sparse_divide
(
x
,
y
)
else
:
raise
ValueError
(
"Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor"
)
python/paddle/utils/code_gen/sparse_api.yaml
浏览文件 @
e3d94fc5
-
api
:
add
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
kernel
:
func
:
add_coo_coo{sparse_coo -> sparse_coo},
add_csr_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
add_grad
-
api
:
conv3d
args
:
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
output
:
Tensor(out), Tensor(rulebook)
...
...
@@ -28,6 +37,24 @@
invoke
:
to_sparse_coo_impl(x, sparse_dim)
backward
:
dense_to_coo_grad
-
api
:
divide
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
kernel
:
func
:
divide_coo_coo{sparse_coo -> sparse_coo},
divide_csr_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
divide_grad
-
api
:
multiply
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
kernel
:
func
:
multiply_coo_coo{sparse_coo -> sparse_coo},
multiply_csr_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
multiply_grad
-
api
:
relu
args
:
(Tensor x)
output
:
Tensor(out)
...
...
@@ -63,6 +90,15 @@
layout
:
x
backward
:
sqrt_grad
-
api
:
subtract
args
:
(Tensor x, Tensor y)
output
:
Tensor(out)
kernel
:
func
:
subtract_coo_coo{sparse_coo -> sparse_coo},
subtract_csr_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
subtract_grad
-
api
:
tanh
args
:
(Tensor x)
output
:
Tensor(out)
...
...
python/paddle/utils/code_gen/sparse_bw_api.yaml
浏览文件 @
e3d94fc5
-
backward_api
:
add_grad
forward
:
add(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
kernel
:
func
:
add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
-
backward_api
:
conv3d_grad
forward
:
conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
args
:
(Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
...
...
@@ -25,6 +33,14 @@
output
:
Tensor(x_grad)
invoke
:
to_dense_impl(out_grad)
-
backward_api
:
divide_grad
forward
:
divide(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
kernel
:
func
:
divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
divide_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
-
backward_api
:
masked_matmul_grad
forward
:
masked_matmul(Tensor x, Tensor y, Tensor mask) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
...
...
@@ -39,6 +55,14 @@
kernel
:
func
:
csr_dense_matmul_grad{sparse_csr, dense, dense -> sparse_csr, dense}
-
backward_api
:
multiply_grad
forward
:
multiply(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
kernel
:
func
:
multiply_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
multiply_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
-
backward_api
:
relu_grad
forward
:
relu(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
...
...
@@ -74,6 +98,14 @@
kernel
:
func
:
sparse_coo_sqrt_grad {sparse_coo, sparse_coo -> sparse_coo}
-
backward_api
:
subtract_grad
forward
:
subtract(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
output
:
Tensor(x_grad), Tensor(y_grad)
kernel
:
func
:
subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
-
backward_api
:
tanh_grad
forward
:
tanh(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
...
...
python/setup.py.in
浏览文件 @
e3d94fc5
...
...
@@ -281,6 +281,7 @@ packages=['paddle',
'paddle.incubate.tensor',
'paddle.incubate.multiprocessing',
'paddle.incubate.nn',
'paddle.incubate.sparse',
'paddle.incubate.asp',
'paddle.incubate.passes',
'paddle.distribution',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录