From 19d9c73606d74aa7ed78b146ef59b7dc72dbf15e Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 22 Jul 2022 22:32:22 +0800 Subject: [PATCH] [Sparse]add sparse unary api(expm1/deg2rad/rad2deg/relu6/leaky_relu) (#44432) --- paddle/phi/api/yaml/sparse_api.yaml | 27 ++++ paddle/phi/api/yaml/sparse_bw_api.yaml | 24 +++ paddle/phi/kernels/activation_grad_kernel.h | 2 +- .../kernels/sparse/cpu/unary_grad_kernel.cc | 3 + paddle/phi/kernels/sparse/cpu/unary_kernel.cc | 3 + .../kernels/sparse/gpu/unary_grad_kernel.cu | 3 + paddle/phi/kernels/sparse/gpu/unary_kernel.cu | 3 + .../sparse/impl/unary_grad_kernel_impl.h | 3 + .../kernels/sparse/impl/unary_kernel_impl.h | 3 + .../tests/unittests/test_sparse_unary_op.py | 17 ++ python/paddle/incubate/sparse/__init__.py | 6 + python/paddle/incubate/sparse/nn/__init__.py | 4 + .../incubate/sparse/nn/functional/__init__.py | 4 + .../sparse/nn/functional/activation.py | 119 ++++++++++---- .../incubate/sparse/nn/layer/activation.py | 149 ++++++++++++++---- python/paddle/incubate/sparse/unary.py | 127 +++++++++++++-- 16 files changed, 422 insertions(+), 75 deletions(-) diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index 28f35535be..2e44e20142 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -127,6 +127,24 @@ divide_csr_scalar{sparse_csr -> sparse_csr} backward : divide_scalar_grad +- api : expm1 + args : (Tensor x) + output : Tensor(out) + kernel : + func : expm1_coo{sparse_coo -> sparse_coo}, + expm1_csr{sparse_csr -> sparse_csr} + layout : x + backward : expm1_grad + +- api : leaky_relu + args : (Tensor x, float alpha) + output : Tensor(out) + kernel : + func : leaky_relu_coo{sparse_coo -> sparse_coo}, + leaky_relu_csr{sparse_csr -> sparse_csr} + layout : x + backward : leaky_relu_grad + - api : log1p args : (Tensor x) output : Tensor(out) @@ -163,6 +181,15 @@ layout : x backward : relu_grad +- api : relu6 + args : (Tensor x, float threshold) + output : Tensor(out) + kernel : + func : relu6_coo{sparse_coo -> sparse_coo}, + relu6_csr{sparse_csr -> sparse_csr} + layout : x + backward : relu6_grad + - api : scale args : (Tensor x, float scale, float bias, bool bias_after_scale) output : Tensor(out) diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index a39577e7e6..bde86f3816 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -122,6 +122,22 @@ output : Tensor(x_grad) invoke : divide_scalar(out_grad, scalar) +- backward_api : expm1_grad + forward : expm1(Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + kernel : + func : expm1_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + expm1_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + +- backward_api : leaky_relu_grad + forward : leaky_relu(Tensor x, float alpha) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float alpha) + output : Tensor(x_grad) + kernel : + func : leaky_relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + leaky_relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + - backward_api : log1p_grad forward : log1p(Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) @@ -178,6 +194,14 @@ func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr} +- backward_api : relu6_grad + forward : relu6(Tensor x, float threshold) -> Tensor(out) + args : (Tensor out, Tensor out_grad, float threshold) + output : Tensor(x_grad) + kernel : + func : relu6_coo_grad {sparse_coo, sparse_coo -> sparse_coo}, + relu6_csr_grad {sparse_csr, sparse_csr -> sparse_csr} + - backward_api : relu_grad forward : relu(Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 4daa231437..ea33326278 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -240,11 +240,11 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha); +DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold); - DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset); } // namespace phi diff --git a/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc index f8520db2ca..4c993e3a27 100644 --- a/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc @@ -51,6 +51,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow) +PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(expm1, Expm1) +PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu6, Relu6) +PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_KERNEL(cast_coo_grad, CPU, diff --git a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc index 1c1ece27d9..d0df009594 100644 --- a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc @@ -93,6 +93,9 @@ PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale) +PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(expm1, Expm1) +PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu6, Relu6) +PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_KERNEL(divide_coo_scalar, CPU, diff --git a/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu index be0f13fb0e..ef66e91364 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu @@ -53,6 +53,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow) +PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(expm1, Expm1) +PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu6, Relu6) +PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_KERNEL(cast_coo_grad, GPU, diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu index 6358b7b983..b03f508a32 100644 --- a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu @@ -98,6 +98,9 @@ PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow) PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale) +PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(expm1, Expm1) +PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu6, Relu6) +PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_KERNEL(divide_coo_scalar, GPU, diff --git a/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h index ffc5f6bbac..0709e6d946 100644 --- a/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h @@ -93,7 +93,10 @@ DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu) DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs) +DEFINE_SPARSE_UNARY_GRAD_KERNEL(Expm1) DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor) +DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) +DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Relu6, threshold) template void CastCooGradKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h index 2639753266..338bb13d28 100644 --- a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h +++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h @@ -86,7 +86,10 @@ DEFINE_SPARSE_UNARY_KERNEL(Square) DEFINE_SPARSE_UNARY_KERNEL(Log1p) DEFINE_SPARSE_UNARY_KERNEL(Relu) DEFINE_SPARSE_UNARY_KERNEL(Abs) +DEFINE_SPARSE_UNARY_KERNEL(Expm1) DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor) +DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Relu6, threshold) +DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(LeakyRelu, alpha) template void ScaleCooKernel(const Context& dev_ctx, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py index 36d64f5067..d67fe0b7d5 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py @@ -123,9 +123,26 @@ class TestSparseUnary(unittest.TestCase): self.compare_with_dense(paddle.nn.ReLU(), paddle.incubate.sparse.nn.ReLU()) + def test_sparse_relu6(self): + self.compare_with_dense(paddle.nn.ReLU6(), + paddle.incubate.sparse.nn.ReLU6()) + + def test_sparse_leaky_relu(self): + self.compare_with_dense(paddle.nn.LeakyReLU(0.1), + paddle.incubate.sparse.nn.LeakyReLU(0.1)) + def test_sparse_abs(self): self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs) + def test_sparse_expm1(self): + self.compare_with_dense(paddle.expm1, paddle.incubate.sparse.expm1) + + def test_sparse_deg2rad(self): + self.compare_with_dense(paddle.deg2rad, paddle.incubate.sparse.deg2rad) + + def test_sparse_rad2deg(self): + self.compare_with_dense(paddle.rad2deg, paddle.incubate.sparse.rad2deg) + def test_sparse_neg(self): self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg) diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py index 6a672cb494..8408c3ca27 100644 --- a/python/paddle/incubate/sparse/__init__.py +++ b/python/paddle/incubate/sparse/__init__.py @@ -31,6 +31,9 @@ from .unary import pow from .unary import cast from .unary import neg from .unary import coalesce +from .unary import deg2rad +from .unary import rad2deg +from .unary import expm1 from .binary import mv from .binary import matmul @@ -62,6 +65,9 @@ __all__ = [ 'pow', 'cast', 'neg', + 'deg2rad', + 'rad2deg', + 'expm1', 'mv', 'matmul', 'masked_matmul', diff --git a/python/paddle/incubate/sparse/nn/__init__.py b/python/paddle/incubate/sparse/nn/__init__.py index 0df4d02414..1d58897537 100644 --- a/python/paddle/incubate/sparse/nn/__init__.py +++ b/python/paddle/incubate/sparse/nn/__init__.py @@ -16,6 +16,8 @@ from . import functional from .layer.activation import ReLU from .layer.activation import Softmax +from .layer.activation import ReLU6 +from .layer.activation import LeakyReLU from .layer.norm import BatchNorm from .layer.conv import Conv3D from .layer.conv import SubmConv3D @@ -23,6 +25,8 @@ from .layer.pooling import MaxPool3D __all__ = [ 'ReLU', + 'ReLU6', + 'LeakyReLU', 'Softmax', 'BatchNorm', 'Conv3D', diff --git a/python/paddle/incubate/sparse/nn/functional/__init__.py b/python/paddle/incubate/sparse/nn/functional/__init__.py index 21939eeb1a..3e8ff4ba50 100644 --- a/python/paddle/incubate/sparse/nn/functional/__init__.py +++ b/python/paddle/incubate/sparse/nn/functional/__init__.py @@ -17,6 +17,8 @@ from .conv import subm_conv3d # noqa: F401 from .transformer import attention # noqa: F401 from .pooling import max_pool3d # noqa: F401 from .activation import relu # noqa: F401 +from .activation import relu6 # noqa: F401 +from .activation import leaky_relu # noqa: F401 from .activation import softmax # noqa: F401 __all__ = [ @@ -24,6 +26,8 @@ __all__ = [ 'subm_conv3d', 'max_pool3d', 'relu', + 'relu6', + 'leaky_relu', 'softmax', 'attention', ] diff --git a/python/paddle/incubate/sparse/nn/functional/activation.py b/python/paddle/incubate/sparse/nn/functional/activation.py index dc29694240..2305abc8d5 100644 --- a/python/paddle/incubate/sparse/nn/functional/activation.py +++ b/python/paddle/incubate/sparse/nn/functional/activation.py @@ -21,7 +21,7 @@ from paddle.fluid.framework import dygraph_only @dygraph_only def relu(x, name=None): """ - sparse relu activation, requiring x to be a sparse coo or sparse csr tensor. + sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor. .. math:: @@ -39,12 +39,11 @@ def relu(x, name=None): .. code-block:: python import paddle - from paddle.fluid.framework import _test_eager_guard - with _test_eager_guard(): - dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32') - sparse_x = dense_x.to_sparse_coo(1) - out = paddle.incubate.sparse.nn.functional.relu(sparse_x) + dense_x = paddle.to_tensor([-2., 0., 1.]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.nn.functional.relu(sparse_x) + # [0., 0., 1.] """ return _C_ops.final_state_sparse_relu(x) @@ -52,7 +51,7 @@ def relu(x, name=None): @dygraph_only def softmax(x, axis=-1, name=None): """ - sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. + sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor. Note: Only support axis=-1 for SparseCsrTensor, which is faster when read data @@ -79,30 +78,92 @@ def softmax(x, axis=-1, name=None): import paddle import numpy as np - from paddle.fluid.framework import _test_eager_guard - paddle.seed(100) - with _test_eager_guard(): - mask = np.random.rand(3, 4) < 0.5 - np_x = np.random.rand(3, 4) * mask - # [[0. 0. 0.96823406 0.19722934] - # [0.94373937 0. 0.02060066 0.71456372] - # [0. 0. 0. 0.98275049]] - - csr = paddle.to_tensor(np_x).to_sparse_csr() - # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 5, 6], - # cols=[2, 3, 0, 2, 3, 3], - # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, - # 0.98275049]) - - out = paddle.incubate.sparse.nn.functional.softmax(csr) - # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 5, 6], - # cols=[2, 3, 0, 2, 3, 3], - # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, - # 1. ]) + mask = np.random.rand(3, 4) < 0.5 + np_x = np.random.rand(3, 4) * mask + # [[0. 0. 0.96823406 0.19722934] + # [0.94373937 0. 0.02060066 0.71456372] + # [0. 0. 0. 0.98275049]] + + csr = paddle.to_tensor(np_x).to_sparse_csr() + # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 5, 6], + # cols=[2, 3, 0, 2, 3, 3], + # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, + # 0.98275049]) + + out = paddle.incubate.sparse.nn.functional.softmax(csr) + # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 5, 6], + # cols=[2, 3, 0, 2, 3, 3], + # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, + # 1. ]) """ return _C_ops.final_state_sparse_softmax(x, axis) + + +@dygraph_only +def relu6(x, name=None): + """ + sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + relu6(x) = min(max(0, x), 6) + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-2., 0., 8.]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.nn.functional.relu6(sparse_x) + """ + return _C_ops.final_state_sparse_relu6(x, 6.0) + + +@dygraph_only +def leaky_relu(x, negative_slope=0.01, name=None): + """ + sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + leaky\_relu(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + negative\_slope * x, & & otherwise \\ + \end{array} + \right. + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64. + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-2., 0., 5.]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.nn.functional.leaky_relu(sparse_x, 0.5) + """ + return _C_ops.final_state_sparse_leaky_relu(x, negative_slope) diff --git a/python/paddle/incubate/sparse/nn/layer/activation.py b/python/paddle/incubate/sparse/nn/layer/activation.py index 9aec20603a..da374fa87a 100644 --- a/python/paddle/incubate/sparse/nn/layer/activation.py +++ b/python/paddle/incubate/sparse/nn/layer/activation.py @@ -20,7 +20,7 @@ __all__ = [] class ReLU(Layer): """ - Sparse ReLU Activation. + Sparse ReLU Activation, requiring x to be a SparseCooTensor or SparseCsrTensor. .. math:: @@ -38,15 +38,12 @@ class ReLU(Layer): .. code-block:: python import paddle - from paddle.fluid.framework import _test_eager_guard - with _test_eager_guard(): - x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]] - dense_x = paddle.to_tensor(x, dtype='float32') - sparse_dim = 2 - sparse_x = dense_x.to_sparse_coo(sparse_dim) - relu = paddle.incubate.sparse.nn.ReLU() - out = relu(sparse_x) - #out.values: [0., 2., 0., 4., 5.] + + dense_x = paddle.to_tensor([-2., 0., 1.]) + sparse_x = dense_x.to_sparse_coo(1) + relu = paddle.incubate.sparse.nn.ReLU() + out = relu(sparse_x) + # [0., 0., 1.] """ def __init__(self, name=None): @@ -63,7 +60,7 @@ class ReLU(Layer): class Softmax(Layer): """ - sparse softmax activation, x must be SparseCsrTensor or SparseCooTensor. + Sparse Softmax Activation, requiring x to be a SparseCooTensor or SparseCsrTensor. Note: Only support axis=-1 for SparseCsrTensor, which is faster when read data @@ -90,31 +87,28 @@ class Softmax(Layer): import paddle import numpy as np - from paddle.fluid.framework import _test_eager_guard - paddle.seed(100) - with _test_eager_guard(): - mask = np.random.rand(3, 4) < 0.5 - np_x = np.random.rand(3, 4) * mask - # [[0. 0. 0.96823406 0.19722934] - # [0.94373937 0. 0.02060066 0.71456372] - # [0. 0. 0. 0.98275049]] - - csr = paddle.to_tensor(np_x).to_sparse_csr() - # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 5, 6], - # cols=[2, 3, 0, 2, 3, 3], - # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, - # 0.98275049]) - - m = paddle.incubate.sparse.nn.Softmax() - out = m(csr) - # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, - # crows=[0, 2, 5, 6], - # cols=[2, 3, 0, 2, 3, 3], - # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, - # 1. ]) + mask = np.random.rand(3, 4) < 0.5 + np_x = np.random.rand(3, 4) * mask + # [[0. 0. 0.96823406 0.19722934] + # [0.94373937 0. 0.02060066 0.71456372] + # [0. 0. 0. 0.98275049]] + + csr = paddle.to_tensor(np_x).to_sparse_csr() + # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 5, 6], + # cols=[2, 3, 0, 2, 3, 3], + # values=[0.96823406, 0.19722934, 0.94373937, 0.02060066, 0.71456372, + # 0.98275049]) + + softmax = paddle.incubate.sparse.nn.Softmax() + out = softmax(csr) + # Tensor(shape=[3, 4], dtype=paddle.float64, place=Place(gpu:0), stop_gradient=True, + # crows=[0, 2, 5, 6], + # cols=[2, 3, 0, 2, 3, 3], + # values=[0.68373820, 0.31626180, 0.45610887, 0.18119845, 0.36269269, + # 1. ]) """ def __init__(self, axis=-1, name=None): @@ -128,3 +122,90 @@ class Softmax(Layer): def extra_repr(self): name_str = 'name={}'.format(self._name) if self._name else '' return name_str + + +class ReLU6(Layer): + """ + Sparse ReLU6 Activation, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + ReLU(x) = min(max(0,x), 6) + + Parameters: + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Sparse Tensor with any shape. + - output: Sparse Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-2., 0., 8.]) + sparse_x = dense_x.to_sparse_coo(1) + relu6 = paddle.incubate.sparse.nn.ReLU6() + out = relu6(sparse_x) + """ + + def __init__(self, name=None): + super(ReLU6, self).__init__() + self._name = name + + def forward(self, x): + return F.relu6(x, self._name) + + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str + + +class LeakyReLU(Layer): + """ + Sparse Leaky ReLU Activation, requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + LeakyReLU(x)= + \left\{ + \begin{array}{rcl} + x, & & if \ x >= 0 \\ + negative\_slope * x, & & otherwise \\ + \end{array} + \right. + + Parameters: + negative_slope (float, optional): Slope of the activation function at + :math:`x < 0` . Default is 0.01. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Sparse Tensor with any shape. + - output: Sparse Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-2., 0., 5.]) + sparse_x = dense_x.to_sparse_coo(1) + leaky_relu = paddle.incubate.sparse.nn.LeakyReLU(0.5) + out = leaky_relu(sparse_x) + """ + + def __init__(self, negative_slope=0.01, name=None): + super(LeakyReLU, self).__init__() + self._negative_slope = negative_slope + self._name = name + + def forward(self, x): + return F.leaky_relu(x, self._negative_slope, self._name) + + def extra_repr(self): + name_str = 'name={}'.format(self._name) if self._name else '' + return name_str diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py index 1725c8791f..ae55a5b9ab 100644 --- a/python/paddle/incubate/sparse/unary.py +++ b/python/paddle/incubate/sparse/unary.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from paddle import _C_ops from paddle.fluid.framework import dygraph_only, core, convert_np_dtype_to_dtype_ __all__ = [] +_int_dtype_ = [ + core.VarDesc.VarType.UINT8, + core.VarDesc.VarType.INT8, + core.VarDesc.VarType.INT16, + core.VarDesc.VarType.INT32, + core.VarDesc.VarType.INT64, + core.VarDesc.VarType.BOOL, +] + @dygraph_only def sin(x, name=None): @@ -489,17 +500,111 @@ def coalesce(x): .. code-block:: python import paddle + from paddle.incubate import sparse - from paddle.fluid.framework import _test_eager_guard - - with _test_eager_guard(): - indices = [[0, 0, 1], [1, 1, 2]] - values = [1.0, 2.0, 3.0] - sp_x = sparse.sparse_coo_tensor(indices, values) - sp_x = sparse.coalesce(sp_x) - print(sp_x.indices()) - #[[0, 1], [1, 2]] - print(sp_x.values()) - #[3.0, 3.0] + + indices = [[0, 0, 1], [1, 1, 2]] + values = [1.0, 2.0, 3.0] + sp_x = sparse.sparse_coo_tensor(indices, values) + sp_x = sparse.coalesce(sp_x) + print(sp_x.indices()) + #[[0, 1], [1, 2]] + print(sp_x.values()) + #[3.0, 3.0] """ return _C_ops.final_state_sparse_coalesce(x) + + +@dygraph_only +def rad2deg(x, name=None): + """ + Convert each of the elements of input x from angles in radians to degrees, + requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + rad2deg(x) = 180/ \pi * x + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([3.142, 0., -3.142]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.rad2deg(sparse_x) + + """ + if x.dtype in _int_dtype_: + x = _C_ops.final_state_sparse_cast(x, None, core.VarDesc.VarType.FP32) + return _C_ops.final_state_sparse_scale(x, 180.0 / np.pi, 0.0, True) + + +@dygraph_only +def deg2rad(x, name=None): + """ + Convert each of the elements of input x from degrees to angles in radians, + requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + deg2rad(x) = \pi * x / 180 + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-180, 0, 180]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.deg2rad(sparse_x) + + """ + if x.dtype in _int_dtype_: + x = _C_ops.final_state_sparse_cast(x, None, core.VarDesc.VarType.FP32) + return _C_ops.final_state_sparse_scale(x, np.pi / 180.0, 0.0, True) + + +@dygraph_only +def expm1(x, name=None): + """ + Calculate elementwise `exp(x)-1` , requiring x to be a SparseCooTensor or SparseCsrTensor. + + .. math:: + + out = exp(x) - 1 + + Parameters: + x (Tensor): The input Sparse Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Sparse Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + + dense_x = paddle.to_tensor([-2., 0., 1.]) + sparse_x = dense_x.to_sparse_coo(1) + out = paddle.incubate.sparse.expm1(sparse_x) + """ + return _C_ops.final_state_sparse_expm1(x) -- GitLab