From 682acd2224c415afdd2d3353917ca4712963b7e8 Mon Sep 17 00:00:00 2001
From: Zhou Wei <1183042833@qq.com>
Date: Tue, 12 Jul 2022 14:39:12 +0800
Subject: [PATCH] [Sparse]add sparse unary
 api(sin/tan/pow/neg/log1p/square/cast...) (#44022)

---
 paddle/fluid/pybind/eager_method.cc           |  25 ++
 .../api/yaml/generator/sparse_bw_api_gen.py   |   1 +
 paddle/phi/api/yaml/sparse_api.yaml           | 168 ++++++-
 paddle/phi/api/yaml/sparse_bw_api.yaml        | 129 +++++-
 paddle/phi/kernels/activation_grad_kernel.h   |   8 +
 paddle/phi/kernels/activation_kernel.h        |  27 +-
 .../phi/kernels/funcs/eigen/eigen_function.h  |  12 +
 paddle/phi/kernels/funcs/eigen/elementwise.cc |  17 +
 paddle/phi/kernels/funcs/eigen/elementwise.cu |  17 +
 .../kernels/sparse/cpu/unary_grad_kernel.cc   |  79 ++++
 paddle/phi/kernels/sparse/cpu/unary_kernel.cc | 139 ++++++
 .../kernels/sparse/gpu/unary_grad_kernel.cu   |  79 ++++
 paddle/phi/kernels/sparse/gpu/unary_kernel.cu | 142 ++++++
 .../sparse/impl/unary_grad_kernel_impl.h      | 141 ++++++
 .../kernels/sparse/impl/unary_kernel_impl.h   | 207 +++++++++
 .../phi/kernels/sparse/unary_grad_kernel.cc   | 183 --------
 paddle/phi/kernels/sparse/unary_grad_kernel.h |  68 ++-
 paddle/phi/kernels/sparse/unary_kernel.cc     | 177 --------
 paddle/phi/kernels/sparse/unary_kernel.h      |  95 +++-
 .../kernels/test_sparse_activation_dev_api.cc |   4 +-
 .../unittests/test_sparse_elementwise_op.py   |  10 +-
 .../tests/unittests/test_sparse_model.py      |   4 +
 .../tests/unittests/test_sparse_unary_op.py   | 237 +++++-----
 .../tests/unittests/test_sparse_utils_op.py   |   4 +-
 python/paddle/incubate/sparse/__init__.py     |  37 +-
 python/paddle/incubate/sparse/binary.py       | 199 ++++++++-
 python/paddle/incubate/sparse/math.py         | 260 -----------
 python/paddle/incubate/sparse/unary.py        | 413 ++++++++++++++++--
 28 files changed, 2036 insertions(+), 846 deletions(-)
 create mode 100644 paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
 create mode 100644 paddle/phi/kernels/sparse/cpu/unary_kernel.cc
 create mode 100644 paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
 create mode 100644 paddle/phi/kernels/sparse/gpu/unary_kernel.cu
 create mode 100644 paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
 create mode 100644 paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
 delete mode 100644 paddle/phi/kernels/sparse/unary_grad_kernel.cc
 delete mode 100644 paddle/phi/kernels/sparse/unary_kernel.cc
 delete mode 100644 python/paddle/incubate/sparse/math.py

diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc
index 77e1962911..086c15dafd 100644
--- a/paddle/fluid/pybind/eager_method.cc
+++ b/paddle/fluid/pybind/eager_method.cc
@@ -1473,6 +1473,27 @@ static PyObject* tensor_method_get_map_tensor(TensorObject* self,
   EAGER_CATCH_AND_THROW_RETURN_NULL
 }
 
+static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
+                                                 PyObject* args,
+                                                 PyObject* kwargs) {
+  EAGER_TRY
+  PADDLE_ENFORCE(
+      self->tensor.is_sparse_coo_tensor() ||
+          self->tensor.is_sparse_csr_tensor(),
+      paddle::platform::errors::Fatal("this method is only effective for "
+                                      "SparseCooTensor or SparseCsrTensor"));
+  if (self->tensor.is_sparse_coo_tensor()) {
+    auto sparse_coo_tensor =
+        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
+    return ToPyObject(sparse_coo_tensor->nnz());
+  } else {
+    auto sparse_csr_tensor =
+        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
+    return ToPyObject(sparse_csr_tensor->nnz());
+  }
+  EAGER_CATCH_AND_THROW_RETURN_NULL
+}
+
 static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
@@ -1962,6 +1983,10 @@ PyMethodDef variable_methods[] = {
      METH_VARARGS | METH_KEYWORDS,
      NULL},
     /***the method of sparse tensor****/
+    {"nnz",
+     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_nums,
+     METH_VARARGS | METH_KEYWORDS,
+     NULL},
     {"indices",
      (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
      METH_VARARGS | METH_KEYWORDS,
diff --git a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py
index e30c5e3c5d..f3172a23cb 100644
--- a/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py
+++ b/paddle/phi/api/yaml/generator/sparse_bw_api_gen.py
@@ -109,6 +109,7 @@ def source_include(header_file_path):
 
 #include "glog/logging.h"
 
+#include "paddle/phi/api/include/sparse_api.h"
 #include "paddle/phi/api/lib/api_gen_utils.h"
 #include "paddle/phi/api/lib/kernel_dispatch.h"
 #include "paddle/phi/api/lib/sparse_api_custom_impl.h"
diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml
index 68c41d50ae..d8c275ff1f 100644
--- a/paddle/phi/api/yaml/sparse_api.yaml
+++ b/paddle/phi/api/yaml/sparse_api.yaml
@@ -1,12 +1,85 @@
+- api : abs
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : abs_coo{sparse_coo -> sparse_coo},
+           abs_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : abs_grad
+
+- api : acos
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : acos_coo{sparse_coo -> sparse_coo},
+           acos_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : acos_grad
+
+- api : acosh
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : acosh_coo{sparse_coo -> sparse_coo},
+           acosh_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : acosh_grad
+
 - api : add
   args : (Tensor x, Tensor y)
   output : Tensor(out)
   kernel :
-    func : add_coo_coo{sparse_coo -> sparse_coo},
-           add_csr_csr{sparse_csr -> sparse_csr}
+    func : add_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
+           add_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
     layout : x
   backward : add_grad
 
+- api : asin
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : asin_coo{sparse_coo -> sparse_coo},
+           asin_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : asin_grad
+
+- api : asinh
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : asinh_coo{sparse_coo -> sparse_coo},
+           asinh_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : asinh_grad
+
+- api : atan
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : atan_coo{sparse_coo -> sparse_coo},
+           atan_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : atan_grad
+
+- api : atanh
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : atanh_coo{sparse_coo -> sparse_coo},
+           atanh_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : atanh_grad
+
+- api : cast
+  args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED)
+  output : Tensor(out)
+  kernel :
+    func : cast_coo{sparse_coo -> sparse_coo},
+           cast_csr{sparse_csr -> sparse_csr}
+    layout : x
+    data_type : x
+  backward : cast_grad
+
 - api : conv3d
   args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
   output : Tensor(out), Tensor(rulebook)
@@ -41,38 +114,81 @@
   args : (Tensor x, Tensor y)
   output : Tensor(out)
   kernel :
-    func : divide_coo_coo{sparse_coo -> sparse_coo},
-           divide_csr_csr{sparse_csr -> sparse_csr}
+    func : divide_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
+           divide_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
     layout : x
   backward : divide_grad
 
+- api : divide_scalar
+  args : (Tensor x, float scalar)
+  output : Tensor(out)
+  kernel :
+    func : divide_coo_scalar{sparse_coo -> sparse_coo},
+           divide_csr_scalar{sparse_csr -> sparse_csr}
+  backward : divide_scalar_grad
+
+- api : log1p
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : log1p_coo{sparse_coo -> sparse_coo},
+           log1p_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : log1p_grad
+
 - api : multiply
   args : (Tensor x, Tensor y)
   output : Tensor(out)
   kernel :
-    func : multiply_coo_coo{sparse_coo -> sparse_coo},
-           multiply_csr_csr{sparse_csr -> sparse_csr}
+    func : multiply_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
+           multiply_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
     layout : x
   backward : multiply_grad
 
+- api : pow
+  args : (Tensor x, float factor)
+  output : Tensor(out)
+  kernel :
+    func : pow_coo{sparse_coo -> sparse_coo},
+           pow_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : pow_grad
+
 - api : relu
   args : (Tensor x)
   output : Tensor(out)
   kernel :
-    func : sparse_coo_relu{sparse_coo -> sparse_coo},
-           sparse_csr_relu{sparse_csr -> sparse_csr}
+    func : relu_coo{sparse_coo -> sparse_coo},
+           relu_csr{sparse_csr -> sparse_csr}
     layout : x
   backward : relu_grad
 
+- api : scale
+  args : (Tensor x, float scale, float bias, bool bias_after_scale)
+  output : Tensor(out)
+  kernel :
+    func : scale_coo{sparse_coo -> sparse_coo},
+           scale_csr{sparse_csr -> sparse_csr}
+  backward : scale_grad
+
 - api : sin
   args : (Tensor x)
-  output : Tensor(out@SparseCooTensor)
+  output : Tensor(out)
   kernel :
-    func : sparse_coo_sin {sparse_coo -> sparse_coo},
-           sparse_csr_sin {sparse_csr -> sparse_csr}
+    func : sin_coo{sparse_coo -> sparse_coo},
+           sin_csr{sparse_csr -> sparse_csr}
     layout : x
   backward : sin_grad
 
+- api : sinh
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : sinh_coo{sparse_coo -> sparse_coo},
+           sinh_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : sinh_grad
+
 - api : softmax
   args : (Tensor x, int axis=-1)
   output : Tensor(out)
@@ -85,26 +201,44 @@
   args : (Tensor x)
   output : Tensor(out)
   kernel :
-    func : sparse_coo_sqrt{sparse_coo -> sparse_coo},
-           sparse_csr_sqrt{sparse_csr -> sparse_csr}
+    func : sqrt_coo{sparse_coo -> sparse_coo},
+           sqrt_csr{sparse_csr -> sparse_csr}
     layout : x
   backward : sqrt_grad
 
+- api : square
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : square_coo{sparse_coo -> sparse_coo},
+           square_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : square_grad
+
 - api : subtract
   args : (Tensor x, Tensor y)
   output : Tensor(out)
   kernel :
-    func : subtract_coo_coo{sparse_coo -> sparse_coo},
-           subtract_csr_csr{sparse_csr -> sparse_csr}
+    func : subtract_coo_coo{sparse_coo, sparse_coo -> sparse_coo},
+           subtract_csr_csr{sparse_csr, sparse_csr -> sparse_csr}
     layout : x
   backward : subtract_grad
 
+- api : tan
+  args : (Tensor x)
+  output : Tensor(out)
+  kernel :
+    func : tan_coo{sparse_coo -> sparse_coo},
+           tan_csr{sparse_csr -> sparse_csr}
+    layout : x
+  backward : tan_grad
+
 - api : tanh
   args : (Tensor x)
   output : Tensor(out)
   kernel :
-    func : sparse_coo_tanh{sparse_coo -> sparse_coo},
-           sparse_csr_tanh{sparse_csr -> sparse_csr}
+    func : tanh_coo{sparse_coo -> sparse_coo},
+           tanh_csr{sparse_csr -> sparse_csr}
     layout : x
   backward : tanh_grad
 
diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml
index 0ca9c9daa9..220d45cadc 100644
--- a/paddle/phi/api/yaml/sparse_bw_api.yaml
+++ b/paddle/phi/api/yaml/sparse_bw_api.yaml
@@ -1,3 +1,27 @@
+- backward_api : abs_grad
+  forward : tanh(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : abs_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           abs_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : acos_grad
+  forward : acos(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : acos_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           acos_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : acosh_grad
+  forward : acosh(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : acosh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           acosh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
 - backward_api : add_grad
   forward : add(Tensor x, Tensor y) -> Tensor(out)
   args : (Tensor x, Tensor y, Tensor out_grad)
@@ -6,6 +30,47 @@
     func : add_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
            add_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
 
+- backward_api : asin_grad
+  forward : asin(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : asin_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           asin_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : asinh_grad
+  forward : asinh(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : asinh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           asinh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : atan_grad
+  forward : atan(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : atan_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           atan_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : atanh_grad
+  forward : atanh(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : atanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           atanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : cast_grad
+  forward : cast(Tensor x, DataType index_dtype, DataType value_dtype) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad, DataType value_dtype)
+  output : Tensor(x_grad)
+  kernel :
+    func : cast_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+    data_type : out_grad
+
 - backward_api : conv3d_grad
   forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor)
   args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
@@ -41,6 +106,20 @@
     func : divide_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
            divide_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
 
+- backward_api : divide_scalar_grad
+  forward : divide_scalar (Tensor x, float scalar) -> Tensor(out)
+  args : (Tensor out_grad, float scalar)
+  output : Tensor(x_grad)
+  invoke : divide_scalar(out_grad, scalar)
+
+- backward_api : log1p_grad
+  forward : log1p(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : log1p_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           log1p_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
 - backward_api : masked_matmul_grad
   forward : masked_matmul(Tensor x, Tensor y, Tensor mask) -> Tensor(out)
   args : (Tensor x, Tensor y, Tensor out_grad)
@@ -71,19 +150,43 @@
     func : mv_coo_grad{sparse_coo, dense, dense -> sparse_coo, dense},
            mv_csr_grad{sparse_csr, dense, dense -> sparse_csr, dense}
 
+- backward_api : pow_grad
+  forward : pow(Tensor x, float factor) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad, float factor)
+  output : Tensor(x_grad)
+  kernel :
+    func : pow_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           pow_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
 - backward_api : relu_grad
   forward : relu(Tensor x) -> Tensor(out)
   args : (Tensor out, Tensor out_grad)
   output : Tensor(x_grad)
   kernel :
-    func : sparse_coo_relu_grad {sparse_coo, sparse_coo -> sparse_coo}
+    func : relu_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           relu_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : scale_grad
+  forward : scale(Tensor x, float scale, float bias, bool bias_after_scale) -> Tensor(out)
+  args : (Tensor out_grad, float scale)
+  output : Tensor(x_grad)
+  invoke : scale(out_grad, scale, 0.0, true)
 
 - backward_api : sin_grad
   forward : sin(Tensor x) -> Tensor(out)
   args : (Tensor x, Tensor out_grad)
   output : Tensor(x_grad)
   kernel :
-    func : sparse_coo_sin_grad {sparse_coo, sparse_coo -> sparse_coo}
+    func : sin_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           sin_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : sinh_grad
+  forward : sinh(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : sinh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           sinh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
 
 - backward_api : softmax_grad
   forward : softmax(Tensor x, int axis=-1) -> Tensor(out)
@@ -104,7 +207,16 @@
   args : (Tensor out, Tensor out_grad)
   output : Tensor(x_grad)
   kernel :
-    func : sparse_coo_sqrt_grad {sparse_coo, sparse_coo -> sparse_coo}
+    func : sqrt_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           sqrt_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
+- backward_api : square_grad
+  forward : square(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : square_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           square_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
 
 - backward_api : subtract_grad
   forward : subtract(Tensor x, Tensor y) -> Tensor(out)
@@ -114,12 +226,21 @@
     func : subtract_coo_coo_grad{sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo},
            subtract_csr_csr_grad{sparse_csr, sparse_csr, sparse_csr -> sparse_csr, sparse_csr}
 
+- backward_api : tan_grad
+  forward : tan(Tensor x) -> Tensor(out)
+  args : (Tensor x, Tensor out_grad)
+  output : Tensor(x_grad)
+  kernel :
+    func : tan_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           tan_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
+
 - backward_api : tanh_grad
   forward : tanh(Tensor x) -> Tensor(out)
   args : (Tensor out, Tensor out_grad)
   output : Tensor(x_grad)
   kernel :
-    func : sparse_coo_tanh_grad {sparse_coo, sparse_coo -> sparse_coo}
+    func : tanh_coo_grad {sparse_coo, sparse_coo -> sparse_coo},
+           tanh_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
 
 - backward_api : values_grad
   forward : coo_values(Tensor x) -> Tensor(out)
diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h
index 8e63a0fd22..4daa231437 100644
--- a/paddle/phi/kernels/activation_grad_kernel.h
+++ b/paddle/phi/kernels/activation_grad_kernel.h
@@ -212,12 +212,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acosh);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Atanh);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Silu);
+DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Square);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log2);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log10);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log1p);
 
+DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp);
+DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1);
+DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal);
+DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh);
 DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid);
