未验证 提交 19d9c736 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Sparse]add sparse unary api(expm1/deg2rad/rad2deg/relu6/leaky_relu) (#44432)

上级 18c77325
...@@ -127,6 +127,24 @@ ...@@ -127,6 +127,24 @@
divide_csr_scalar{sparse_csr -> sparse_csr} divide_csr_scalar{sparse_csr -> sparse_csr}
backward : divide_scalar_grad backward : divide_scalar_grad
- api : expm1
args : (Tensor x)
output : Tensor(out)
kernel :
func : expm1_coo{sparse_coo -> sparse_coo},
expm1_csr{sparse_csr -> sparse_csr}
layout : x
backward : expm1_grad
- api : leaky_relu
args : (Tensor x, float alpha)
output : Tensor(out)
kernel :
func : leaky_relu_coo{sparse_coo -> sparse_coo},
leaky_relu_csr{sparse_csr -> sparse_csr}
layout : x
backward : leaky_relu_grad
- api : log1p - api : log1p
args : (Tensor x) args : (Tensor x)
output : Tensor(out) output : Tensor(out)
...@@ -163,6 +181,15 @@ ...@@ -163,6 +181,15 @@
layout : x layout : x
backward : relu_grad backward : relu_grad
- api : relu6
args : (Tensor x, float threshold)
output : Tensor(out)
kernel :
func : relu6_coo{sparse_coo -> sparse_coo},
relu6_csr{sparse_csr -> sparse_csr}
layout : x
backward : relu6_grad
- api : scale - api : scale
args : (Tensor x, float scale, float bias, bool bias_after_scale) args : (Tensor x, float scale, float bias, bool bias_after_scale)
output : Tensor(out) output : Tensor(out)
......
...@@ -122,6 +122,22 @@ ...@@ -122,6 +122,22 @@
output : Tensor(x_grad) output : Tensor(x_grad)
invoke : divide_scalar(out_grad, scalar) invoke : divide_scalar(out_grad, scalar)
- backward_api : expm1_grad
forward : expm1(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
kernel :
func : expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_api : leaky_relu_grad
forward : leaky_relu(Tensor x, float alpha) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float alpha)
output : Tensor(x_grad)
kernel :
func : leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_api : log1p_grad - backward_api : log1p_grad
forward : log1p(Tensor x) -> Tensor(out) forward : log1p(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad) args : (Tensor x, Tensor out_grad)
...@@ -178,6 +194,14 @@ ...@@ -178,6 +194,14 @@
func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr} pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_api : relu6_grad
forward : relu6(Tensor x, float threshold) -> Tensor(out)
args : (Tensor out, Tensor out_grad, float threshold)
output : Tensor(x_grad)
kernel :
func : relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
- backward_api : relu_grad - backward_api : relu_grad
forward : relu(Tensor x) -> Tensor(out) forward : relu(Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
......
...@@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); ...@@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);
} // namespace phi } // namespace phi
...@@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p) ...@@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_KERNEL(cast_coo_grad, PD_REGISTER_KERNEL(cast_coo_grad,
CPU, CPU,
......
...@@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu) ...@@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_KERNEL(divide_coo_scalar, PD_REGISTER_KERNEL(divide_coo_scalar,
CPU, CPU,
......
...@@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p) ...@@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_KERNEL(cast_coo_grad, PD_REGISTER_KERNEL(cast_coo_grad,
GPU, GPU,
......
...@@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu) ...@@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6)
PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu)
PD_REGISTER_KERNEL(divide_coo_scalar, PD_REGISTER_KERNEL(divide_coo_scalar,
GPU, GPU,
......
...@@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square) ...@@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs)
DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor) DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
template <typename T, typename Context> template <typename T, typename Context>
void CastCooGradKernel(const Context& dev_ctx, void CastCooGradKernel(const Context& dev_ctx,
......
...@@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square) ...@@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square)
DEFINE_SPARSE_UNARY_KERNEL(Log1p) DEFINE_SPARSE_UNARY_KERNEL(Log1p)
DEFINE_SPARSE_UNARY_KERNEL(Relu) DEFINE_SPARSE_UNARY_KERNEL(Relu)
DEFINE_SPARSE_UNARY_KERNEL(Abs) DEFINE_SPARSE_UNARY_KERNEL(Abs)
DEFINE_SPARSE_UNARY_KERNEL(Expm1)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold)
DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha)
template <typename T, typename Context> template <typename T, typename Context>
void ScaleCooKernel(const Context& dev_ctx, void ScaleCooKernel(const Context& dev_ctx,
......
...@@ -123,9 +123,26 @@ class TestSparseUnary(unittest.TestCase): ...@@ -123,9 +123,26 @@ class TestSparseUnary(unittest.TestCase):
self.compare_with_dense(paddle.nn.ReLU(), self.compare_with_dense(paddle.nn.ReLU(),
paddle.incubate.sparse.nn.ReLU()) paddle.incubate.sparse.nn.ReLU())
def test_sparse_relu6(self):
self.compare_with_dense(paddle.nn.ReLU6(),
paddle.incubate.sparse.nn.ReLU6())
def test_sparse_leaky_relu(self):
self.compare_with_dense(paddle.nn.LeakyReLU(0.1),
paddle.incubate.sparse.nn.LeakyReLU(0.1))
def test_sparse_abs(self): def test_sparse_abs(self):
self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs) self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs)
def test_sparse_expm1(self):
self.compare_with_dense(paddle.expm1, paddle.incubate.sparse.expm1)
def test_sparse_deg2rad(self):
self.compare_with_dense(paddle.deg2rad, paddle.incubate.sparse.deg2rad)
def test_sparse_rad2deg(self):
self.compare_with_dense(paddle.rad2deg, paddle.incubate.sparse.rad2deg)
def test_sparse_neg(self): def test_sparse_neg(self):
self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg) self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg)
......
...@@ -31,6 +31,9 @@ from .unary import pow ...@@ -31,6 +31,9 @@ from .unary import pow
from .unary import cast from .unary import cast
from .unary import neg from .unary import neg
from .unary import coalesce from .unary import coalesce
from .unary import deg2rad
from .unary import rad2deg
from .unary import expm1
from .binary import mv from .binary import mv
from .binary import matmul from .binary import matmul
...@@ -62,6 +65,9 @@ __all__ = [ ...@@ -62,6 +65,9 @@ __all__ = [
'pow', 'pow',
'cast', 'cast',
'neg', 'neg',
'deg2rad',
'rad2deg',
'expm1',
'mv', 'mv',
'matmul', 'matmul',
'masked_matmul', 'masked_matmul',
......
...@@ -16,6 +16,8 @@ from . import functional ...@@ -16,6 +16,8 @@ from . import functional
from .layer.activation import ReLU from .layer.activation import ReLU
from .layer.activation import Softmax from .layer.activation import Softmax
from .layer.activation import ReLU6
from .layer.activation import LeakyReLU
from .layer.norm import BatchNorm from .layer.norm import BatchNorm
from .layer.conv import Conv3D from .layer.conv import Conv3D
from .layer.conv import SubmConv3D from .layer.conv import SubmConv3D
...@@ -23,6 +25,8 @@ from .layer.pooling import MaxPool3D ...@@ -23,6 +25,8 @@ from .layer.pooling import MaxPool3D
__all__ = [ __all__ = [
'ReLU', 'ReLU',
'ReLU6',
'LeakyReLU',
'Softmax', 'Softmax',
'BatchNorm', 'BatchNorm',
'Conv3D', 'Conv3D',
......
...@@ -17,6 +17,8 @@ from .conv import subm_conv3d # noqa: F401 ...@@ -17,6 +17,8 @@ from .conv import subm_conv3d # noqa: F401
from .transformer import attention # noqa: F401 from .transformer import attention # noqa: F401
from .pooling import max_pool3d # noqa: F401 from .pooling import max_pool3d # noqa: F401
from .activation import relu # noqa: F401 from .activation import relu # noqa: F401
from .activation import relu6 # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import softmax # noqa: F401 from .activation import softmax # noqa: F401
__all__ = [ __all__ = [
...@@ -24,6 +26,8 @@ __all__ = [ ...@@ -24,6 +26,8 @@ __all__ = [
'subm_conv3d', 'subm_conv3d',
'max_pool3d', 'max_pool3d',
'relu', 'relu',
'relu6',
'leaky_relu',
'softmax', 'softmax',
'attention', 'attention',
] ]
...@@ -21,7 +21,7 @@ from paddle.fluid.framework import dygraph_only ...@@ -21,7 +21,7 @@ from paddle.fluid.framework import dygraph_only
@dygraph_only @dygraph_only
def relu(x, name=None): def relu(x, name=None):
""" """
sparse relu activation, requiring x to be a sparse coo or sparse csr tensor. sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math:: .. math::
...@@ -39,12 +39,11 @@ def relu(x, name=None): ...@@ -39,12 +39,11 @@ def relu(x, name=None):
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard(): dense_x = paddle.to_tensor([-2., 0., 1.])
dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32') sparse_x = dense_x.to_sparse_coo(1)
sparse_x = dense_x.to_sparse_coo(1) out = paddle.incubate.sparse.nn.functional.relu(sparse_x)
out = paddle.incubate.sparse.nn.functional.relu(sparse_x) # [0., 0., 1.]
""" """
return _C_ops.final_state_sparse_relu(x) return _C_ops.final_state_sparse_relu(x)
...@@ -52,7 +51,7 @@ def relu(x, name=None): ...@@ -52,7 +51,7 @@ def relu(x, name=None):
@dygraph_only @dygraph_only
def softmax(x, axis=-1, name=None): def softmax(x, axis=-1, name=None):
""" """
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
Note: Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data Only support axis=-1 for SparseCsrTensor, which is faster when read data
...@@ -79,30 +78,92 @@ def softmax(x, axis=-1, name=None): ...@@ -79,30 +78,92 @@ def softmax(x, axis=-1, name=None):
import paddle import paddle
import numpy as np import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100) paddle.seed(100)
with _test_eager_guard(): mask = np.random.rand(3, 4) < 0.5
mask = np.random.rand(3, 4) < 0.5 np_x = np.random.rand(3, 4) * mask
np_x = np.random.rand(3, 4) * mask # [[0. 0. 0.96823406 0.19722934]
# [[0. 0. 0.96823406 0.19722934] # [0.94373937 0. 0.02060066 0.71456372]
# [0.94373937 0. 0.02060066 0.71456372] # [0. 0. 0. 0.98275049]]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
csr = paddle.to_tensor(np_x).to_sparse_csr() # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, # crows=[0, 2, 5, 6],
# crows=[0, 2, 5, 6], # cols=[2, 3, 0, 2, 3, 3],
# cols=[2, 3, 0, 2, 3, 3], # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, # 0.98275049])
# 0.98275049])
out = paddle.incubate.sparse.nn.functional.softmax(csr)
out = paddle.incubate.sparse.nn.functional.softmax(csr) # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, # crows=[0, 2, 5, 6],
# crows=[0, 2, 5, 6], # cols=[2, 3, 0, 2, 3, 3],
# cols=[2, 3, 0, 2, 3, 3], # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, # 1. ])
# 1. ])
""" """
return _C_ops.final_state_sparse_softmax(x, axis) return _C_ops.final_state_sparse_softmax(x, axis)
@dygraph_only
def relu6(x, name=None):
"""
sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
relu6(x) = min(max(0, x), 6)
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.relu6(sparse_x)
"""
return _C_ops.final_state_sparse_relu6(x, 6.0)
@dygraph_only
def leaky_relu(x, negative_slope=0.01, name=None):
"""
sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
leaky\_relu(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.nn.functional.leaky_relu(sparse_x, 0.5)
"""
return _C_ops.final_state_sparse_leaky_relu(x, negative_slope)
...@@ -20,7 +20,7 @@ __all__ = [] ...@@ -20,7 +20,7 @@ __all__ = []
class ReLU(Layer): class ReLU(Layer):
""" """
Sparse ReLU Activation. Sparse ReLU Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math:: .. math::
...@@ -38,15 +38,12 @@ class ReLU(Layer): ...@@ -38,15 +38,12 @@ class ReLU(Layer):
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.fluid.framework import _test_eager_guard
with _test_eager_guard(): dense_x = paddle.to_tensor([-2., 0., 1.])
x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]] sparse_x = dense_x.to_sparse_coo(1)
dense_x = paddle.to_tensor(x, dtype='float32') relu = paddle.incubate.sparse.nn.ReLU()
sparse_dim = 2 out = relu(sparse_x)
sparse_x = dense_x.to_sparse_coo(sparse_dim) # [0., 0., 1.]
relu = paddle.incubate.sparse.nn.ReLU()
out = relu(sparse_x)
#out.values: [0., 2., 0., 4., 5.]
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -63,7 +60,7 @@ class ReLU(Layer): ...@@ -63,7 +60,7 @@ class ReLU(Layer):
class Softmax(Layer): class Softmax(Layer):
""" """
sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. Sparse Softmax Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
Note: Note:
Only support axis=-1 for SparseCsrTensor, which is faster when read data Only support axis=-1 for SparseCsrTensor, which is faster when read data
...@@ -90,31 +87,28 @@ class Softmax(Layer): ...@@ -90,31 +87,28 @@ class Softmax(Layer):
import paddle import paddle
import numpy as np import numpy as np
from paddle.fluid.framework import _test_eager_guard
paddle.seed(100) paddle.seed(100)
with _test_eager_guard(): mask = np.random.rand(3, 4) < 0.5
mask = np.random.rand(3, 4) < 0.5 np_x = np.random.rand(3, 4) * mask
np_x = np.random.rand(3, 4) * mask # [[0. 0. 0.96823406 0.19722934]
# [[0. 0. 0.96823406 0.19722934] # [0.94373937 0. 0.02060066 0.71456372]
# [0.94373937 0. 0.02060066 0.71456372] # [0. 0. 0. 0.98275049]]
# [0. 0. 0. 0.98275049]]
csr = paddle.to_tensor(np_x).to_sparse_csr()
csr = paddle.to_tensor(np_x).to_sparse_csr() # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, # crows=[0, 2, 5, 6],
# crows=[0, 2, 5, 6], # cols=[2, 3, 0, 2, 3, 3],
# cols=[2, 3, 0, 2, 3, 3], # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372,
# values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, # 0.98275049])
# 0.98275049])
softmax = paddle.incubate.sparse.nn.Softmax()
m = paddle.incubate.sparse.nn.Softmax() out = softmax(csr)
out = m(csr) # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True,
# Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, # crows=[0, 2, 5, 6],
# crows=[0, 2, 5, 6], # cols=[2, 3, 0, 2, 3, 3],
# cols=[2, 3, 0, 2, 3, 3], # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269,
# values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, # 1. ])
# 1. ])
""" """
def __init__(self, axis=-1, name=None): def __init__(self, axis=-1, name=None):
...@@ -128,3 +122,90 @@ class Softmax(Layer): ...@@ -128,3 +122,90 @@ class Softmax(Layer):
def extra_repr(self): def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else '' name_str = 'name={}'.format(self._name) if self._name else ''
return name_str return name_str
class ReLU6(Layer):
"""
Sparse ReLU6 Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
ReLU(x) = min(max(0,x), 6)
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Sparse Tensor with any shape.
- output: Sparse Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 8.])
sparse_x = dense_x.to_sparse_coo(1)
relu6 = paddle.incubate.sparse.nn.ReLU6()
out = relu6(sparse_x)
"""
def __init__(self, name=None):
super(ReLU6, self).__init__()
self._name = name
def forward(self, x):
return F.relu6(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class LeakyReLU(Layer):
"""
Sparse Leaky ReLU Activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
LeakyReLU(x)=
\left\{
\begin{array}{rcl}
x, & & if \ x >= 0 \\
negative\_slope * x, & & otherwise \\
\end{array}
\right.
Parameters:
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Sparse Tensor with any shape.
- output: Sparse Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 5.])
sparse_x = dense_x.to_sparse_coo(1)
leaky_relu = paddle.incubate.sparse.nn.LeakyReLU(0.5)
out = leaky_relu(sparse_x)
"""
def __init__(self, negative_slope=0.01, name=None):
super(LeakyReLU, self).__init__()
self._negative_slope = negative_slope
self._name = name
def forward(self, x):
return F.leaky_relu(x, self._negative_slope, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
...@@ -12,11 +12,22 @@ ...@@ -12,11 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.framework import dygraph_only, core, convert_np_dtype_to_dtype_ from paddle.fluid.framework import dygraph_only, core, convert_np_dtype_to_dtype_
__all__ = [] __all__ = []
_int_dtype_ = [
core.VarDesc.VarType.UINT8,
core.VarDesc.VarType.INT8,
core.VarDesc.VarType.INT16,
core.VarDesc.VarType.INT32,
core.VarDesc.VarType.INT64,
core.VarDesc.VarType.BOOL,
]
@dygraph_only @dygraph_only
def sin(x, name=None): def sin(x, name=None):
...@@ -489,17 +500,111 @@ def coalesce(x): ...@@ -489,17 +500,111 @@ def coalesce(x):
.. code-block:: python .. code-block:: python
import paddle import paddle
from paddle.incubate import sparse from paddle.incubate import sparse
from paddle.fluid.framework import _test_eager_guard
indices = [[0, 0, 1], [1, 1, 2]]
with _test_eager_guard(): values = [1.0, 2.0, 3.0]
indices = [[0, 0, 1], [1, 1, 2]] sp_x = sparse.sparse_coo_tensor(indices, values)
values = [1.0, 2.0, 3.0] sp_x = sparse.coalesce(sp_x)
sp_x = sparse.sparse_coo_tensor(indices, values) print(sp_x.indices())
sp_x = sparse.coalesce(sp_x) #[[0, 1], [1, 2]]
print(sp_x.indices()) print(sp_x.values())
#[[0, 1], [1, 2]] #[3.0, 3.0]
print(sp_x.values())
#[3.0, 3.0]
""" """
return _C_ops.final_state_sparse_coalesce(x) return _C_ops.final_state_sparse_coalesce(x)
@dygraph_only
def rad2deg(x, name=None):
"""
Convert each of the elements of input x from angles in radians to degrees,
requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
rad2deg(x) = 180/ \pi * x
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([3.142, 0., -3.142])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.rad2deg(sparse_x)
"""
if x.dtype in _int_dtype_:
x = _C_ops.final_state_sparse_cast(x, None, core.VarDesc.VarType.FP32)
return _C_ops.final_state_sparse_scale(x, 180.0 / np.pi, 0.0, True)
@dygraph_only
def deg2rad(x, name=None):
"""
Convert each of the elements of input x from degrees to angles in radians,
requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
deg2rad(x) = \pi * x / 180
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-180, 0, 180])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.deg2rad(sparse_x)
"""
if x.dtype in _int_dtype_:
x = _C_ops.final_state_sparse_cast(x, None, core.VarDesc.VarType.FP32)
return _C_ops.final_state_sparse_scale(x, np.pi / 180.0, 0.0, True)
@dygraph_only
def expm1(x, name=None):
"""
Calculate elementwise `exp(x)-1` , requiring x to be a SparseCooTensor or SparseCsrTensor.
.. math::
out = exp(x) - 1
Parameters:
x (Tensor): The input Sparse Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Sparse Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
dense_x = paddle.to_tensor([-2., 0., 1.])
sparse_x = dense_x.to_sparse_coo(1)
out = paddle.incubate.sparse.expm1(sparse_x)
"""
return _C_ops.final_state_sparse_expm1(x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册