Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
19d9c736
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
19d9c736
编写于
7月 22, 2022
作者:
zhouweiwei2014
提交者:
GitHub
7月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]add sparse unary api(expm1/deg2rad/rad2deg/relu6/leaky_relu) (#44432)
上级
18c77325
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
422 addition
and
75 deletion
+422
-75
paddle/phi/api/yaml/sparse_api.yaml
paddle/phi/api/yaml/sparse_api.yaml
+27
-0
paddle/phi/api/yaml/sparse_bw_api.yaml
paddle/phi/api/yaml/sparse_bw_api.yaml
+24
-0
paddle/phi/kernels/activation_grad_kernel.h
paddle/phi/kernels/activation_grad_kernel.h
+1
-1
paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
+3
-0
paddle/phi/kernels/sparse/cpu/unary_kernel.cc
paddle/phi/kernels/sparse/cpu/unary_kernel.cc
+3
-0
paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
+3
-0
paddle/phi/kernels/sparse/gpu/unary_kernel.cu
paddle/phi/kernels/sparse/gpu/unary_kernel.cu
+3
-0
paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
+3
-0
paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
+3
-0
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
+17
-0
python/paddle/incubate/sparse/__init__.py
python/paddle/incubate/sparse/__init__.py
+6
-0
python/paddle/incubate/sparse/nn/__init__.py
python/paddle/incubate/sparse/nn/__init__.py
+4
-0
python/paddle/incubate/sparse/nn/functional/__init__.py
python/paddle/incubate/sparse/nn/functional/__init__.py
+4
-0
python/paddle/incubate/sparse/nn/functional/activation.py
python/paddle/incubate/sparse/nn/functional/activation.py
+90
-29
python/paddle/incubate/sparse/nn/layer/activation.py
python/paddle/incubate/sparse/nn/layer/activation.py
+115
-34
python/paddle/incubate/sparse/unary.py
python/paddle/incubate/sparse/unary.py
+116
-11
未找到文件。
paddle/phi/api/yaml/sparse_api.yaml
浏览文件 @
19d9c736
...
...
@@ -127,6 +127,24 @@
divide_csr_scalar{sparse_csr -> sparse_csr}
backward
:
divide_scalar_grad
-
api
:
expm1
args
:
(Tensor x)
output
:
Tensor(out)
kernel
:
func
:
expm1_coo{sparse_coo -> sparse_coo},
expm1_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
expm1_grad
-
api
:
leaky_relu
args
:
(Tensor x, float alpha)
output
:
Tensor(out)
kernel
:
func
:
leaky_relu_coo{sparse_coo -> sparse_coo},
leaky_relu_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
leaky_relu_grad
-
api
:
log1p
args
:
(Tensor x)
output
:
Tensor(out)
...
...
@@ -163,6 +181,15 @@
layout
:
x
backward
:
relu_grad
-
api
:
relu6
args
:
(Tensor x, float threshold)
output
:
Tensor(out)
kernel
:
func
:
relu6_coo{sparse_coo -> sparse_coo},
relu6_csr{sparse_csr -> sparse_csr}
layout
:
x
backward
:
relu6_grad
-
api
:
scale
args
:
(Tensor x, float scale, float bias, bool bias_after_scale)
output
:
Tensor(out)
...
...
paddle/phi/api/yaml/sparse_bw_api.yaml
浏览文件 @
19d9c736
...
...
@@ -122,6 +122,22 @@
output
:
Tensor(x_grad)
invoke
:
divide_scalar(out_grad, scalar)
-
backward_api
:
expm1_grad
forward
:
expm1(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
output
:
Tensor(x_grad)
kernel
:
func
:
expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
leaky_relu_grad
forward
:
leaky_relu(Tensor x, float alpha) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad, float alpha)
output
:
Tensor(x_grad)
kernel
:
func
:
leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
log1p_grad
forward
:
log1p(Tensor x) -> Tensor(out)
args
:
(Tensor x, Tensor out_grad)
...
...
@@ -178,6 +194,14 @@
func
:
pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
relu6_grad
forward
:
relu6(Tensor x, float threshold) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad, float threshold)
output
:
Tensor(x_grad)
kernel
:
func
:
relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
relu_grad
forward
:
relu(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
...
...
paddle/phi/kernels/activation_grad_kernel.h
浏览文件 @
19d9c736
...
...
@@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Logit
,
eps
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Mish
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX
(
Celu
,
alpha
);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT
(
Relu6
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
BRelu
,
t_min
,
t_max
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
STanh
,
scale_a
,
scale_b
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX
(
Softplus
,
beta
,
threshold
);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT
(
HardSigmoid
,
slope
,
offset
);
}
// namespace phi
paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
浏览文件 @
19d9c736
...
...
@@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
relu
,
Relu
)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
abs
,
Abs
)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
pow
,
Pow
)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
expm1
,
Expm1
)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
relu6
,
Relu6
)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_KERNEL
(
cast_coo_grad
,
CPU
,
...
...
paddle/phi/kernels/sparse/cpu/unary_kernel.cc
浏览文件 @
19d9c736
...
...
@@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
abs
,
Abs
)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
pow
,
Pow
)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
scale
,
Scale
)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
expm1
,
Expm1
)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
relu6
,
Relu6
)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_KERNEL
(
divide_coo_scalar
,
CPU
,
...
...
paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
浏览文件 @
19d9c736
...
...
@@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
relu
,
Relu
)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
abs
,
Abs
)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
pow
,
Pow
)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
expm1
,
Expm1
)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
relu6
,
Relu6
)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_KERNEL
(
cast_coo_grad
,
GPU
,
...
...
paddle/phi/kernels/sparse/gpu/unary_kernel.cu
浏览文件 @
19d9c736
...
...
@@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
abs
,
Abs
)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
pow
,
Pow
)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
scale
,
Scale
)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
expm1
,
Expm1
)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
relu6
,
Relu6
)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL
(
leaky_relu
,
LeakyRelu
)
PD_REGISTER_KERNEL
(
divide_coo_scalar
,
GPU
,
...
...
paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
浏览文件 @
19d9c736
...
...
@@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
DEFINE_SPARSE_UNARY_GRAD_KERNEL
(
Log1p
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL
(
Relu
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL
(
Abs
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL
(
Expm1
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR
(
Pow
,
factor
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR
(
LeakyRelu
,
alpha
)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR
(
Relu6
,
threshold
)
template
<
typename
T
,
typename
Context
>
void
CastCooGradKernel
(
const
Context
&
dev_ctx
,
...
...
paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
浏览文件 @
19d9c736
...
...
@@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square)
DEFINE_SPARSE_UNARY_KERNEL
(
Log1p
)
DEFINE_SPARSE_UNARY_KERNEL
(
Relu
)
DEFINE_SPARSE_UNARY_KERNEL
(
Abs
)
DEFINE_SPARSE_UNARY_KERNEL
(
Expm1
)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR
(
Pow
,
factor
)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR
(
Relu6
,
threshold
)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR
(
LeakyRelu
,
alpha
)
template
<
typename
T
,
typename
Context
>
void
ScaleCooKernel
(
const
Context
&
dev_ctx
,
...
...
python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
浏览文件 @
19d9c736
...
...
@@ -123,9 +123,26 @@ class TestSparseUnary(unittest.TestCase):
self
.
compare_with_dense
(
paddle
.
nn
.
ReLU
(),
paddle
.
incubate
.
sparse
.
nn
.
ReLU
())
def
test_sparse_relu6
(
self
):
self
.
compare_with_dense
(
paddle
.
nn
.
ReLU6
(),
paddle
.
incubate
.
sparse
.
nn
.
ReLU6
())
def
test_sparse_leaky_relu
(
self
):
self
.
compare_with_dense
(
paddle
.
nn
.
LeakyReLU
(
0.1
),
paddle
.
incubate
.
sparse
.
nn
.
LeakyReLU
(
0.1
))
def
test_sparse_abs
(
self
):
self
.
compare_with_dense
(
paddle
.
abs
,
paddle
.
incubate
.
sparse
.
abs
)
def
test_sparse_expm1
(
self
):
self
.
compare_with_dense
(
paddle
.
expm1
,
paddle
.
incubate
.
sparse
.
expm1
)
def
test_sparse_deg2rad
(
self
):
self
.
compare_with_dense
(
paddle
.
deg2rad
,
paddle
.
incubate
.
sparse
.
deg2rad
)
def
test_sparse_rad2deg
(
self
):
self
.
compare_with_dense
(
paddle
.
rad2deg
,
paddle
.
incubate
.
sparse
.
rad2deg
)
def
test_sparse_neg
(
self
):
self
.
compare_with_dense
(
paddle
.
neg
,
paddle
.
incubate
.
sparse
.
neg
)
...
...
python/paddle/incubate/sparse/__init__.py
浏览文件 @
19d9c736
...
...
@@ -31,6 +31,9 @@ from .unary import pow
from
.unary
import
cast
from
.unary
import
neg
from
.unary
import
coalesce
from
.unary
import
deg2rad
from
.unary
import
rad2deg
from
.unary
import
expm1
from
.binary
import
mv
from
.binary
import
matmul
...
...
@@ -62,6 +65,9 @@ __all__ = [
'pow'
,
'cast'
,
'neg'
,
'deg2rad'
,
'rad2deg'
,
'expm1'
,
'mv'
,
'matmul'
,
'masked_matmul'
,
...
...
python/paddle/incubate/sparse/nn/__init__.py
浏览文件 @
19d9c736
...
...
@@ -16,6 +16,8 @@ from . import functional
from
.layer.activation
import
ReLU
from
.layer.activation
import
Softmax
from
.layer.activation
import
ReLU6
from
.layer.activation
import
LeakyReLU
from
.layer.norm
import
BatchNorm
from
.layer.conv
import
Conv3D
from
.layer.conv
import
SubmConv3D
...
...
@@ -23,6 +25,8 @@ from .layer.pooling import MaxPool3D
__all__
=
[
'ReLU'
,
'ReLU6'
,
'LeakyReLU'
,
'Softmax'
,
'BatchNorm'
,
'Conv3D'
,
...
...
python/paddle/incubate/sparse/nn/functional/__init__.py
浏览文件 @
19d9c736
...
...
@@ -17,6 +17,8 @@ from .conv import subm_conv3d # noqa: F401
from
.transformer
import
attention
# noqa: F401
from
.pooling
import
max_pool3d
# noqa: F401
from
.activation
import
relu
# noqa: F401
from
.activation
import
relu6
# noqa: F401
from
.activation
import
leaky_relu
# noqa: F401
from
.activation
import
softmax
# noqa: F401
__all__
=
[
...
...
@@ -24,6 +26,8 @@ __all__ = [
'subm_conv3d'
,
'max_pool3d'
,
'relu'
,
'relu6'
,
'leaky_relu'
,
'softmax'
,
'attention'
,
]
python/paddle/incubate/sparse/nn/functional/activation.py
浏览文件 @
19d9c736
...
...
@@ -21,7 +21,7 @@ from paddle.fluid.framework import dygraph_only
@
dygraph_only
def
relu
(
x
,
name
=
None
):
"""
sparse relu activation, requiring x to be a
sparse coo or sparse csr t
ensor.
sparse relu activation, requiring x to be a
SparseCooTensor or SparseCsrT
ensor.
.. math::
...
...
@@ -39,12 +39,11 @@ def relu(x, name=None):
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
# [0., 0., 1.]
"""
return
_C_ops
.
final_state_sparse_relu
(
x
)
...
...
@@ -52,7 +51,7 @@ def relu(x, name=None):
@
dygraph_only
def
softmax
(
x
,
axis
=-
1
,
name
=
None
):
"""
sparse softmax activation,
x must be SparseCsrTensor or SparseCoo
Tensor.
sparse softmax activation,
requiring x to be a SparseCooTensor or SparseCsr
Tensor.
Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data
...
...
@@ -79,11 +78,8 @@ def softmax(x, axis=-1, name=None):
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
...
...
@@ -106,3 +102,68 @@ def softmax(x, axis=-1, name=None):
"""
return
_C_ops
.
final_state_sparse_softmax
(
x
,
axis
)
@
dygraph_only
def
relu6
(
x
,
name
=
None
):
"""
sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
relu6(x) = min(max(0, x), 6)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu6(sparse_x)
"""
return
_C_ops
.
final_state_sparse_relu6
(
x
,
6.0
)
@
dygraph_only
def
leaky_relu
(
x
,
negative_slope
=
0.01
,
name
=
None
):
"""
sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
leaky\_relu(x)=
\left\{
\b
egin{array}{rcl}
x, & & if \ x >= 0
\\
negative\_slope * x, & & otherwise
\\
\end{array}
\r
ight.
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.leaky_relu(sparse_x, 0.5)
"""
return
_C_ops
.
final_state_sparse_leaky_relu
(
x
,
negative_slope
)
python/paddle/incubate/sparse/nn/layer/activation.py
浏览文件 @
19d9c736
...
...
@@ -20,7 +20,7 @@ __all__ = []
class
ReLU
(
Layer
):
"""
Sparse ReLU Activation.
Sparse ReLU Activation
, requiring x to be a SparseCooTensor or SparseCsrTensor
.
.. math::
...
...
@@ -38,15 +38,12 @@ class ReLU(Layer):
.. code-block:: python
import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
dense_x = paddle.to_tensor(x, dtype='float32')
sparse_dim = 2
sparse_x = dense_x.to_sparse_coo(sparse_dim)
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
relu = paddle.incubate.sparse.nn.ReLU()
out = relu(sparse_x)
#out.values: [0., 2., 0., 4., 5
.]
# [0., 0., 1
.]
"""
def
__init__
(
self
,
name
=
None
):
...
...
@@ -63,7 +60,7 @@ class ReLU(Layer):
class
Softmax
(
Layer
):
"""
sparse softmax activation, x must be SparseCsrTensor or SparseCoo
Tensor.
Sparse Softmax Activation, requiring x to be a SparseCooTensor or SparseCsr
Tensor.
Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data
...
...
@@ -90,11 +87,8 @@ class Softmax(Layer):
import paddle
import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100)
with _test_eager_guard():
mask = np.random.rand(3, 4) < 0.5
np_x = np.random.rand(3, 4) * mask
# [[0. 0. 0.96823406 0.19722934]
...
...
@@ -108,8 +102,8 @@ class Softmax(Layer):
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# 0.98275049])
m
= paddle.incubate.sparse.nn.Softmax()
out = m
(csr)
softmax
= paddle.incubate.sparse.nn.Softmax()
out = softmax
(csr)
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# crows=[0, 2, 5, 6],
# cols=[2, 3, 0, 2, 3, 3],
...
...
@@ -128,3 +122,90 @@ class Softmax(Layer):
def
extra_repr
(
self
):
name_str
=
'name={}'
.
format
(
self
.
_name
)
if
self
.
_name
else
''
return
name_str
class
ReLU6
(
Layer
):
"""
Sparse ReLU6 Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
ReLU(x) = min(max(0,x), 6)
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Sparse Tensor with any shape.
- output: Sparse Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
relu6 = paddle.incubate.sparse.nn.ReLU6()
out = relu6(sparse_x)
"""
def
__init__
(
self
,
name
=
None
):
super
(
ReLU6
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
relu6
(
x
,
self
.
_name
)
def
extra_repr
(
self
):
name_str
=
'name={}'
.
format
(
self
.
_name
)
if
self
.
_name
else
''
return
name_str
class
LeakyReLU
(
Layer
):
"""
Sparse Leaky ReLU Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
LeakyReLU(x)=
\left\{
\b
egin{array}{rcl}
x, & & if \ x >= 0
\\
negative\_slope * x, & & otherwise
\\
\end{array}
\r
ight.
Parameters:
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Sparse Tensor with any shape.
- output: Sparse Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
leaky_relu = paddle.incubate.sparse.nn.LeakyReLU(0.5)
out = leaky_relu(sparse_x)
"""
def
__init__
(
self
,
negative_slope
=
0.01
,
name
=
None
):
super
(
LeakyReLU
,
self
).
__init__
()
self
.
_negative_slope
=
negative_slope
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
leaky_relu
(
x
,
self
.
_negative_slope
,
self
.
_name
)
def
extra_repr
(
self
):
name_str
=
'name={}'
.
format
(
self
.
_name
)
if
self
.
_name
else
''
return
name_str
python/paddle/incubate/sparse/unary.py
浏览文件 @
19d9c736
...
...
@@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
paddle
import
_C_ops
from
paddle.fluid.framework
import
dygraph_only
,
core
,
convert_np_dtype_to_dtype_
__all__
=
[]
_int_dtype_
=
[
core
.
VarDesc
.
VarType
.
UINT8
,
core
.
VarDesc
.
VarType
.
INT8
,
core
.
VarDesc
.
VarType
.
INT16
,
core
.
VarDesc
.
VarType
.
INT32
,
core
.
VarDesc
.
VarType
.
INT64
,
core
.
VarDesc
.
VarType
.
BOOL
,
]
@
dygraph_only
def
sin
(
x
,
name
=
None
):
...
...
@@ -489,10 +500,9 @@ def coalesce(x):
.. code-block:: python
import paddle
from paddle.incubate import sparse
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard():
indices = [[0, 0, 1], [1, 1, 2]]
values = [1.0, 2.0, 3.0]
sp_x = sparse.sparse_coo_tensor(indices, values)
...
...
@@ -503,3 +513,98 @@ def coalesce(x):
#[3.0, 3.0]
"""
return
_C_ops
.
final_state_sparse_coalesce
(
x
)
@
dygraph_only
def
rad2deg
(
x
,
name
=
None
):
"""
Convert each of the elements of input x from angles in radians to degrees,
requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
rad2deg(x) = 180/ \pi * x
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([3.142, 0., -3.142])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.rad2deg(sparse_x)
"""
if
x
.
dtype
in
_int_dtype_
:
x
=
_C_ops
.
final_state_sparse_cast
(
x
,
None
,
core
.
VarDesc
.
VarType
.
FP32
)
return
_C_ops
.
final_state_sparse_scale
(
x
,
180.0
/
np
.
pi
,
0.0
,
True
)
@
dygraph_only
def
deg2rad
(
x
,
name
=
None
):
"""
Convert each of the elements of input x from degrees to angles in radians,
requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
deg2rad(x) = \pi * x / 180
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-180, 0, 180])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.deg2rad(sparse_x)
"""
if
x
.
dtype
in
_int_dtype_
:
x
=
_C_ops
.
final_state_sparse_cast
(
x
,
None
,
core
.
VarDesc
.
VarType
.
FP32
)
return
_C_ops
.
final_state_sparse_scale
(
x
,
np
.
pi
/
180.0
,
0.0
,
True
)
@
dygraph_only
def
expm1
(
x
,
name
=
None
):
"""
Calculate elementwise `exp(x)-1` , requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
out = exp(x) - 1
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.expm1(sparse_x)
"""
return
_C_ops
.
final_state_sparse_expm1
(
x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录