@@ -233,9 +238,12 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
 DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold);
 DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
 DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
+DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, threshold);
 DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
 
 DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
+DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b);
+DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold);
 
 DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset);
 
diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h
index 5cc4357c93..8e5913e10f 100644
--- a/paddle/phi/kernels/activation_kernel.h
+++ b/paddle/phi/kernels/activation_kernel.h
@@ -40,12 +40,12 @@ namespace phi {
                     float attr2,                                     \
                     DenseTensor* out);
 
+DECLARE_ACTIVATION_KERNEL(Sin)
 DECLARE_ACTIVATION_KERNEL(Cos)
 DECLARE_ACTIVATION_KERNEL(Tan)
-DECLARE_ACTIVATION_KERNEL(Acos)
-DECLARE_ACTIVATION_KERNEL(Sin)
 DECLARE_ACTIVATION_KERNEL(Asin)
 DECLARE_ACTIVATION_KERNEL(Atan)
+DECLARE_ACTIVATION_KERNEL(Acos)
 DECLARE_ACTIVATION_KERNEL(Sinh)
 DECLARE_ACTIVATION_KERNEL(Cosh)
 DECLARE_ACTIVATION_KERNEL(Asinh)
@@ -53,15 +53,14 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
 DECLARE_ACTIVATION_KERNEL(Atanh)
 DECLARE_ACTIVATION_KERNEL(Relu)
 DECLARE_ACTIVATION_KERNEL(Tanh)
+DECLARE_ACTIVATION_KERNEL(TanhShrink)
+DECLARE_ACTIVATION_KERNEL(Silu)
 DECLARE_ACTIVATION_KERNEL(Exp)
 DECLARE_ACTIVATION_KERNEL(Expm1)
 DECLARE_ACTIVATION_KERNEL(Reciprocal)
 DECLARE_ACTIVATION_KERNEL(Square)
 DECLARE_ACTIVATION_KERNEL(Sqrt)
 DECLARE_ACTIVATION_KERNEL(Rsqrt)
-
-DECLARE_ACTIVATION_KERNEL(TanhShrink)
-DECLARE_ACTIVATION_KERNEL(Silu)
 DECLARE_ACTIVATION_KERNEL(Sigmoid)
 DECLARE_ACTIVATION_KERNEL(LogSigmoid)
 DECLARE_ACTIVATION_KERNEL(Log)
@@ -77,28 +76,18 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, threshold)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
+DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
+DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta)
 DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Celu, alpha)
+DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Logit, eps)
 
 DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
 DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
-DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
-
 DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold)
-
-template <typename T, typename Context>
-void LogitKernel(const Context& dev_ctx,
-                 const DenseTensor& x,
-                 float eps,
-                 DenseTensor* out);
-
-template <typename T, typename Context>
-void MishKernel(const Context& dev_ctx,
-                const DenseTensor& x,
-                float threshold,
-                DenseTensor* out);
+DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
 
 template <typename T, typename Context>
 void HardSwishKernel(const Context& dev_ctx,
diff --git a/paddle/phi/kernels/funcs/eigen/eigen_function.h b/paddle/phi/kernels/funcs/eigen/eigen_function.h
index b971b4f95e..1e81256e79 100644
--- a/paddle/phi/kernels/funcs/eigen/eigen_function.h
+++ b/paddle/phi/kernels/funcs/eigen/eigen_function.h
@@ -118,6 +118,18 @@ struct EigenSub {
                    const InType& right);
 };
 
+template <typename EigenDevice, typename T>
+struct EigenDiv {
+  using InType = Eigen::TensorMap<
+      Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  using OutType =
+      Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  static void Eval(const EigenDevice& dev,
+                   OutType out,
+                   const InType& in,
+                   const T value);
+};
+
 template <typename EigenDevice, typename T, int Rank>
 struct EigenSlice {
   using Array = Eigen::DSizes<Eigen::DenseIndex, Rank>;
diff --git a/paddle/phi/kernels/funcs/eigen/elementwise.cc b/paddle/phi/kernels/funcs/eigen/elementwise.cc
index 507a0116c3..713513757a 100644
--- a/paddle/phi/kernels/funcs/eigen/elementwise.cc
+++ b/paddle/phi/kernels/funcs/eigen/elementwise.cc
@@ -55,5 +55,22 @@ struct EigenSub<Eigen::DefaultDevice, T> {
 
 template struct EigenSub<Eigen::DefaultDevice, float>;
 
+template <typename T>
+struct EigenDiv<Eigen::DefaultDevice, T> {
+  using InType = Eigen::TensorMap<
+      Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  using OutType =
+      Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  static void Eval(const Eigen::DefaultDevice& dev,
+                   OutType out,
+                   const InType& in,
+                   const T value) {
+    out.device(dev) = in / value;
+  }
+};
+
+template struct EigenDiv<Eigen::DefaultDevice, float>;
+template struct EigenDiv<Eigen::DefaultDevice, double>;
+
 }  // namespace funcs
 }  // namespace phi
diff --git a/paddle/phi/kernels/funcs/eigen/elementwise.cu b/paddle/phi/kernels/funcs/eigen/elementwise.cu
index 3855ba8ccf..1fb3b8a376 100644
--- a/paddle/phi/kernels/funcs/eigen/elementwise.cu
+++ b/paddle/phi/kernels/funcs/eigen/elementwise.cu
@@ -55,5 +55,22 @@ struct EigenSub<Eigen::GpuDevice, T> {
 
 template struct EigenSub<Eigen::GpuDevice, float>;
 
+template <typename T>
+struct EigenDiv<Eigen::GpuDevice, T> {
+  using InType = Eigen::TensorMap<
+      Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  using OutType =
+      Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, Eigen::DenseIndex>>;
+  static void Eval(const Eigen::GpuDevice& dev,
+                   OutType out,
+                   const InType& in,
+                   const T value) {
+    out.device(dev) = in / value;
+  }
+};
+
+template struct EigenDiv<Eigen::GpuDevice, float>;
+template struct EigenDiv<Eigen::GpuDevice, double>;
+
 }  // namespace funcs
 }  // namespace phi
diff --git a/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
new file mode 100644
index 0000000000..f8520db2ca
--- /dev/null
+++ b/paddle/phi/kernels/sparse/cpu/unary_grad_kernel.cc
@@ -0,0 +1,79 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
+
+#include "paddle/phi/backends/cpu/cpu_context.h"
+#include "paddle/phi/core/kernel_registry.h"
+#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
+
+#define PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(name, prefix)     \
+  PD_REGISTER_KERNEL(name##_coo_grad,                              \
+                     CPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CooGradKernel,           \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
+  }                                                                \
+                                                                   \
+  PD_REGISTER_KERNEL(name##_csr_grad,                              \
+                     CPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CsrGradKernel,           \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
+  }
+
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sin, Sin)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(tan, Tan)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(asin, Asin)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(atan, Atan)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sinh, Sinh)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(tanh, Tanh)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(asinh, Asinh)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(atanh, Atanh)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(sqrt, Sqrt)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(square, Square)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(log1p, Log1p)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(relu, Relu)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(abs, Abs)
+PD_REGISTER_SPARSE_UNARY_CPU_GRAD_KERNEL(pow, Pow)
+
+PD_REGISTER_KERNEL(cast_coo_grad,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCooGradKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
+
+PD_REGISTER_KERNEL(cast_csr_grad,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCsrGradKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
diff --git a/paddle/phi/kernels/sparse/cpu/unary_kernel.cc b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc
new file mode 100644
index 0000000000..1c1ece27d9
--- /dev/null
+++ b/paddle/phi/kernels/sparse/cpu/unary_kernel.cc
@@ -0,0 +1,139 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/phi/kernels/sparse/unary_kernel.h"
+
+#include "paddle/phi/backends/cpu/cpu_context.h"
+#include "paddle/phi/core/kernel_registry.h"
+#include "paddle/phi/kernels/funcs/eigen/common.h"
+#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
+#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
+#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
+
+namespace phi {
+namespace sparse {
+
+template <typename T, typename Context>
+void DivCooScalarKernel(const Context& dev_ctx,
+                        const SparseCooTensor& x,
+                        float scalar,
+                        SparseCooTensor* out) {
+  EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);
+
+  auto eigen_out =
+      phi::EigenVector<T>::Flatten(*(out->mutable_non_zero_elements()));
+  auto eigen_x = phi::EigenVector<T>::Flatten(x.non_zero_elements());
+  auto& dev = *dev_ctx.eigen_device();
+
+  phi::funcs::EigenDiv<std::decay_t<decltype(dev)>, T>::Eval(
+      dev, eigen_out, eigen_x, static_cast<T>(scalar));
+}
+
+template <typename T, typename Context>
+void DivCsrScalarKernel(const Context& dev_ctx,
+                        const SparseCsrTensor& x,
+                        float scalar,
+                        SparseCsrTensor* out) {
+  EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);
+
+  auto eigen_out =
+      phi::EigenVector<T>::Flatten(*(out->mutable_non_zero_elements()));
+  auto eigen_x = phi::EigenVector<T>::Flatten(x.non_zero_elements());
+  auto& dev = *dev_ctx.eigen_device();
+
+  phi::funcs::EigenDiv<std::decay_t<decltype(dev)>, T>::Eval(
+      dev, eigen_out, eigen_x, static_cast<T>(scalar));
+}
+
+}  // namespace sparse
+}  // namespace phi
+
+#define PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(name, prefix)          \
+  PD_REGISTER_KERNEL(name##_coo,                                   \
+                     CPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CooKernel,               \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
+  }                                                                \
+                                                                   \
+  PD_REGISTER_KERNEL(name##_csr,                                   \
+                     CPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CsrKernel,               \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
+  }
+
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sin, Sin)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(tan, Tan)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(asin, Asin)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(atan, Atan)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sinh, Sinh)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(tanh, Tanh)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(asinh, Asinh)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(atanh, Atanh)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(sqrt, Sqrt)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(square, Square)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(log1p, Log1p)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(relu, Relu)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(abs, Abs)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(pow, Pow)
+PD_REGISTER_SPARSE_UNARY_CPU_KERNEL(scale, Scale)
+
+PD_REGISTER_KERNEL(divide_coo_scalar,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::DivCooScalarKernel,
+                   float,
+                   double) {
+  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
+}
+
+PD_REGISTER_KERNEL(divide_csr_scalar,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::DivCsrScalarKernel,
+                   float,
+                   double) {
+  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
+}
+
+PD_REGISTER_KERNEL(cast_coo,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCooKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
+
+PD_REGISTER_KERNEL(cast_csr,
+                   CPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCsrKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
diff --git a/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
new file mode 100644
index 0000000000..c1f2b2a1f0
--- /dev/null
+++ b/paddle/phi/kernels/sparse/gpu/unary_grad_kernel.cu
@@ -0,0 +1,79 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
+
+#include "paddle/phi/backends/gpu/gpu_context.h"
+#include "paddle/phi/core/kernel_registry.h"
+#include "paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h"
+
+#define PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(name, prefix)     \
+  PD_REGISTER_KERNEL(name##_coo_grad,                              \
+                     GPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CooGradKernel,           \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
+  }                                                                \
+                                                                   \
+  PD_REGISTER_KERNEL(name##_csr_grad,                              \
+                     GPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CsrGradKernel,           \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
+  }
+
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(sin, Sin)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(tan, Tan)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(asin, Asin)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(atan, Atan)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(sinh, Sinh)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(tanh, Tanh)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(asinh, Asinh)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(atanh, Atanh)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(sqrt, Sqrt)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(square, Square)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(log1p, Log1p)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(relu, Relu)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(abs, Abs)
+PD_REGISTER_SPARSE_UNARY_GPU_GRAD_KERNEL(pow, Pow)
+
+PD_REGISTER_KERNEL(cast_coo_grad,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCooGradKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
+
+PD_REGISTER_KERNEL(cast_csr_grad,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCsrGradKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
diff --git a/paddle/phi/kernels/sparse/gpu/unary_kernel.cu b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu
new file mode 100644
index 0000000000..fdf0b5106d
--- /dev/null
+++ b/paddle/phi/kernels/sparse/gpu/unary_kernel.cu
@@ -0,0 +1,142 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/phi/kernels/sparse/unary_kernel.h"
+
+#include "paddle/phi/backends/gpu/gpu_context.h"
+#include "paddle/phi/core/kernel_registry.h"
+#include "paddle/phi/kernels/funcs/elementwise_base.h"
+#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
+
+namespace phi {
+namespace sparse {
+
+template <typename T>
+struct DivScalarFunctor {
+  T value_;
+
+  explicit DivScalarFunctor(T value) : value_(value) {}
+
+  __device__ __forceinline__ T operator()(const T x) const {
+    return x / value_;
+  }
+};
+
+template <typename T, typename Context>
+void DivCooScalarKernel(const Context& dev_ctx,
+                        const SparseCooTensor& x,
+                        float scalar,
+                        SparseCooTensor* out) {
+  EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);
+
+  std::vector<const DenseTensor*> ins = {&(x.non_zero_elements())};
+  std::vector<DenseTensor*> outs = {out->mutable_non_zero_elements()};
+  DivScalarFunctor<T> func(static_cast<T>(scalar));
+  funcs::ElementwiseKernel<T, DivScalarFunctor<T>>(dev_ctx, ins, &outs, func);
+}
+
+template <typename T, typename Context>
+void DivCsrScalarKernel(const Context& dev_ctx,
+                        const SparseCsrTensor& x,
+                        float scalar,
+                        SparseCsrTensor* out) {
+  EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);
+
+  std::vector<const DenseTensor*> ins = {&(x.non_zero_elements())};
+  std::vector<DenseTensor*> outs = {out->mutable_non_zero_elements()};
+  DivScalarFunctor<T> func(static_cast<T>(scalar));
+  funcs::ElementwiseKernel<T, DivScalarFunctor<T>>(dev_ctx, ins, &outs, func);
+}
+
+}  // namespace sparse
+}  // namespace phi
+
+#define PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(name, prefix)          \
+  PD_REGISTER_KERNEL(name##_coo,                                   \
+                     GPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CooKernel,               \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); \
+  }                                                                \
+                                                                   \
+  PD_REGISTER_KERNEL(name##_csr,                                   \
+                     GPU,                                          \
+                     ALL_LAYOUT,                                   \
+                     phi::sparse::prefix##CsrKernel,               \
+                     float,                                        \
+                     double) {                                     \
+    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); \
+  }
+
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sin, Sin)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(tan, Tan)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(asin, Asin)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(atan, Atan)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sinh, Sinh)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(tanh, Tanh)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(asinh, Asinh)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(atanh, Atanh)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(sqrt, Sqrt)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(square, Square)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(log1p, Log1p)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(relu, Relu)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(abs, Abs)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(pow, Pow)
+PD_REGISTER_SPARSE_UNARY_GPU_KERNEL(scale, Scale)
+
+PD_REGISTER_KERNEL(divide_coo_scalar,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::DivCooScalarKernel,
+                   float,
+                   double) {
+  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
+}
+
+PD_REGISTER_KERNEL(divide_csr_scalar,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::DivCsrScalarKernel,
+                   float,
+                   double) {
+  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
+}
+
+PD_REGISTER_KERNEL(cast_coo,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCooKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
+
+PD_REGISTER_KERNEL(cast_csr,
+                   GPU,
+                   ALL_LAYOUT,
+                   phi::sparse::CastCsrKernel,
+                   float,
+                   double,
+                   int8_t,
+                   uint8_t,
+                   int16_t,
+                   int,
+                   int64_t,
+                   bool) {}
diff --git a/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
new file mode 100644
index 0000000000..ffc5f6bbac
--- /dev/null
+++ b/paddle/phi/kernels/sparse/impl/unary_grad_kernel_impl.h
@@ -0,0 +1,141 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/phi/core/sparse_coo_tensor.h"
+#include "paddle/phi/core/sparse_csr_tensor.h"
+#include "paddle/phi/core/tensor_utils.h"
+#include "paddle/phi/kernels/abs_grad_kernel.h"
+#include "paddle/phi/kernels/activation_grad_kernel.h"
+#include "paddle/phi/kernels/cast_kernel.h"
+#include "paddle/phi/kernels/sparse/empty_kernel.h"
+#include "paddle/phi/kernels/sparse/impl/unary_kernel_impl.h"
+
+namespace phi {
+namespace sparse {
+
+#define DEFINE_SPARSE_UNARY_GRAD_KERNEL(prefix)                           \
+  template <typename T, typename Context>                                 \
+  void prefix##CooGradKernel(const Context& dev_ctx,                      \
+                             const SparseCooTensor& x_or_out,             \
+                             const SparseCooTensor& dout,                 \
+                             SparseCooTensor* dx) {                       \
+    EmptyLikeCooKernel<T, Context>(dev_ctx, x_or_out, dx);                \
+    phi::prefix##GradKernel<T, Context>(dev_ctx,                          \
+                                        x_or_out.non_zero_elements(),     \
+                                        dout.non_zero_elements(),         \
+                                        dx->mutable_non_zero_elements()); \
+  }                                                                       \
+                                                                          \
+  template <typename T, typename Context>                                 \
+  void prefix##CsrGradKernel(const Context& dev_ctx,                      \
+                             const SparseCsrTensor& x_or_out,             \
+                             const SparseCsrTensor& dout,                 \
+                             SparseCsrTensor* dx) {                       \
+    EmptyLikeCsrKernel<T, Context>(dev_ctx, x_or_out, dx);                \
+    phi::prefix##GradKernel<T, Context>(dev_ctx,                          \
+                                        x_or_out.non_zero_elements(),     \
+                                        dout.non_zero_elements(),         \
+                                        dx->mutable_non_zero_elements()); \
+  }
+
+#define DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(prefix, attr)       \
+  template <typename T, typename Context>                                 \
+  void prefix##CooGradKernel(const Context& dev_ctx,                      \
+                             const SparseCooTensor& x_or_out,             \
+                             const SparseCooTensor& dout,                 \
+                             float attr,                                  \
+                             SparseCooTensor* dx) {                       \
+    EmptyLikeCooKernel<T, Context>(dev_ctx, x_or_out, dx);                \
+    phi::prefix##GradKernel<T, Context>(dev_ctx,                          \
+                                        x_or_out.non_zero_elements(),     \
+                                        dout.non_zero_elements(),         \
+                                        attr,                             \
+                                        dx->mutable_non_zero_elements()); \
+  }                                                                       \
+                                                                          \
+  template <typename T, typename Context>                                 \
+  void prefix##CsrGradKernel(const Context& dev_ctx,                      \
+                             const SparseCsrTensor& x_or_out,             \
+                             const SparseCsrTensor& dout,                 \
+                             float attr,                                  \
+                             SparseCsrTensor* dx) {                       \
+    EmptyLikeCsrKernel<T, Context>(dev_ctx, x_or_out, dx);                \
+    phi::prefix##GradKernel<T, Context>(dev_ctx,                          \
+                                        x_or_out.non_zero_elements(),     \
+                                        dout.non_zero_elements(),         \
+                                        attr,                             \
+                                        dx->mutable_non_zero_elements()); \
+  }
+
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Sin)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Tan)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Asin)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Atan)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Sinh)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Tanh)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Asinh)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Atanh)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Sqrt)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Square)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Relu)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL(Abs)
+DEFINE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
+
+template <typename T, typename Context>
+void CastCooGradKernel(const Context& dev_ctx,
+                       const SparseCooTensor& x,
+                       const SparseCooTensor& dout,
+                       DataType value_dtype,
+                       SparseCooTensor* dx) {
+  EmptyLikeCooKernel<T, Context>(dev_ctx, x, dx);
+  if (value_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx,
+              dout.non_zero_elements(),
+              dev_ctx.GetPlace(),
+              false,
+              dx->mutable_non_zero_elements());
+  } else {
+    phi::CastKernel<T, Context>(dev_ctx,
+                                dout.non_zero_elements(),
+                                x.non_zero_elements().dtype(),
+                                dx->mutable_non_zero_elements());
+  }
+}
+
+template <typename T, typename Context>
+void CastCsrGradKernel(const Context& dev_ctx,
+                       const SparseCsrTensor& x,
+                       const SparseCsrTensor& dout,
+                       DataType value_dtype,
+                       SparseCsrTensor* dx) {
+  EmptyLikeCsrKernel<T, Context>(dev_ctx, x, dx);
+  if (value_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx,
+              dout.non_zero_elements(),
+              dev_ctx.GetPlace(),
+              false,
+              dx->mutable_non_zero_elements());
+  } else {
+    phi::CastKernel<T, Context>(dev_ctx,
+                                dout.non_zero_elements(),
+                                x.non_zero_elements().dtype(),
+                                dx->mutable_non_zero_elements());
+  }
+}
+
+}  // namespace sparse
+}  // namespace phi
diff --git a/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
new file mode 100644
index 0000000000..231fc551f4
--- /dev/null
+++ b/paddle/phi/kernels/sparse/impl/unary_kernel_impl.h
@@ -0,0 +1,207 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "paddle/phi/core/meta_tensor.h"
+#include "paddle/phi/core/sparse_coo_tensor.h"
+#include "paddle/phi/core/sparse_csr_tensor.h"
+#include "paddle/phi/core/tensor_utils.h"
+#include "paddle/phi/core/visit_type.h"
+#include "paddle/phi/kernels/abs_kernel.h"
+#include "paddle/phi/kernels/activation_kernel.h"
+#include "paddle/phi/kernels/cast_kernel.h"
+#include "paddle/phi/kernels/scale_kernel.h"
+#include "paddle/phi/kernels/sparse/empty_kernel.h"
+#include "paddle/phi/kernels/trunc_kernel.h"
+
+namespace phi {
+namespace sparse {
+
+#define DEFINE_SPARSE_UNARY_KERNEL(prefix)                                 \
+  template <typename T, typename Context>                                  \
+  void prefix##CooKernel(const Context& dev_ctx,                           \
+                         const SparseCooTensor& x,                         \
+                         SparseCooTensor* out) {                           \
+    EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);                       \
+    phi::prefix##Kernel<T, Context>(                                       \
+        dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \
+  }                                                                        \
+                                                                           \
+  template <typename T, typename Context>                                  \
+  void prefix##CsrKernel(const Context& dev_ctx,                           \
+                         const SparseCsrTensor& x,                         \
+                         SparseCsrTensor* out) {                           \
+    EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);                       \
+    phi::prefix##Kernel<T, Context>(                                       \
+        dev_ctx, x.non_zero_elements(), out->mutable_non_zero_elements()); \
+  }
+
+#define DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(prefix, attr)         \
+  template <typename T, typename Context>                              \
+  void prefix##CooKernel(const Context& dev_ctx,                       \
+                         const SparseCooTensor& x,                     \
+                         float attr,                                   \
+                         SparseCooTensor* out) {                       \
+    EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);                   \
+    phi::prefix##Kernel<T, Context>(dev_ctx,                           \
+                                    x.non_zero_elements(),             \
+                                    attr,                              \
+                                    out->mutable_non_zero_elements()); \
+  }                                                                    \
+                                                                       \
+  template <typename T, typename Context>                              \
+  void prefix##CsrKernel(const Context& dev_ctx,                       \
+                         const SparseCsrTensor& x,                     \
+                         float attr,                                   \
+                         SparseCsrTensor* out) {                       \
+    EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);                   \
+    phi::prefix##Kernel<T, Context>(dev_ctx,                           \
+                                    x.non_zero_elements(),             \
+                                    attr,                              \
+                                    out->mutable_non_zero_elements()); \
+  }
+
+DEFINE_SPARSE_UNARY_KERNEL(Sin)
+DEFINE_SPARSE_UNARY_KERNEL(Tan)
+DEFINE_SPARSE_UNARY_KERNEL(Asin)
+DEFINE_SPARSE_UNARY_KERNEL(Atan)
+DEFINE_SPARSE_UNARY_KERNEL(Sinh)
+DEFINE_SPARSE_UNARY_KERNEL(Tanh)
+DEFINE_SPARSE_UNARY_KERNEL(Asinh)
+DEFINE_SPARSE_UNARY_KERNEL(Atanh)
+DEFINE_SPARSE_UNARY_KERNEL(Sqrt)
+DEFINE_SPARSE_UNARY_KERNEL(Square)
+DEFINE_SPARSE_UNARY_KERNEL(Log1p)
+DEFINE_SPARSE_UNARY_KERNEL(Relu)
+DEFINE_SPARSE_UNARY_KERNEL(Abs)
+DEFINE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
+
+template <typename T, typename Context>
+void ScaleCooKernel(const Context& dev_ctx,
+                    const SparseCooTensor& x,
+                    float scale,
+                    float bias,
+                    bool bias_after_scale,
+                    SparseCooTensor* out) {
+  EmptyLikeCooKernel<T, Context>(dev_ctx, x, out);
+  phi::ScaleKernel<T, Context>(dev_ctx,
+                               x.non_zero_elements(),
+                               scale,
+                               bias,
+                               bias_after_scale,
+                               out->mutable_non_zero_elements());
+}
+
+template <typename T, typename Context>
+void ScaleCsrKernel(const Context& dev_ctx,
+                    const SparseCsrTensor& x,
+                    float scale,
+                    float bias,
+                    bool bias_after_scale,
+                    SparseCsrTensor* out) {
+  EmptyLikeCsrKernel<T, Context>(dev_ctx, x, out);
+  phi::ScaleKernel<T, Context>(dev_ctx,
+                               x.non_zero_elements(),
+                               scale,
+                               bias,
+                               bias_after_scale,
+                               out->mutable_non_zero_elements());
+}
+
+template <typename T, typename Context>
+void CastCooKernel(const Context& dev_ctx,
+                   const SparseCooTensor& x,
+                   DataType index_dtype,
+                   DataType value_dtype,
+                   SparseCooTensor* out) {
+  out->set_dims(x.dims());
+
+  const DenseTensor& x_indices = x.non_zero_indices();
+  const DenseTensor& x_values = x.non_zero_elements();
+  DenseTensor* out_indices = out->mutable_non_zero_indices();
+  DenseTensor* out_values = out->mutable_non_zero_elements();
+
+  if (index_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx, x_indices, dev_ctx.GetPlace(), false, out_indices);
+  } else {
+    phi::MetaTensor meta(out_indices);
+    meta.set_dims(x_indices.dims());
+    meta.set_dtype(index_dtype);
+
+    PD_VISIT_INTEGRAL_TYPES(x_indices.dtype(), "CastCooKernel", [&] {
+      phi::CastKernel<data_t, Context>(
+          dev_ctx, x_indices, index_dtype, out_indices);
+    });
+  }
+
+  if (value_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
+  } else {
+    phi::MetaTensor meta(out_values);
+    meta.set_dims(x_values.dims());
+    meta.set_dtype(value_dtype);
+    phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
+  }
+}
+
+template <typename T, typename Context>
+void CastCsrKernel(const Context& dev_ctx,
+                   const SparseCsrTensor& x,
+                   DataType index_dtype,
+                   DataType value_dtype,
+                   SparseCsrTensor* out) {
+  out->set_dims(x.dims());
+
+  const DenseTensor& x_crows = x.non_zero_crows();
+  const DenseTensor& x_cols = x.non_zero_cols();
+  const DenseTensor& x_values = x.non_zero_elements();
+  DenseTensor* out_crows = out->mutable_non_zero_crows();
+  DenseTensor* out_cols = out->mutable_non_zero_cols();
+  DenseTensor* out_values = out->mutable_non_zero_elements();
+
+  if (index_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx, x_crows, dev_ctx.GetPlace(), false, out_crows);
+    phi::Copy(dev_ctx, x_cols, dev_ctx.GetPlace(), false, out_cols);
+  } else {
+    phi::MetaTensor crows_meta(out_crows);
+    crows_meta.set_dims(x_crows.dims());
+    crows_meta.set_dtype(index_dtype);
+
+    PD_VISIT_INTEGRAL_TYPES(x_crows.dtype(), "CastCsrKernel", [&] {
+      phi::CastKernel<data_t, Context>(
+          dev_ctx, x_crows, index_dtype, out_crows);
+    });
+
+    phi::MetaTensor cols_meta(out_cols);
+    cols_meta.set_dims(x_cols.dims());
+    cols_meta.set_dtype(index_dtype);
+
+    PD_VISIT_INTEGRAL_TYPES(x_cols.dtype(), "CastCsrKernel", [&] {
+      phi::CastKernel<data_t, Context>(dev_ctx, x_cols, index_dtype, out_cols);
+    });
+  }
+
+  if (value_dtype == DataType::UNDEFINED) {
+    phi::Copy(dev_ctx, x_values, dev_ctx.GetPlace(), false, out_values);
+  } else {
+    phi::MetaTensor meta(out_values);
+    meta.set_dims(x_values.dims());
+    meta.set_dtype(value_dtype);
+    phi::CastKernel<T, Context>(dev_ctx, x_values, value_dtype, out_values);
+  }
+}
+
+}  // namespace sparse
+}  // namespace phi
diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.cc b/paddle/phi/kernels/sparse/unary_grad_kernel.cc
deleted file mode 100644
index cd844532e9..0000000000
--- a/paddle/phi/kernels/sparse/unary_grad_kernel.cc
+++ /dev/null
@@ -1,183 +0,0 @@
-// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "paddle/phi/kernels/sparse/unary_grad_kernel.h"
-
-#include "paddle/phi/backends/cpu/cpu_context.h"
-#include "paddle/phi/backends/gpu/gpu_context.h"
-#include "paddle/phi/core/kernel_registry.h"
-#include "paddle/phi/core/sparse_coo_tensor.h"
-#include "paddle/phi/core/sparse_csr_tensor.h"
-#include "paddle/phi/core/tensor_utils.h"
-#include "paddle/phi/kernels/activation_grad_kernel.h"
-#include "paddle/phi/kernels/empty_kernel.h"
-
-#define DEFINE_SPARSE_UNARY_GRAD_KERNEL(DenseKernelFunc)                    \
-  namespace phi {                                                           \
-  namespace sparse {                                                        \
-                                                                            \
-  template <typename T, typename Context>                                   \
-  void SparseCoo##DenseKernelFunc(const Context& dev_ctx,                   \
-                                  const SparseCooTensor& x_or_out,          \
-                                  const SparseCooTensor& out_grad,          \
-                                  SparseCooTensor* x_grad) {                \
-    DenseTensor non_zero_indices =                                          \
-        phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_indices());   \
-    DenseTensor non_zero_elements =                                         \
-        phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements());  \
-    phi::Copy(dev_ctx,                                                      \
-              x_or_out.non_zero_indices(),                                  \
-              dev_ctx.GetPlace(),                                           \
-              false,                                                        \
-              &non_zero_indices);                                           \
-    phi::DenseKernelFunc<T, Context>(dev_ctx,                               \
-                                     x_or_out.non_zero_elements(),          \
-                                     out_grad.non_zero_elements(),          \
-                                     &non_zero_elements);                   \
-    x_grad->SetMember(                                                      \
-        non_zero_indices, non_zero_elements, x_or_out.dims(), true);        \
-  }                                                                         \
-                                                                            \
-  template <typename T, typename Context>                                   \
-  void SparseCsr##DenseKernelFunc(const Context& dev_ctx,                   \
-                                  const SparseCsrTensor& x_or_out,          \
-                                  const SparseCsrTensor& out_grad,          \
-                                  SparseCsrTensor* out) {                   \
-    DenseTensor non_zero_crows =                                            \
-        phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_crows());     \
-    DenseTensor non_zero_cols =                                             \
-        phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_cols());      \
-    DenseTensor non_zero_elements =                                         \
-        phi::EmptyLike<T, Context>(dev_ctx, x_or_out.non_zero_elements());  \
-    phi::Copy(dev_ctx,                                                      \
-              x_or_out.non_zero_crows(),                                    \
-              dev_ctx.GetPlace(),                                           \
-              false,                                                        \
-              &non_zero_crows);                                             \
-    phi::Copy(dev_ctx,                                                      \
-              x_or_out.non_zero_cols(),                                     \
-              dev_ctx.GetPlace(),                                           \
-              false,                                                        \
-              &non_zero_cols);                                              \
-    phi::DenseKernelFunc<T, Context>(dev_ctx,                               \
-                                     x_or_out.non_zero_elements(),          \
-                                     out_grad.non_zero_elements(),          \
-                                     &non_zero_elements);                   \
-    out->SetMember(                                                         \
-        non_zero_crows, non_zero_cols, non_zero_elements, x_or_out.dims()); \
-  }                                                                         \
-  }                                                                         \
-  }
-
-#define REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  PD_REGISTER_KERNEL(sparse_coo_##kernel_name,                         \
-                     CPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCoo##DenseKernelFunc,          \
-                     float,                                            \
-                     double) {                                         \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);     \
-  }                                                                    \
-  PD_REGISTER_KERNEL(sparse_csr_##kernel_name,                         \
-                     CPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCsr##DenseKernelFunc,          \
-                     float,                                            \
-                     double) {                                         \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);     \
-  }
-
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
-#define REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  PD_REGISTER_KERNEL(sparse_coo_##kernel_name,                         \
-                     GPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCoo##DenseKernelFunc,          \
-                     float,                                            \
-                     double,                                           \
-                     phi::dtype::float16) {                            \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);     \
-  }                                                                    \
-                                                                       \
-  PD_REGISTER_KERNEL(sparse_csr_##kernel_name,                         \
-                     GPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCsr##DenseKernelFunc,          \
-                     float,                                            \
-                     double,                                           \
-                     phi::dtype::float16) {                            \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);     \
-  }
-#else
-// This macro definition is empty when GPU is disabled
-#define REGISTER_GPU_SPARSE_UNARY_KERNEL(sparse_kernel_name, DenseKernelFunc)
-#endif
-
-#define REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)   \
-  REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
-
-#define DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL(kernel_name,     \
-                                                     DenseKernelFunc) \
-  DEFINE_SPARSE_UNARY_GRAD_KERNEL(DenseKernelFunc)                    \
-  REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
-
-// NOTE: the following code is to bypass the restriction of Paddle
-// kernel registration mechanism. Do NOT refactor them unless you
-// know what you are doing.
-// If you want to implement any new kernel, please follow `sin_grad`,
-// `tanh_grad` etc, do NOT follow the following `relu_grad`.
-DEFINE_SPARSE_UNARY_GRAD_KERNEL(ReluGradKernel)
-
-PD_REGISTER_KERNEL(sparse_coo_relu_grad,
-                   CPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCooReluGradKernel,
-                   float,
-                   double) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
-}
-PD_REGISTER_KERNEL(sparse_csr_relu_grad,
-                   CPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCsrReluGradKernel,
-                   float,
-                   double) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
-}
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
-PD_REGISTER_KERNEL(sparse_coo_relu_grad,
-                   GPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCooReluGradKernel,
-                   float,
-                   double,
-                   phi::dtype::float16) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
-}
-
-PD_REGISTER_KERNEL(sparse_csr_relu_grad,
-                   GPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCsrReluGradKernel,
-                   float,
-                   double,
-                   phi::dtype::float16) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
-}
-#endif
-
-DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL(sin_grad, SinGradKernel)
-DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
-DEFINE_AND_REGISTER_SPARSE_UNARY_GRAD_KERNEL(tanh_grad, TanhGradKernel)
diff --git a/paddle/phi/kernels/sparse/unary_grad_kernel.h b/paddle/phi/kernels/sparse/unary_grad_kernel.h
index 24ea4fee1a..eb2cf9ed69 100644
--- a/paddle/phi/kernels/sparse/unary_grad_kernel.h
+++ b/paddle/phi/kernels/sparse/unary_grad_kernel.h
@@ -17,25 +17,65 @@
 #include "paddle/phi/core/sparse_coo_tensor.h"
 #include "paddle/phi/core/sparse_csr_tensor.h"
 
-#define DECLARE_SPARSE_UNARY_GRAD_KERNEL(name)                      \
-  template <typename T, typename Context>                           \
-  void SparseCoo##name##GradKernel(const Context& dev_ctx,          \
-                                   const SparseCooTensor& x,        \
-                                   const SparseCooTensor& out_grad, \
-                                   SparseCooTensor* x_grad);        \
-                                                                    \
-  template <typename T, typename Context>                           \
-  void SparseCsr##name##GradKernel(const Context& dev_ctx,          \
-                                   const SparseCsrTensor& x,        \
-                                   const SparseCsrTensor& out_grad, \
-                                   SparseCsrTensor* x_grad);
-
 namespace phi {
 namespace sparse {
 
+#define DECLARE_SPARSE_UNARY_GRAD_KERNEL(prefix)              \
+  template <typename T, typename Context>                     \
+  void prefix##CooGradKernel(const Context& dev_ctx,          \
+                             const SparseCooTensor& x_or_out, \
+                             const SparseCooTensor& dout,     \
+                             SparseCooTensor* dx);            \
+                                                              \
+  template <typename T, typename Context>                     \
+  void prefix##CsrGradKernel(const Context& dev_ctx,          \
+                             const SparseCsrTensor& x_or_out, \
+                             const SparseCsrTensor& dout,     \
+                             SparseCsrTensor* dx);
+
+#define DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(prefix, attr) \
+  template <typename T, typename Context>                            \
+  void prefix##CooGradKernel(const Context& dev_ctx,                 \
+                             const SparseCooTensor& x_or_out,        \
+                             const SparseCooTensor& dout,            \
+                             float attr,                             \
+                             SparseCooTensor* dx);                   \
+                                                                     \
+  template <typename T, typename Context>                            \
+  void prefix##CsrGradKernel(const Context& dev_ctx,                 \
+                             const SparseCsrTensor& x_or_out,        \
+                             const SparseCsrTensor& dout,            \
+                             float attr,                             \
+                             SparseCsrTensor* dx);
+
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sin)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Tan)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Asin)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Atan)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sinh)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Asinh)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Atanh)
 DECLARE_SPARSE_UNARY_GRAD_KERNEL(Relu)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Tanh)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Square)
 DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sqrt)
-DECLARE_SPARSE_UNARY_GRAD_KERNEL(Sin)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Log1p)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL(Abs)
+DECLARE_SPARSE_UNARY_GRAD_KERNEL_WITH_ONE_ATTR(Pow, factor)
+
+template <typename T, typename Context>
+void CastCooGradKernel(const Context& dev_ctx,
+                       const SparseCooTensor& x,
+                       const SparseCooTensor& dout,
+                       DataType value_dtype,
+                       SparseCooTensor* dx);
+
+template <typename T, typename Context>
+void CastCsrGradKernel(const Context& dev_ctx,
+                       const SparseCsrTensor& x,
+                       const SparseCsrTensor& dout,
+                       DataType value_dtype,
+                       SparseCsrTensor* dx);
 
 }  // namespace sparse
 }  // namespace phi
diff --git a/paddle/phi/kernels/sparse/unary_kernel.cc b/paddle/phi/kernels/sparse/unary_kernel.cc
deleted file mode 100644
index 2999536b34..0000000000
--- a/paddle/phi/kernels/sparse/unary_kernel.cc
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "paddle/phi/kernels/sparse/unary_kernel.h"
-
-#include "paddle/phi/backends/cpu/cpu_context.h"
-#include "paddle/phi/backends/gpu/gpu_context.h"
-#include "paddle/phi/core/kernel_registry.h"
-#include "paddle/phi/core/sparse_coo_tensor.h"
-#include "paddle/phi/core/sparse_csr_tensor.h"
-#include "paddle/phi/core/tensor_utils.h"
-#include "paddle/phi/kernels/activation_kernel.h"
-#include "paddle/phi/kernels/empty_kernel.h"
-
-#define DEFINE_SPARSE_UNARY_KERNEL(DenseKernelFunc)                      \
-  namespace phi {                                                        \
-  namespace sparse {                                                     \
-                                                                         \
-  template <typename T, typename Context>                                \
-  void SparseCoo##DenseKernelFunc(const Context& dev_ctx,                \
-                                  const SparseCooTensor& x,              \
-                                  SparseCooTensor* out) {                \
-    DenseTensor non_zero_indices =                                       \
-        phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_indices());       \
-    DenseTensor non_zero_elements =                                      \
-        phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements());      \
-    phi::Copy(dev_ctx,                                                   \
-              x.non_zero_indices(),                                      \
-              dev_ctx.GetPlace(),                                        \
-              false,                                                     \
-              &non_zero_indices);                                        \
-    phi::DenseKernelFunc<T, Context>(                                    \
-        dev_ctx, x.non_zero_elements(), &non_zero_elements);             \
-    out->SetMember(non_zero_indices, non_zero_elements, x.dims(), true); \
-  }                                                                      \
-                                                                         \
-  template <typename T, typename Context>                                \
-  void SparseCsr##DenseKernelFunc(const Context& dev_ctx,                \
-                                  const SparseCsrTensor& x,              \
-                                  SparseCsrTensor* out) {                \
-    DenseTensor non_zero_crows =                                         \
-        phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_crows());         \
-    DenseTensor non_zero_cols =                                          \
-        phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_cols());          \
-    DenseTensor non_zero_elements =                                      \
-        phi::EmptyLike<T, Context>(dev_ctx, x.non_zero_elements());      \
-    phi::Copy(dev_ctx,                                                   \
-              x.non_zero_crows(),                                        \
-              dev_ctx.GetPlace(),                                        \
-              false,                                                     \
-              &non_zero_crows);                                          \
-    phi::Copy(dev_ctx,                                                   \
-              x.non_zero_cols(),                                         \
-              dev_ctx.GetPlace(),                                        \
-              false,                                                     \
-              &non_zero_cols);                                           \
-    phi::DenseKernelFunc<T, Context>(                                    \
-        dev_ctx, x.non_zero_elements(), &non_zero_elements);             \
-    out->SetMember(                                                      \
-        non_zero_crows, non_zero_cols, non_zero_elements, x.dims());     \
-  }                                                                      \
-  }                                                                      \
-  }
-
-#define REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  PD_REGISTER_KERNEL(sparse_coo_##kernel_name,                         \
-                     CPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCoo##DenseKernelFunc,          \
-                     float,                                            \
-                     double) {                                         \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);     \
-  }                                                                    \
-  PD_REGISTER_KERNEL(sparse_csr_##kernel_name,                         \
-                     CPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCsr##DenseKernelFunc,          \
-                     float,                                            \
-                     double) {                                         \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);     \
-  }
-
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
-#define REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  PD_REGISTER_KERNEL(sparse_coo_##kernel_name,                         \
-                     GPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCoo##DenseKernelFunc,          \
-                     float,                                            \
-                     double,                                           \
-                     phi::dtype::float16) {                            \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);     \
-  }                                                                    \
-                                                                       \
-  PD_REGISTER_KERNEL(sparse_csr_##kernel_name,                         \
-                     GPU,                                              \
-                     ALL_LAYOUT,                                       \
-                     phi::sparse::SparseCsr##DenseKernelFunc,          \
-                     float,                                            \
-                     double,                                           \
-                     phi::dtype::float16) {                            \
-    kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);     \
-  }
-#else
-// This macro definition is empty when GPU is disabled
-#define REGISTER_GPU_SPARSE_UNARY_KERNEL(sparse_kernel_name, DenseKernelFunc)
-#endif
-
-#define REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  REGISTER_CPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)   \
-  REGISTER_GPU_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
-
-#define DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc) \
-  DEFINE_SPARSE_UNARY_KERNEL(DenseKernelFunc)                                 \
-  REGISTER_SPARSE_UNARY_KERNEL(kernel_name, DenseKernelFunc)
-
-// NOTE: the following code is to bypass the restriction of Paddle
-// kernel registration mechanism. Do NOT refactor them unless you
-// know what you are doing.
-// If you want to implement any new kernel, please follow `sin`,
-// `tanh` etc, do NOT follow `sqrt`.
-DEFINE_SPARSE_UNARY_KERNEL(SqrtKernel)
-
-PD_REGISTER_KERNEL(sparse_coo_sqrt,
-                   CPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCooSqrtKernel,
-                   float,
-                   double) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
-}
-PD_REGISTER_KERNEL(sparse_csr_sqrt,
-                   CPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCsrSqrtKernel,
-                   float,
-                   double) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
-}
-
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
-PD_REGISTER_KERNEL(sparse_coo_sqrt,
-                   GPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCooSqrtKernel,
-                   float,
-                   double,
-                   phi::dtype::float16) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO);
-}
-
-PD_REGISTER_KERNEL(sparse_csr_sqrt,
-                   GPU,
-                   ALL_LAYOUT,
-                   phi::sparse::SparseCsrSqrtKernel,
-                   float,
-                   double,
-                   phi::dtype::float16) {
-  kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR);
-}
-
-#endif
-
-DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL(sin, SinKernel)
-DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL(tanh, TanhKernel)
-DEFINE_AND_REGISTER_SPARSE_UNARY_KERNEL(relu, ReluKernel)
diff --git a/paddle/phi/kernels/sparse/unary_kernel.h b/paddle/phi/kernels/sparse/unary_kernel.h
index 4470173c14..fdb6b21a44 100644
--- a/paddle/phi/kernels/sparse/unary_kernel.h
+++ b/paddle/phi/kernels/sparse/unary_kernel.h
@@ -14,35 +14,104 @@
 
 #pragma once
 
-#include "paddle/phi/core/dense_tensor.h"
 #include "paddle/phi/core/sparse_coo_tensor.h"
 #include "paddle/phi/core/sparse_csr_tensor.h"
-#include "paddle/phi/kernels/activation_kernel.h"
-#include "paddle/phi/kernels/empty_kernel.h"
 
-#define DECLARE_SPARSE_UNARY_KERNEL(name)                                      \
+namespace phi {
+namespace sparse {
+
+#define DECLARE_SPARSE_UNARY_KERNEL(prefix)                                    \
   template <typename T, typename Context>                                      \
-  void SparseCoo##name##Kernel(                                                \
+  void prefix##CooKernel(                                                      \
       const Context& dev_ctx, const SparseCooTensor& x, SparseCooTensor* out); \
                                                                                \
   template <typename T, typename Context>                                      \
-  void SparseCsr##name##Kernel(                                                \
+  void prefix##CsrKernel(                                                      \
       const Context& dev_ctx, const SparseCsrTensor& x, SparseCsrTensor* out);
 
-namespace phi {
-namespace sparse {
+#define DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(prefix, attr) \
+  template <typename T, typename Context>                       \
+  void prefix##CooKernel(const Context& dev_ctx,                \
+                         const SparseCooTensor& x,              \
+                         float attr,                            \
+                         SparseCooTensor* out);                 \
+                                                                \
+  template <typename T, typename Context>                       \
+  void prefix##CsrKernel(const Context& dev_ctx,                \
+                         const SparseCsrTensor& x,              \
+                         float attr,                            \
+                         SparseCsrTensor* out);
 
+DECLARE_SPARSE_UNARY_KERNEL(Sin)
+DECLARE_SPARSE_UNARY_KERNEL(Tan)
+DECLARE_SPARSE_UNARY_KERNEL(Asin)
+DECLARE_SPARSE_UNARY_KERNEL(Atan)
+DECLARE_SPARSE_UNARY_KERNEL(Sinh)
+DECLARE_SPARSE_UNARY_KERNEL(Asinh)
+DECLARE_SPARSE_UNARY_KERNEL(Atanh)
 DECLARE_SPARSE_UNARY_KERNEL(Relu)
+DECLARE_SPARSE_UNARY_KERNEL(Tanh)
+DECLARE_SPARSE_UNARY_KERNEL(Square)
 DECLARE_SPARSE_UNARY_KERNEL(Sqrt)
-DECLARE_SPARSE_UNARY_KERNEL(Sin)
+DECLARE_SPARSE_UNARY_KERNEL(Log1p)
+DECLARE_SPARSE_UNARY_KERNEL(Abs)
+DECLARE_SPARSE_UNARY_KERNEL_WITH_ONE_ATTR(Pow, factor)
+
+template <typename T, typename Context>
+void ScaleCooKernel(const Context& dev_ctx,
+                    const SparseCooTensor& x,
+                    float scale,
+                    float bias,
+                    bool bias_after_scale,
+                    SparseCooTensor* out);
+
+template <typename T, typename Context>
+void ScaleCsrKernel(const Context& dev_ctx,
+                    const SparseCsrTensor& x,
+                    float scale,
+                    float bias,
+                    bool bias_after_scale,
+                    SparseCsrTensor* out);
 
 template <typename T, typename Context>
-SparseCooTensor SparseRelu(const Context& dev_ctx, const SparseCooTensor& x) {
-  DenseTensor indices, values;
-  SparseCooTensor coo(indices, values, x.dims());
-  SparseCooReluKernel<T, Context>(dev_ctx, x, &coo);
+void DivCooScalarKernel(const Context& dev_ctx,
+                        const SparseCooTensor& x,
+                        float scalar,
+                        SparseCooTensor* out);
+
+template <typename T, typename Context>
+void DivCsrScalarKernel(const Context& dev_ctx,
+                        const SparseCsrTensor& x,
+                        float scalar,
+                        SparseCsrTensor* out);
+
+template <typename T, typename Context>
+void CastCooKernel(const Context& dev_ctx,
+                   const SparseCooTensor& x,
+                   DataType index_dtype,
+                   DataType value_dtype,
+                   SparseCooTensor* out);
+
+template <typename T, typename Context>
+void CastCsrKernel(const Context& dev_ctx,
+                   const SparseCsrTensor& x,
+                   DataType index_dtype,
+                   DataType value_dtype,
+                   SparseCsrTensor* out);
+
+template <typename T, typename Context>
+SparseCooTensor ReluCoo(const Context& dev_ctx, const SparseCooTensor& x) {
+  SparseCooTensor coo;
+  ReluCooKernel<T, Context>(dev_ctx, x, &coo);
   return coo;
 }
 
+template <typename T, typename Context>
+SparseCooTensor ReluCsr(const Context& dev_ctx, const SparseCooTensor& x) {
+  SparseCooTensor csr;
+  ReluCsrKernel<T, Context>(dev_ctx, x, &csr);
+  return csr;
+}
+
 }  // namespace sparse
 }  // namespace phi
diff --git a/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
index 51d1e67f5a..9c6776fb2a 100644
--- a/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
+++ b/paddle/phi/tests/kernels/test_sparse_activation_dev_api.cc
@@ -49,7 +49,7 @@ TEST(DEV_API, sparse_relu) {
   memcpy(dense_x.data<float>(), data.data(), data.size() * sizeof(float));
   auto sparse_coo = sparse::DenseToSparseCoo<float>(dev_ctx_cpu, dense_x, 2);
 
-  auto sparse_out = sparse::SparseRelu<float>(dev_ctx_cpu, sparse_coo);
+  auto sparse_out = sparse::ReluCoo<float>(dev_ctx_cpu, sparse_coo);
   DenseTensor dense_out =
       phi::EmptyLike<float>(dev_ctx_cpu, sparse_out.non_zero_elements());
   ReluKernel<float>(dev_ctx_cpu, sparse_coo.non_zero_elements(), &dense_out);
@@ -69,7 +69,7 @@ TEST(DEV_API, sparse_relu) {
 
   SparseCooTensor sparse_out_grad(
       sparse_coo.non_zero_indices(), dense_out, {3, 4});
-  sparse::SparseCooReluGradKernel<float>(
+  sparse::ReluCooGradKernel<float>(
       dev_ctx_cpu, sparse_coo, sparse_out_grad, &sparse_grad_x);
 
   cmp = memcmp(dense_grad_x.data<float>(),
diff --git a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
index 61932cf4a7..12546ea463 100644
--- a/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
+++ b/python/paddle/fluid/tests/unittests/test_sparse_elementwise_op.py
@@ -125,16 +125,14 @@ class TestSparseElementWiseAPI(unittest.TestCase):
     def test_support_dtypes_csr(self):
         paddle.device.set_device('cpu')
         if paddle.device.get_device() == "cpu":
-            with _test_eager_guard():
-                for op in op_list:
-                    self.func_test_csr(op)
+            for op in op_list:
+                self.func_test_csr(op)
 
     def test_support_dtypes_coo(self):
         paddle.device.set_device('cpu')
         if paddle.device.get_device() == "cpu":
-            with _test_eager_guard():
-                for op in op_list:
-                    self.func_test_coo(op)
+            for op in op_list:
+                self.func_test_coo(op)
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/fluid/tests/unittests/test_sparse_model.py b/python/paddle/fluid/tests/unittests/test_sparse_model.py
index 90f30e3831..c070614fc7 100644
--- a/python/paddle/fluid/tests/unittests/test_sparse_model.py
+++ b/python/paddle/fluid/tests/unittests/test_sparse_model.py
@@ -62,3 +62,7 @@ class TestGradientAdd(unittest.TestCase):
             sparse_loss.backward()
 
             assert np.allclose(x.grad.numpy(), sparse_x.grad.to_dense().numpy())
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
index 2272022e8d..36d64f5067 100644
--- a/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
+++ b/python/paddle/fluid/tests/unittests/test_sparse_unary_op.py
@@ -12,137 +12,142 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import print_function
 import unittest
-from typing import Union, Callable
 import numpy as np
 import paddle
-import paddle.fluid as fluid
-from paddle.fluid.framework import _test_eager_guard
-from paddle import _C_ops
+from paddle.fluid.framework import convert_np_dtype_to_dtype_
 
 
 class TestSparseUnary(unittest.TestCase):
 
-    def assert_raises_on_dense_tensor(self, sparse_func):
-        with _test_eager_guard():
-            dense_x = paddle.ones((2, 3))
-            with self.assertRaises(NotImplementedError):
-                sparse_func(dense_x)
-
-    def compare_with_dense(
-        self,
-        x,
-        to_sparse: Callable[[paddle.Tensor], paddle.Tensor],
-        dense_func: Callable[[paddle.Tensor], paddle.Tensor],
-        sparse_func: Callable[[paddle.Tensor], paddle.Tensor],
-        test_gradient: bool,
-    ):
-
-        def tensor_allclose(dense_tensor: paddle.Tensor,
-                            sparse_tensor: paddle.Tensor):
-            dense_numpy = dense_tensor.numpy()
-            mask = ~np.isnan(dense_numpy)
-            return np.allclose(dense_numpy[mask],
-                               sparse_tensor.to_dense().numpy()[mask])
-
-        fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
-        with _test_eager_guard():
-            dense_x = paddle.to_tensor(x,
-                                       dtype="float32",
-                                       stop_gradient=not test_gradient)
-
-            sparse_x = to_sparse(dense_x)
-            sparse_out = sparse_func(sparse_x)
-
-            dense_x = paddle.to_tensor(x,
-                                       dtype="float32",
-                                       stop_gradient=not test_gradient)
+    def to_sparse(self, x, format):
+        if format == 'coo':
+            return x.detach().to_sparse_coo(sparse_dim=x.ndim)
+        elif format == 'csr':
+            return x.detach().to_sparse_csr()
+
+    def check_result(self, dense_func, sparse_func, format, *args):
+        origin_x = paddle.rand([8, 16, 32], dtype='float32')
+        mask = paddle.randint(0, 2, [8, 16, 32]).astype('float32')
+
+        ### check sparse coo with dense ###
+        dense_x = origin_x * mask
+        sp_x = self.to_sparse(dense_x, format)
+
+        sp_x.stop_gradient = False
+        if len(args) == 0:
+            sp_out = sparse_func(sp_x)
+        elif len(args) == 1:
+            sp_out = sparse_func(sp_x, args[0])
+        elif len(args) == 2:
+            sp_out = sparse_func(sp_x, args[0], args[1])
+        sp_out.backward()
+
+        dense_x.stop_gradient = False
+        if len(args) == 0:
             dense_out = dense_func(dense_x)
+        elif len(args) == 1:
+            dense_out = dense_func(dense_x, args[0])
+        elif len(args) == 2:
+            if dense_func == paddle.cast:
+                dense_out = dense_func(dense_x, args[1])
+
+                int_dtype = convert_np_dtype_to_dtype_(args[0])
+                if sp_out.is_sparse_csr():
+                    self.assertEqual(sp_out.crows().dtype, int_dtype)
+                    self.assertEqual(sp_out.cols().dtype, int_dtype)
+                elif sp_out.is_sparse_coo():
+                    self.assertEqual(sp_out.indices().dtype, int_dtype)
+            else:
+                dense_out = dense_func(dense_x, args[0], args[1])
+        dense_out.backward()
+
+        # compare forward
+        self.assertTrue(
+            np.allclose(sp_out.to_dense().numpy(), dense_out.numpy()))
+
+        # compare backward
+        if dense_func == paddle.sqrt:
+            expect_grad = np.nan_to_num(dense_x.grad.numpy(), 0., 0., 0.)
+        else:
+            expect_grad = (dense_x.grad * mask).numpy()
+        self.assertTrue(np.allclose(sp_x.grad.to_dense().numpy(), expect_grad))
+
+    def compare_with_dense(self, dense_func, sparse_func):
+        self.check_result(dense_func, sparse_func, 'coo')
+        self.check_result(dense_func, sparse_func, 'csr')
+
+    def compare_with_dense_one_attr(self, dense_func, sparse_func, attr1):
+        self.check_result(dense_func, sparse_func, 'coo', attr1)
+        self.check_result(dense_func, sparse_func, 'csr', attr1)
+
+    def compare_with_dense_two_attr(self, dense_func, sparse_func, attr1,
+                                    attr2):
+        self.check_result(dense_func, sparse_func, 'coo', attr1, attr2)
+        self.check_result(dense_func, sparse_func, 'csr', attr1, attr2)
 
-            assert tensor_allclose(dense_out, sparse_out)
+    def test_sparse_sin(self):
+        self.compare_with_dense(paddle.sin, paddle.incubate.sparse.sin)
 
-            if test_gradient:
-                dense_out.backward(dense_out)
-                sparse_out.backward(sparse_out)
-                assert tensor_allclose(dense_x.grad, sparse_x.grad)
-        fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
+    def test_sparse_tan(self):
+        self.compare_with_dense(paddle.tan, paddle.incubate.sparse.tan)
 
-    def test_sparse_relu(self):
-        x = [[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]]
-        sparse_dim = 2
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_coo(sparse_dim),
-            paddle.nn.ReLU(),
-            paddle.incubate.sparse.nn.ReLU(),
-            True,
-        )
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_csr(),
-            paddle.nn.ReLU(),
-            paddle.incubate.sparse.nn.ReLU(),
-            False,
-        )
-        self.assert_raises_on_dense_tensor(paddle.incubate.sparse.nn.ReLU())
+    def test_sparse_asin(self):
+        self.compare_with_dense(paddle.asin, paddle.incubate.sparse.asin)
 
-    def test_sparse_sqrt(self):
-        x = [[0, 16, 0, 0], [0, 0, 0, 0], [0, 4, 2, 0]]
-        sparse_dim = 2
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_coo(sparse_dim),
-            paddle.sqrt,
-            paddle.incubate.sparse.sqrt,
-            True,
-        )
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_csr(),
-            paddle.sqrt,
-            paddle.incubate.sparse.sqrt,
-            False,
-        )
-        self.assert_raises_on_dense_tensor(paddle.incubate.sparse.sqrt)
+    def test_sparse_atan(self):
+        self.compare_with_dense(paddle.atan, paddle.incubate.sparse.atan)
 
-    def test_sparse_sin(self):
-        x = [[0, 16, 0, 0], [0, 0, 0, 0], [0, 4, 2, 0]]
-        sparse_dim = 2
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_coo(sparse_dim),
-            paddle.sin,
-            paddle.incubate.sparse.sin,
-            True,
-        )
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_csr(),
-            paddle.sin,
-            paddle.incubate.sparse.sin,
-            False,
-        )
-        self.assert_raises_on_dense_tensor(paddle.incubate.sparse.sin)
+    def test_sparse_sinh(self):
+        self.compare_with_dense(paddle.sinh, paddle.incubate.sparse.sinh)
 
     def test_sparse_tanh(self):
-        x = [[0, 16, 0, 0], [0, 0, 0, 0], [0, -4, 2, 0]]
-        sparse_dim = 2
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_coo(sparse_dim),
-            paddle.tanh,
-            paddle.incubate.sparse.tanh,
-            True,
-        )
-        self.compare_with_dense(
-            x,
-            lambda x: x.to_sparse_csr(),
-            paddle.tanh,
-            paddle.incubate.sparse.tanh,
-            False,
-        )
-        self.assert_raises_on_dense_tensor(paddle.incubate.sparse.tanh)
+        self.compare_with_dense(paddle.tanh, paddle.incubate.sparse.tanh)
+
+    def test_sparse_asinh(self):
+        self.compare_with_dense(paddle.asinh, paddle.incubate.sparse.asinh)
+
+    def test_sparse_atanh(self):
+        self.compare_with_dense(paddle.atanh, paddle.incubate.sparse.atanh)
+
+    def test_sparse_sqrt(self):
+        self.compare_with_dense(paddle.sqrt, paddle.incubate.sparse.sqrt)
+
+    def test_sparse_square(self):
+        self.compare_with_dense(paddle.square, paddle.incubate.sparse.square)
+
+    def test_sparse_log1p(self):
+        self.compare_with_dense(paddle.log1p, paddle.incubate.sparse.log1p)
+
+    def test_sparse_relu(self):
+        self.compare_with_dense(paddle.nn.ReLU(),
+                                paddle.incubate.sparse.nn.ReLU())
+
+    def test_sparse_abs(self):
+        self.compare_with_dense(paddle.abs, paddle.incubate.sparse.abs)
+
+    def test_sparse_neg(self):
+        self.compare_with_dense(paddle.neg, paddle.incubate.sparse.neg)
+
+    def test_sparse_pow(self):
+        self.compare_with_dense_one_attr(paddle.pow, paddle.incubate.sparse.pow,
+                                         3)
+
+    def test_sparse_mul_scalar(self):
+        self.compare_with_dense_one_attr(paddle.Tensor.__mul__,
+                                         paddle.incubate.sparse.multiply, 3)
+
+    def test_sparse_div_scalar(self):
+        self.compare_with_dense_one_attr(paddle.Tensor.__div__,
+                                         paddle.incubate.sparse.divide, 2)
+
+    def test_sparse_cast(self):
+        self.compare_with_dense_two_attr(paddle.cast,
+                                         paddle.incubate.sparse.cast, 'int16',
+                                         'float32')
+        self.compare_with_dense_two_attr(paddle.cast,
+                                         paddle.incubate.sparse.cast, 'int32',
+                                         'float64')
 
 
 if __name__ == "__main__":
diff --git a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
index a12425b692..ac69469cbb 100644
--- a/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
+++ b/python/paddle/fluid/tests/unittests/test_sparse_utils_op.py
@@ -38,7 +38,6 @@ class TestSparseCreate(unittest.TestCase):
                                                            dense_shape,
                                                            stop_gradient=False)
             # test the to_string.py
-            print(coo)
             assert np.array_equal(indices, coo.indices().numpy())
             assert np.array_equal(values, coo.values().numpy())
 
@@ -49,6 +48,7 @@ class TestSparseCreate(unittest.TestCase):
             dense_shape = [3, 3]
             coo = paddle.incubate.sparse.sparse_coo_tensor(
                 indices, values, dense_shape)
+            assert np.array_equal(3, coo.nnz())
             assert np.array_equal(indices, coo.indices().numpy())
             assert np.array_equal(values, coo.values().numpy())
 
@@ -78,7 +78,7 @@ class TestSparseCreate(unittest.TestCase):
             csr = paddle.incubate.sparse.sparse_csr_tensor(
                 crows, cols, values, dense_shape)
             # test the to_string.py
-            print(csr)
+            assert np.array_equal(5, csr.nnz())
             assert np.array_equal(crows, csr.crows().numpy())
             assert np.array_equal(cols, csr.cols().numpy())
             assert np.array_equal(values, csr.values().numpy())
diff --git a/python/paddle/incubate/sparse/__init__.py b/python/paddle/incubate/sparse/__init__.py
index f696434118..c56ada3468 100644
--- a/python/paddle/incubate/sparse/__init__.py
+++ b/python/paddle/incubate/sparse/__init__.py
@@ -15,27 +15,50 @@
 from .creation import sparse_coo_tensor
 from .creation import sparse_csr_tensor
 
-from .unary import sqrt
 from .unary import sin
+from .unary import tan
+from .unary import asin
+from .unary import atan
+from .unary import sinh
 from .unary import tanh
+from .unary import asinh
+from .unary import atanh
+from .unary import sqrt
+from .unary import square
+from .unary import log1p
+from .unary import abs
+from .unary import pow
+from .unary import cast
+from .unary import neg
 
 from .binary import mv
 from .binary import matmul
 from .binary import masked_matmul
-
-from .math import add
-from .math import divide
-from .math import multiply
-from .math import subtract
+from .binary import add
+from .binary import divide
+from .binary import multiply
+from .binary import subtract
 
 from . import nn
 
 __all__ = [
     'sparse_coo_tensor',
     'sparse_csr_tensor',
-    'sqrt',
     'sin',
+    'tan',
+    'asin',
+    'atan',
+    'sinh',
     'tanh',
+    'asinh',
+    'atanh',
+    'sqrt',
+    'square',
+    'log1p',
+    'abs',
+    'pow',
+    'cast',
+    'neg',
     'mv',
     'matmul',
     'masked_matmul',
diff --git a/python/paddle/incubate/sparse/binary.py b/python/paddle/incubate/sparse/binary.py
index f34378924e..0c90cd92a7 100644
--- a/python/paddle/incubate/sparse/binary.py
+++ b/python/paddle/incubate/sparse/binary.py
@@ -13,10 +13,19 @@
 # limitations under the License.
 
 from paddle import _C_ops
-from paddle.fluid.framework import dygraph_only
+from paddle.fluid.framework import dygraph_only, core
 
 __all__ = []
 
+_int_dtype_ = [
+    core.VarDesc.VarType.UINT8,
+    core.VarDesc.VarType.INT8,
+    core.VarDesc.VarType.INT16,
+    core.VarDesc.VarType.INT32,
+    core.VarDesc.VarType.INT64,
+    core.VarDesc.VarType.BOOL,
+]
+
 
 @dygraph_only
 def matmul(x, y, name=None):
@@ -197,3 +206,191 @@ def mv(x, vec, name=None):
 
     """
     return _C_ops.final_state_sparse_mv(x, vec)
+
+
+def add(x, y, name=None):
+    """
+    Add two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
+    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
+    The equation is:
+
+    .. math::
+        out = x + y
+
+    Args:
+        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        Tensor: the result tensor.
+
+    Examples:
+
+    ..  code-block:: python
+
+        import paddle
+        from paddle.fluid.framework import _test_eager_guard
+
+        paddle.device.set_device("cpu")
+
+        with _test_eager_guard():
+            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
+            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
+            sparse_x = x.to_sparse_csr()
+            sparse_y = y.to_sparse_csr()
+            sparse_z = paddle.incubate.sparse.add(sparse_x, sparse_y)
+            print(sparse_z.to_dense())
+
+        # [[ 0., -1.,  0.,  0.],
+        # [ 0.,  2., -6.,  0.],
+        # [ 6.,  8.,  4.,  8.]]
+
+    """
+    if y.dtype != x.dtype:
+        y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
+    return _C_ops.final_state_sparse_add(x, y)
+
+
+@dygraph_only
+def subtract(x, y, name=None):
+    """
+    Subtract two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
+    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
+    The equation is:
+
+    .. math::
+        out = x - y
+
+    Args:
+        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        Tensor: the result tensor.
+
+    Examples:
+
+    ..  code-block:: python
+
+        import paddle
+        from paddle.fluid.framework import _test_eager_guard
+
+        paddle.device.set_device("cpu")
+
+        with _test_eager_guard():
+            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
+            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
+            sparse_x = x.to_sparse_csr()
+            sparse_y = y.to_sparse_csr()
+            sparse_z = paddle.incubate.sparse.subtract(sparse_x, sparse_y)
+            print(sparse_z.to_dense())
+
+        # [[ 0., -1.,  0.,  4.],
+        # [ 0., -2.,  0.,  0.],
+        # [ 2.,  2., -4., -8.]]
+
+    """
+    if y.dtype != x.dtype:
+        y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
+    return _C_ops.final_state_sparse_subtract(x, y)
+
+
+@dygraph_only
+def multiply(x, y, name=None):
+    """
+    Multiply two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
+    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
+    The equation is:
+
+    .. math::
+        out = x * y
+
+    Args:
+        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        Tensor: the result tensor.
+
+    Examples:
+
+    ..  code-block:: python
+
+        import paddle
+        from paddle.fluid.framework import _test_eager_guard
+
+        paddle.device.set_device("cpu")
+
+        with _test_eager_guard():
+            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
+            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
+            sparse_x = x.to_sparse_csr()
+            sparse_y = y.to_sparse_csr()
+            sparse_z = paddle.incubate.sparse.multiply(sparse_x, sparse_y)
+            print(sparse_z.to_dense())
+
+        # [[ 0.,  0.,  0., -4.],
+        # [ 0.,  0.,  9.,  0.],
+        # [ 8., 15.,  0.,  0.]]
+
+    """
+    if isinstance(y, (int, float)):
+        return _C_ops.final_state_sparse_scale(x, float(y), 0.0, True)
+    else:
+        if y.dtype != x.dtype:
+            y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
+        return _C_ops.final_state_sparse_multiply(x, y)
+
+
+@dygraph_only
+def divide(x, y, name=None):
+    """
+    Divide two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
+    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
+    The equation is:
+
+    .. math::
+        out = x / y
+
+    Args:
+        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
+        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        Tensor: the result tensor.
+
+    Examples:
+
+    ..  code-block:: python
+
+        import paddle
+        from paddle.fluid.framework import _test_eager_guard
+
+        paddle.device.set_device("cpu")
+
+        with _test_eager_guard():
+            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
+            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
+            sparse_x = x.to_sparse_csr()
+            sparse_y = y.to_sparse_csr()
+            sparse_z = paddle.incubate.sparse.divide(sparse_x, sparse_y)
+            print(sparse_z.to_dense())
+
+        # [[ nan      , -inf.     ,  nan      , -1.       ],
+        # [ nan      ,  0.       ,  1.       ,  nan      ],
+        # [ 2.       , 1.66666663,  0.       ,  0.       ]]
+
+    """
+    if x.dtype in _int_dtype_:
+        x = _C_ops.final_state_sparse_cast(x, None, core.VarDesc.VarType.FP32)
+
+    if isinstance(y, (int, float)):
+        return _C_ops.final_state_sparse_divide_scalar(x, float(y))
+    else:
+        if y.dtype != x.dtype:
+            y = _C_ops.final_state_sparse_cast(y, None, x.dtype)
+        return _C_ops.final_state_sparse_divide(x, y)
diff --git a/python/paddle/incubate/sparse/math.py b/python/paddle/incubate/sparse/math.py
deleted file mode 100644
index c6a984c3ad..0000000000
--- a/python/paddle/incubate/sparse/math.py
+++ /dev/null
@@ -1,260 +0,0 @@
-# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-sparse math functions
-"""
-from __future__ import print_function
-
-from paddle import _C_ops, in_dynamic_mode, device, int32, int64
-from paddle.tensor import cast
-from paddle.incubate.sparse import sparse_csr_tensor
-
-
-def _cast_coo(x, dtype, name=None):
-    indices = x.indices()
-    values = cast(x.values(), dtype)
-    return _C_ops.final_state_sparse_create_sparse_coo_tensor(
-        values, indices, x.shape)
-
-
-def _cast_csr(x, dtype, name=None):
-    crows = x.crows()
-    cols = x.cols()
-    values = cast(x.values(), dtype)
-    return sparse_csr_tensor(crows, cols, values, x.shape)
-
-
-def _cast(x, dtype, name=None):
-    if x.is_sparse_coo():
-        return _cast_coo(x, dtype, name)
-    return _cast_csr(x, dtype, name)
-
-
-def add(x, y, name=None):
-    """
-    Add two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
-    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
-    The equation is:
-
-    .. math::
-        out = x + y
-
-    Args:
-        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        Tensor: the result tensor.
-
-    Examples:
-
-    ..  code-block:: python
-
-        import paddle
-        from paddle.fluid.framework import _test_eager_guard
-
-        paddle.device.set_device("cpu")
-
-        with _test_eager_guard():
-            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
-            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
-            sparse_x = x.to_sparse_csr()
-            sparse_y = y.to_sparse_csr()
-            sparse_z = paddle.incubate.sparse.add(sparse_x, sparse_y)
-            print(sparse_z.to_dense())
-
-        # [[ 0., -1.,  0.,  0.],
-        # [ 0.,  2., -6.,  0.],
-        # [ 6.,  8.,  4.,  8.]]
-
-    """
-    assert device.get_device(
-    ) == "cpu", "Currently, Sparse add only support CPU device."
-    assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
-    assert x.is_sparse_csr() == y.is_sparse_csr(
-    ), f"Expect sparse tensor type to be same"
-    if x.is_sparse_coo() or x.is_sparse_csr():
-        return _C_ops.final_state_sparse_add(x, y)
-    else:
-        raise ValueError(
-            "Currently, sparse.add only support the input of SparseCooTensor or SparseCsrTensor"
-        )
-
-
-def subtract(x, y, name=None):
-    """
-    Subtract two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
-    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
-    The equation is:
-
-    .. math::
-        out = x - y
-
-    Args:
-        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        Tensor: the result tensor.
-
-    Examples:
-
-    ..  code-block:: python
-
-        import paddle
-        from paddle.fluid.framework import _test_eager_guard
-
-        paddle.device.set_device("cpu")
-
-        with _test_eager_guard():
-            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
-            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
-            sparse_x = x.to_sparse_csr()
-            sparse_y = y.to_sparse_csr()
-            sparse_z = paddle.incubate.sparse.subtract(sparse_x, sparse_y)
-            print(sparse_z.to_dense())
-
-        # [[ 0., -1.,  0.,  4.],
-        # [ 0., -2.,  0.,  0.],
-        # [ 2.,  2., -4., -8.]]
-
-    """
-    assert device.get_device(
-    ) == "cpu", "Currently, Sparse subtract only support CPU device."
-    assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
-    assert x.is_sparse_csr() == y.is_sparse_csr(
-    ), f"Expect sparse tensor type to be same"
-    if x.is_sparse_coo() or x.is_sparse_csr():
-        return _C_ops.final_state_sparse_subtract(x, y)
-    else:
-        raise ValueError(
-            "Currently, sparse.subtract only support the input of SparseCooTensor or SparseCsrTensor"
-        )
-
-
-def multiply(x, y, name=None):
-    """
-    Multiply two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
-    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
-    The equation is:
-
-    .. math::
-        out = x * y
-
-    Args:
-        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        Tensor: the result tensor.
-
-    Examples:
-
-    ..  code-block:: python
-
-        import paddle
-        from paddle.fluid.framework import _test_eager_guard
-
-        paddle.device.set_device("cpu")
-
-        with _test_eager_guard():
-            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
-            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
-            sparse_x = x.to_sparse_csr()
-            sparse_y = y.to_sparse_csr()
-            sparse_z = paddle.incubate.sparse.multiply(sparse_x, sparse_y)
-            print(sparse_z.to_dense())
-
-        # [[ 0.,  0.,  0., -4.],
-        # [ 0.,  0.,  9.,  0.],
-        # [ 8., 15.,  0.,  0.]]
-
-    """
-    assert device.get_device(
-    ) == "cpu", "Currently, Sparse multiply only support CPU device."
-    assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
-    assert x.is_sparse_csr() == y.is_sparse_csr(
-    ), f"Expect sparse tensor type to be same"
-    if x.is_sparse_coo() or x.is_sparse_csr():
-        return _C_ops.final_state_sparse_multiply(x, y)
-    else:
-        raise ValueError(
-            "Currently, sparse.multiply only support the input of SparseCooTensor or SparseCsrTensor"
-        )
-
-
-def divide(x, y, name=None):
-    """
-    Divide two sparse tensors element-wise. Input x and y's shape should be identical and have same sparse
-    type(SparseCooTensor or SparseCsrTensor).If input is SparseCooTensor, x and y's sparse_dim should be identical.
-    The equation is:
-
-    .. math::
-        out = x / y
-
-    Args:
-        x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
-        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
-
-    Returns:
-        Tensor: the result tensor.
-
-    Examples:
-
-    ..  code-block:: python
-
-        import paddle
-        from paddle.fluid.framework import _test_eager_guard
-
-        paddle.device.set_device("cpu")
-
-        with _test_eager_guard():
-            x = paddle.to_tensor([[0, -1, 0, 2], [0, 0, -3, 0], [4, 5, 0, 0]], 'float32')
-            y = paddle.to_tensor([[0, 0, 0, -2], [0, 2, -3, 0], [2, 3, 4, 8]], 'float32')
-            sparse_x = x.to_sparse_csr()
-            sparse_y = y.to_sparse_csr()
-            sparse_z = paddle.incubate.sparse.divide(sparse_x, sparse_y)
-            print(sparse_z.to_dense())
-
-        # [[ nan      , -inf.     ,  nan      , -1.       ],
-        # [ nan      ,  0.       ,  1.       ,  nan      ],
-        # [ 2.       , 1.66666663,  0.       ,  0.       ]]
-
-    """
-    assert device.get_device(
-    ) == "cpu", "Currently, Sparse divide only support CPU device."
-    assert in_dynamic_mode(), "Currently, Sparse API only support dynamic mode"
-    assert x.is_sparse_csr() == y.is_sparse_csr(
-    ), f"Expect sparse tensor type to be same"
-
-    if x.dtype in [int32, int64]:
-        if x.is_sparse_coo() or x.is_sparse_csr():
-            cx = _cast(x, 'float32')
-            cy = _cast(y, 'float32')
-            return _C_ops.final_state_sparse_divide(cx, cy)
-        else:
-            raise ValueError(
-                "Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor"
-            )
-    else:
-        if x.is_sparse_coo() or x.is_sparse_csr():
-            return _C_ops.final_state_sparse_divide(x, y)
-        else:
-            raise ValueError(
-                "Currently, sparse.divide only support the input of SparseCooTensor or SparseCsrTensor"
-            )
diff --git a/python/paddle/incubate/sparse/unary.py b/python/paddle/incubate/sparse/unary.py
index 09e449b0d9..d3fb55b737 100644
--- a/python/paddle/incubate/sparse/unary.py
+++ b/python/paddle/incubate/sparse/unary.py
@@ -13,19 +13,79 @@
 # limitations under the License.
 
 from paddle import _C_ops
-from paddle.fluid.framework import dygraph_only
+from paddle.fluid.framework import dygraph_only, core, convert_np_dtype_to_dtype_
 
 __all__ = []
 
 
 @dygraph_only
-def tanh(x, name=None):
+def sin(x, name=None):
     """
-    sparse tanh activation, requiring x to be a sparse coo or sparse csr tensor.
+    Calculate elementwise sin of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = sin(x)
 
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.sin(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_sin(x)
+
+
+@dygraph_only
+def tan(x, name=None):
+    """
+    Calculate elementwise tan of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
     .. math::
 
-        out = tanh(x)
+        out = tan(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.tan(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_tan(x)
+
+
+@dygraph_only
+def asin(x, name=None):
+    """
+    Calculate elementwise asin of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = asin(x)
 
     Parameters:
         x (Tensor): The input Sparse Tensor with data type float32, float64.
@@ -39,21 +99,200 @@ def tanh(x, name=None):
         .. code-block:: python
 
             import paddle
-            from paddle.fluid.framework import _test_eager_guard
 
-            with _test_eager_guard():
-                dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
-                sparse_x = dense_x.to_sparse_coo(1)
-                out = paddle.incubate.sparse.tanh(sparse_x)
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.asin(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_asin(x)
+
+
+@dygraph_only
+def atan(x, name=None):
+    """
+    Calculate elementwise atan of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = atan(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.atan(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_atan(x)
+
+
+@dygraph_only
+def sinh(x, name=None):
+    """
+    Calculate elementwise sinh of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = sinh(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.sinh(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_sinh(x)
+
+
+@dygraph_only
+def asinh(x, name=None):
+    """
+    Calculate elementwise asinh of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = asinh(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.asinh(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_asinh(x)
+
+
+@dygraph_only
+def atanh(x, name=None):
+    """
+    Calculate elementwise atanh of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = atanh(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.atanh(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_atanh(x)
+
+
+@dygraph_only
+def tanh(x, name=None):
+    """
+    Calculate elementwise tanh of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = tanh(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+            
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.tanh(sparse_x)
+            
     """
     return _C_ops.final_state_sparse_tanh(x)
 
 
 @dygraph_only
-def sqrt(x, name=None):
+def square(x, name=None):
     """
-    Calculate square root of x, requiring x to be a sparse coo or sparse csr tensor.
+    Calculate elementwise square of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
+    .. math::
+
+        out = square(x)
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
 
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+            
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.square(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_square(x)
+
+
+@dygraph_only
+def sqrt(x, name=None):
+    """
+    Calculate elementwise sqrt of SparseTensor, requiring x to be a SparseCooTensor or SparseCsrTensor.
+        
     .. math::
 
         out = sqrt(x)
@@ -70,24 +309,23 @@ def sqrt(x, name=None):
         .. code-block:: python
 
             import paddle
-            from paddle.fluid.framework import _test_eager_guard
 
-            with _test_eager_guard():
-                dense_x = paddle.to_tensor([4, 0, 1], dtype='float32')
-                sparse_x = dense_x.to_sparse_coo(1)
-                out = paddle.incubate.sparse.sqrt(sparse_x)
+            dense_x = paddle.to_tensor([-2., 0., 1.])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.sqrt(sparse_x)
+            
     """
     return _C_ops.final_state_sparse_sqrt(x)
 
 
 @dygraph_only
-def sin(x, name=None):
+def log1p(x, name=None):
     """
-    Calculate sin of x, requiring x to be a sparse coo or sparse csr tensor.
+    Calculate the natural log of (1+x), requiring x to be a SparseCooTensor or SparseCsrTensor.
 
     .. math::
 
-        out = sin(x)
+        out = ln(1+x)
 
     Parameters:
         x (Tensor): The input Sparse Tensor with data type float32, float64.
@@ -101,11 +339,136 @@ def sin(x, name=None):
         .. code-block:: python
 
             import paddle
-            from paddle.fluid.framework import _test_eager_guard
 
-            with _test_eager_guard():
-                dense_x = paddle.to_tensor([-2, 0, 3], dtype='float32')
-                sparse_x = dense_x.to_sparse_coo(1)
-                out = paddle.incubate.sparse.sin(sparse_x)
+            dense_x = paddle.to_tensor([-2, 0, 1], dtype='float32')
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.log1p(sparse_x)
+            
     """
-    return _C_ops.final_state_sparse_sin(x)
+    return _C_ops.final_state_sparse_log1p(x)
+
+
+@dygraph_only
+def cast(x, index_dtype=None, value_dtype=None, name=None):
+    """
+    cast non-zero-index of SparseTensor to `index_dtype`, non-zero-element of SparseTensor to
+    `value_dtype` , requiring x to be a SparseCooTensor or SparseCsrTensor.
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        index_dtype (np.dtype|str, optional): Data type of the index of SparseCooTensor, 
+            or crows/cols of SparseCsrTensor. Can be uint8, int8, int16, int32, int64.
+        value_dtype (np.dtype|str, optional): Data type of the value of SparseCooTensor,
+            SparseCsrTensor. Can be bool, float16, float32, float64, int8, int32, int64, uint8.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2, 0, 1])
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.cast(sparse_x, 'int32', 'float64')
+            
+    """
+    if index_dtype and not isinstance(index_dtype, core.VarDesc.VarType):
+        index_dtype = convert_np_dtype_to_dtype_(index_dtype)
+    if value_dtype and not isinstance(value_dtype, core.VarDesc.VarType):
+        value_dtype = convert_np_dtype_to_dtype_(value_dtype)
+    return _C_ops.final_state_sparse_cast(x, index_dtype, value_dtype)
+
+
+@dygraph_only
+def pow(x, factor, name=None):
+    """
+    Calculate elementwise pow of x, requiring x to be a SparseCooTensor or SparseCsrTensor.
+
+    .. math::
+
+        out = x^{factor}
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        factor (float|int): factor of pow.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2, 0, 3], dtype='float32')
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.pow(sparse_x, 2)
+            
+    """
+    return _C_ops.final_state_sparse_pow(x, float(factor))
+
+
+@dygraph_only
+def neg(x, name=None):
+    """
+    Calculate elementwise negative of x, requiring x to be a SparseCooTensor or SparseCsrTensor.
+
+    .. math::
+
+        out = -x
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2, 0, 3], dtype='float32')
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.neg(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_scale(x, -1.0, 0.0, True)
+
+
+@dygraph_only
+def abs(x, name=None):
+    """
+    Calculate elementwise absolute value of x, requiring x to be a SparseCooTensor or SparseCsrTensor.
+
+    .. math::
+
+        out = |x|
+
+    Parameters:
+        x (Tensor): The input Sparse Tensor with data type float32, float64.
+        name (str, optional): Name for the operation (optional, default is None).
+            For more information, please refer to :ref:`api_guide_Name`.
+
+    Returns:
+        A Sparse Tensor with the same data type and shape as ``x`` .
+
+    Examples:
+        .. code-block:: python
+
+            import paddle
+
+            dense_x = paddle.to_tensor([-2, 0, 3], dtype='float32')
+            sparse_x = dense_x.to_sparse_coo(1)
+            out = paddle.incubate.sparse.abs(sparse_x)
+            
+    """
+    return _C_ops.final_state_sparse_abs(x)
-- 
GitLab