diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc
index 49e416fd0152dce7c09670f53155b23217c81b29..c0c50d68868ea95942386ccb1da9e190251525d4 100644
--- a/paddle/phi/infermeta/backward.cc
+++ b/paddle/phi/infermeta/backward.cc
@@ -169,6 +169,27 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
   logits_grad->set_dtype(softmax.dtype());
 }
 
+void DeformableConvGradInferMeta(const MetaTensor& x,
+                                 const MetaTensor& offset,
+                                 const MetaTensor& filter,
+                                 paddle::optional<const MetaTensor&> mask,
+                                 const MetaTensor& out_grad,
+                                 const std::vector<int>& strides,
+                                 const std::vector<int>& paddings,
+                                 const std::vector<int>& dilations,
+                                 int deformable_groups,
+                                 int groups,
+                                 int im2col_step,
+                                 MetaTensor* dx,
+                                 MetaTensor* offset_grad,
+                                 MetaTensor* filter_grad,
+                                 MetaTensor* mask_grad) {
+  GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad);
+  if (mask) {
+    UnchangedInferMeta(mask.get(), mask_grad);
+  }
+}
+
 void GatherNdGradInferMeta(const MetaTensor& x,
                            const MetaTensor& index,
                            const MetaTensor& out_grad,
diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h
index eff3731bf22536c8f698bc513df2db68603d9f57..ad375e609313da5ecaab48c1b5ff439fd80e170e 100644
--- a/paddle/phi/infermeta/backward.h
+++ b/paddle/phi/infermeta/backward.h
@@ -79,6 +79,22 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
                                           MetaTensor* logits_grad,
                                           MetaConfig config = MetaConfig());
 
+void DeformableConvGradInferMeta(const MetaTensor& x,
+                                 const MetaTensor& offset,
+                                 const MetaTensor& filter,
+                                 paddle::optional<const MetaTensor&> mask,
+                                 const MetaTensor& out_grad,
+                                 const std::vector<int>& strides,
+                                 const std::vector<int>& paddings,
+                                 const std::vector<int>& dilations,
+                                 int deformable_groups,
+                                 int groups,
+                                 int im2col_step,
+                                 MetaTensor* dx,
+                                 MetaTensor* offset_grad,
+                                 MetaTensor* filter_grad,
+                                 MetaTensor* mask_grad);
+
 void GatherNdGradInferMeta(const MetaTensor& x,
                            const MetaTensor& index,
                            const MetaTensor& out_grad,
diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc
index 298ad14f9e04b66147932d4e1960f6e3bb58c45c..2139605fb204861fb8e20372ce844dbe20436665 100644
--- a/paddle/phi/infermeta/binary.cc
+++ b/paddle/phi/infermeta/binary.cc
@@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x,
   }
 }
 
+// Used in MatrixRankTolInferMeta
+static DDim CheckAndGetOutputDim(const DDim& dim_x) {
+  auto x_vec = phi::vectorize(dim_x);
+  if (x_vec.size() == 2) {
+    return phi::make_ddim({1});
+  }
+  x_vec.erase(x_vec.end() - 2, x_vec.end());
+  return phi::make_ddim(x_vec);
+}
+
 }  // namespace detail
 
 void AllValueCompareInferMeta(const MetaTensor& x,
@@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
   out->share_lod(x);
 }
 
+void MatrixRankTolInferMeta(const MetaTensor& x,
+                            const MetaTensor& atol_tensor,
+                            bool use_default_tol,
+                            bool hermitian,
+                            MetaTensor* out) {
+  auto dim_x = x.dims();
+  PADDLE_ENFORCE_GE(
+      dim_x.size(),
+      2,
+      phi::errors::InvalidArgument("The dims of input must be greater than 2"));
+
+  if (hermitian) {
+    int rows = dim_x[dim_x.size() - 2];
+    int cols = dim_x[dim_x.size() - 1];
+    PADDLE_ENFORCE_EQ(rows,
+                      cols,
+                      phi::errors::InvalidArgument(
+                          "if hermitian == true, matrix should be n*n"));
+  }
+  DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
+  auto dim_tol = atol_tensor.dims();
+  if (dim_x_batch == dim_tol) {
+    out->set_dims(dim_x_batch);
+  } else {
+    int max_dim = std::max(dim_x_batch.size(), dim_tol.size());
+    int axis = std::abs(dim_x_batch.size() - dim_tol.size());
+    std::vector<int> x_batch_dims_array(max_dim);
+    std::vector<int> tol_dims_array(max_dim);
+    std::vector<int> out_dims_array(max_dim);
+    phi::funcs::GetBroadcastDimsArrays(dim_x_batch,
+                                       dim_tol,
+                                       x_batch_dims_array.data(),
+                                       tol_dims_array.data(),
+                                       out_dims_array.data(),
+                                       max_dim,
+                                       axis);
+    out->set_dims(phi::make_ddim(out_dims_array));
+  }
+  out->share_lod(x);
+}
+
 void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) {
   auto dim_x = x.dims();
   auto dim_vec = vec.dims();
diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h
index 70c3c9dfe849dee15242674d70af95d1932f9e02..192fa214c905fede429e5a2b10069a61a91794e4 100644
--- a/paddle/phi/infermeta/binary.h
+++ b/paddle/phi/infermeta/binary.h
@@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x,
                                 int y_num_col_dims,
                                 MetaTensor* out);
 
+void MatrixRankTolInferMeta(const MetaTensor& x,
+                            const MetaTensor& atol_tensor,
+                            bool use_default_tol,
+                            bool hermitian,
+                            MetaTensor* out);
+
 void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out);
 
 void PReluInferMeta(const MetaTensor& x,
diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc
index fa3ea84c931548157d4143d9aba88d8c941ff97e..c6e2cb761911e0c8d6856f9438099593a1d53d90 100644
--- a/paddle/phi/infermeta/unary.cc
+++ b/paddle/phi/infermeta/unary.cc
@@ -31,6 +31,18 @@ limitations under the License. */
 
 namespace phi {
 
+namespace detail {
+// Used in MatrixRankInferMeta
+static DDim CheckAndGetOutputDim(const DDim& dim_x) {
+  auto x_vec = phi::vectorize(dim_x);
+  if (x_vec.size() == 2) {
+    return phi::make_ddim({1});
+  }
+  x_vec.erase(x_vec.end() - 2, x_vec.end());
+  return phi::make_ddim(x_vec);
+}
+}  // namespace detail
+
 void ArgMinMaxInferMeta(const MetaTensor& x,
                         int64_t axis,
                         bool keepdims,
@@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) {
   out->set_dtype(x.dtype());
 }
 
+void MatrixRankInferMeta(const MetaTensor& x,
+                         bool use_default_tol,
+                         bool hermitian,
+                         MetaTensor* out) {
+  auto dim_x = x.dims();
+  PADDLE_ENFORCE_GE(
+      dim_x.size(),
+      2,
+      phi::errors::InvalidArgument("The dims of input must be greater than 2"));
+
+  if (hermitian) {
+    int rows = dim_x[dim_x.size() - 2];
+    int cols = dim_x[dim_x.size() - 1];
+    PADDLE_ENFORCE_EQ(rows,
+                      cols,
+                      phi::errors::InvalidArgument(
+                          "if hermitian == true, matrix should be n*n"));
+  }
+  DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x);
+  out->set_dims(dim_x_batch);
+  out->share_lod(x);
+}
+
 void MaxOutInferMeta(const MetaTensor& x,
                      int groups,
                      int axis,
diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h
index a79f21c4a3f535a33d8a46a2008d13b44ad57337..c49e4c88dd89910e41d94e1cf4d5a5d2cd368bcd 100644
--- a/paddle/phi/infermeta/unary.h
+++ b/paddle/phi/infermeta/unary.h
@@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input,
 
 void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out);
 
+void MatrixRankInferMeta(const MetaTensor& x,
+                         bool use_default_tol,
+                         bool hermitian,
+                         MetaTensor* out);
+
 void MaxOutInferMeta(const MetaTensor& x,
                      int groups,
                      int axis,
diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py
index 2fad915be13dfcdebf15f7892d777dd97f79131e..28e0d4eff377f6f23da436d8219da4475a474bff 100644
--- a/python/paddle/fluid/layers/tensor.py
+++ b/python/paddle/fluid/layers/tensor.py
@@ -1752,10 +1752,12 @@ def eye(num_rows,
     else:
         num_columns = num_rows
 
-    if _non_static_mode():
+    if in_dygraph_mode():
+        out = _C_ops.final_state_eye(num_rows, num_columns, dtype,
+                                     _current_expected_place())
+    elif _in_legacy_dygraph():
         out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns',
                          num_columns)
-
     else:
         helper = LayerHelper("eye", **locals())
         check_dtype(dtype, 'dtype',
diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index 9c3ca133270226e0010283bd430fdeb3c33021db..d42166d8324e6b9d7a9d4311453d8014314a0158 100755
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -968,7 +968,7 @@ set_tests_properties(test_lstm_op PROPERTIES TIMEOUT 120)
 set_tests_properties(test_imperative_star_gan_with_gradient_penalty PROPERTIES TIMEOUT 120)
 
 set_tests_properties(test_bicubic_interp_op PROPERTIES TIMEOUT 120)
-set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 120)
+set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 200)
 set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120)
 set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
 set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120)
@@ -1045,7 +1045,7 @@ set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT
 set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120)
 set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120)
 set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120)
-set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120)
+set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300)
 set_tests_properties(test_parallel_executor_transformer_auto_growth PROPERTIES TIMEOUT 120)
 set_tests_properties(test_py_reader_using_executor PROPERTIES TIMEOUT 120)
 set_tests_properties(test_elementwise_add_op PROPERTIES TIMEOUT 120)
diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py
index 508fc1705218a0da72d1fb5213f4663852e08c3f..f5f1479d07d2f0e570624ebe3f84ea20df59da32 100644
--- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py
+++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py
@@ -17,6 +17,7 @@ import paddle.nn.functional as F
 import paddle.nn.initializer as I
 import numpy as np
 import unittest
+from paddle.fluid.framework import _test_eager_guard
 from unittest import TestCase
 
 
@@ -183,6 +184,10 @@ class TestDeformConv2D(TestCase):
             self.place = paddle.CUDAPlace(0)
             self._test_identity()
 
+    def test_identity_with_eager_guard(self):
+        with _test_eager_guard():
+            self.test_identity()
+
 
 class TestDeformConv2DFunctional(TestCase):
     batch_size = 4
@@ -418,6 +423,10 @@ class TestDeformConv2DFunctional(TestCase):
             self.place = paddle.CUDAPlace(0)
             self._test_identity()
 
+    def test_identity_with_eager_guard(self):
+        with _test_eager_guard():
+            self.test_identity()
+
 
 # testcases for DeformConv2D
 class TestDeformConv2DWithPadding(TestDeformConv2D):
diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
index 45a23231945ece4247b5e5f1b9eaa63f8c33f964..5fc849575b6597a2a229434355cadb59d43e75fe 100644
--- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
+++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py
@@ -14,13 +14,15 @@
 
 from __future__ import print_function
 
+import paddle
 import unittest
 import numpy as np
-
-import paddle
 import paddle.fluid.core as core
 import paddle.fluid as fluid
 from op_test import OpTest
+from paddle.fluid.framework import _test_eager_guard
+
+paddle.enable_static()
 
 
 def dmc_bilinear(data_im, height, width, h, w):
@@ -108,8 +110,24 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
     return out
 
 
+def deform_conv2d_wrapper(x,
+                          offset,
+                          weight,
+                          mask=None,
+                          stride=1,
+                          padding=0,
+                          dilation=1,
+                          deformable_groups=1,
+                          groups=1,
+                          im2col_step=1):
+    return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride,
+                                           padding, dilation, deformable_groups,
+                                           groups, mask)
+
+
 class TestModulatedDeformableConvOp(OpTest):
     def setUp(self):
+        self.python_api = deform_conv2d_wrapper
         self.op_type = "deformable_conv"
         self.init_type()
         self.init_group()
@@ -148,13 +166,14 @@ class TestModulatedDeformableConvOp(OpTest):
         self.outputs = {'Output': output}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
     def test_check_grad(self):
         self.check_grad(
             {'Input', 'Offset', 'Mask', 'Filter'},
             'Output',
-            max_relative_error=0.05)
+            max_relative_error=0.05,
+            check_eager=True)
 
     def init_test_case(self):
         self.pad = [1, 1]
@@ -327,6 +346,10 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
 
         self.assertRaises(ValueError, test_invalid_filter)
 
+    def test_error_with_eager_guard(self):
+        with _test_eager_guard():
+            self.test_error()
+
 
 class TestDeformConv2DAPI(unittest.TestCase):
     def test_api(self):
@@ -358,6 +381,10 @@ class TestDeformConv2DAPI(unittest.TestCase):
 
         test_deform_conv2d_v2()
 
+    def test_api_with_eager_guard(self):
+        with _test_eager_guard():
+            self.test_api()
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py
index e8b18d601afae649ba6af49230f41bc0465a8959..304a151c4d3bfc7aa1e5228a1335aaf9b8663a31 100644
--- a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py
+++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py
@@ -14,12 +14,13 @@
 
 from __future__ import print_function
 
+import paddle
 import unittest
 import numpy as np
-
-import paddle.fluid.core as core
 import paddle.fluid as fluid
+import paddle.fluid.core as core
 from op_test import OpTest
+from paddle.fluid.framework import _test_eager_guard
 
 
 def dmc_bilinear(data_im, height, width, h, w):
@@ -105,8 +106,24 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param):
     return out
 
 
+def deform_conv2d_wrapper(x,
+                          offset,
+                          weight,
+                          mask=None,
+                          stride=1,
+                          padding=0,
+                          dilation=1,
+                          deformable_groups=1,
+                          groups=1,
+                          im2col_step=1):
+    return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride,
+                                           padding, dilation, deformable_groups,
+                                           groups, mask)
+
+
 class TestModulatedDeformableConvOp(OpTest):
     def setUp(self):
+        self.python_api = deform_conv2d_wrapper
         self.op_type = "deformable_conv_v1"
         self.init_type()
         self.init_group()
@@ -142,18 +159,22 @@ class TestModulatedDeformableConvOp(OpTest):
         self.outputs = {'Output': output}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
     def test_check_grad(self):
         self.check_grad(
-            ['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05)
+            ['Input', 'Offset', 'Filter'],
+            'Output',
+            max_relative_error=0.05,
+            check_eager=True)
 
     def test_check_grad_no_filter(self):
         self.check_grad(
             ['Input', 'Offset'],
             'Output',
             max_relative_error=0.1,
-            no_grad_set=set(['Filter']))
+            no_grad_set=set(['Filter']),
+            check_eager=True)
 
     def init_test_case(self):
         self.pad = [1, 1]
@@ -292,6 +313,10 @@ class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase):
 
         self.assertRaises(TypeError, test_invalid_offset)
 
+    def test_error_with_eager_guard(self):
+        with _test_eager_guard():
+            self.test_error()
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_eye_op.py b/python/paddle/fluid/tests/unittests/test_eye_op.py
index cb757cffc442598a75aee2cf2ef0b648280844bf..704762d8094148f76063f364606dc65d5686d518 100644
--- a/python/paddle/fluid/tests/unittests/test_eye_op.py
+++ b/python/paddle/fluid/tests/unittests/test_eye_op.py
@@ -28,6 +28,7 @@ class TestEyeOp(OpTest):
         '''
 	Test eye op with specified shape
         '''
+        self.python_api = paddle.eye
         self.op_type = "eye"
 
         self.inputs = {}
@@ -39,7 +40,7 @@ class TestEyeOp(OpTest):
         self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
 
 class TestEyeOp1(OpTest):
@@ -47,6 +48,7 @@ class TestEyeOp1(OpTest):
         '''
 	Test eye op with default parameters
         '''
+        self.python_api = paddle.eye
         self.op_type = "eye"
 
         self.inputs = {}
@@ -54,7 +56,7 @@ class TestEyeOp1(OpTest):
         self.outputs = {'Out': np.eye(50, dtype=float)}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
 
 class TestEyeOp2(OpTest):
@@ -62,6 +64,7 @@ class TestEyeOp2(OpTest):
         '''
         Test eye op with specified shape
         '''
+        self.python_api = paddle.eye
         self.op_type = "eye"
 
         self.inputs = {}
@@ -69,7 +72,7 @@ class TestEyeOp2(OpTest):
         self.outputs = {'Out': np.eye(99, 1, dtype=float)}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
 
 class API_TestTensorEye(unittest.TestCase):
diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py
index d0b84a0d7e1082ddb8bca5879be2aca3962f3535..b13b3462617627aa28a204e80a5e6239ba1d60fd 100644
--- a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py
+++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py
@@ -30,8 +30,13 @@ SEED = 2049
 np.random.seed(SEED)
 
 
+def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False):
+    return paddle.linalg.matrix_rank(x, tol, hermitian)
+
+
 class TestMatrixRankOP(OpTest):
     def setUp(self):
+        self.python_api = matrix_rank_wraper
         self.op_type = "matrix_rank"
         self.init_data()
         self.inputs = {'X': self.x}
@@ -44,7 +49,7 @@ class TestMatrixRankOP(OpTest):
         self.outputs = {'Out': self.out}
 
     def test_check_output(self):
-        self.check_output()
+        self.check_output(check_eager=True)
 
     def init_data(self):
         self.x = np.eye(3, dtype=np.float32)
@@ -110,6 +115,28 @@ class TestMatrixRankOP5(TestMatrixRankOP):
                                          self.hermitian)
 
 
+class TestMatrixRankOP6(TestMatrixRankOP):
+    def init_data(self):
+        self.x = np.random.rand(3, 4, 5, 6).astype(np.float32)
+        self.tol_tensor = None
+        self.tol = None
+        self.use_default_tol = False
+        self.hermitian = False
+        self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
+                                         self.hermitian)
+
+
+class TestMatrixRankOP7(TestMatrixRankOP):
+    def init_data(self):
+        self.x = np.eye(200, dtype=np.float64)
+        self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype)
+        self.tol = None
+        self.use_default_tol = True
+        self.hermitian = True
+        self.out = np.linalg.matrix_rank(self.x, self.tol_tensor,
+                                         self.hermitian)
+
+
 class TestMatrixRankAPI(unittest.TestCase):
     def test_dygraph(self):
         paddle.disable_static()
diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py
index 23c83b1e3810279c8a1312266d1fa1c865ef10da..33ff27202031f4ec65164d6f94642fc7d4ce153a 100644
--- a/python/paddle/tensor/linalg.py
+++ b/python/paddle/tensor/linalg.py
@@ -1288,8 +1288,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None):
             #      [1, 1, 1, 1]]
 
     """
+    if in_dygraph_mode():
+        if isinstance(tol, Variable):
+            if tol.dtype != x.dtype:
+                tol_tensor = cast(tol, x.dtype)
+            else:
+                tol_tensor = tol
+            use_default_tol = False
+            return _C_ops.final_state_matrix_rank_tol(
+                x, tol_tensor, use_default_tol, hermitian)
 
-    if paddle.in_dynamic_mode():
+        if tol is None:
+            tol_attr = 0.0
+            use_default_tol = True
+        else:
+            tol_attr = float(tol)
+            use_default_tol = False
+        return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol,
+                                              hermitian)
+
+    if _in_legacy_dygraph():
         if tol is None:
             tol_tensor = None
             tol_attr = 0.0
diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml
index 6cb7bfa793f555f38ab49e2b1ea865a986f0362b..718c35683cb0bf08490b1b50036fb5b10806106e 100644
--- a/python/paddle/utils/code_gen/api.yaml
+++ b/python/paddle/utils/code_gen/api.yaml
@@ -435,6 +435,16 @@
     func : cumsum
   backward : cumsum_grad
 
+- api : deformable_conv
+  args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
+  output : Tensor(out)
+  infer_meta : 
+    func : DeformableConvInferMeta
+  kernel :
+    func : deformable_conv
+  optional : mask
+  backward : deformable_conv_grad
+
 - api : depthwise_conv2d_transpose
   args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
   output : Tensor(out)
@@ -609,6 +619,18 @@
     func : expm1
   backward : expm1_grad
 
+- api : eye
+  args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={})
+  output : Tensor(out)
+  infer_meta :
+    func : EyeInferMeta
+    param : [num_rows, num_columns, dtype]
+  kernel :
+    func : eye
+    param : [num_rows, num_columns, dtype]
+    data_type : dtype
+    backend : place
+
 - api : flatten
   args : (Tensor x, int start_axis, int stop_axis)
   output : Tensor(out), Tensor(xshape)
@@ -1167,6 +1189,23 @@
     func : matrix_power
   backward : matrix_power_grad
 
+- api : matrix_rank
+  args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false)
+  output : Tensor(out)
+  infer_meta :
+    func : MatrixRankInferMeta
+    param : [x, use_default_tol, hermitian]
+  kernel :
+    func : matrix_rank
+
+- api : matrix_rank_tol
+  args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false)
+  output : Tensor(out)
+  infer_meta :
+    func : MatrixRankTolInferMeta
+  kernel :
+    func : matrix_rank_tol
+
 - api : max
   args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
   output : Tensor(out)
diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml
index 4c50967b6f4d1f0ef96107dcd9e0bef63cde15eb..f60563d5d018e8f281833209ceb86da35b3bb699 100644
--- a/python/paddle/utils/code_gen/backward.yaml
+++ b/python/paddle/utils/code_gen/backward.yaml
@@ -302,6 +302,16 @@
   output : Tensor(x_grad)
   invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse)
 
+- backward_api : deformable_conv_grad
+  forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out)
+  args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step)
+  output : Tensor(x_grad), Tensor(offset_grad), Tensor(filter_grad), Tensor(mask_grad)
+  infer_meta :
+    func : DeformableConvGradInferMeta
+  kernel :
+    func : deformable_conv_grad
+  optional : mask
+
 - backward_api : depthwise_conv2d_transpose_grad
   forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out)
   args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format)
diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py
index 2ed01d42cfb8c239258999f235c3d67b09df3be2..8fa51df9ac10d7f6213dfe5906395914a4527637 100644
--- a/python/paddle/vision/ops.py
+++ b/python/paddle/vision/ops.py
@@ -558,7 +558,15 @@ def deform_conv2d(x,
 
     use_deform_conv2d_v1 = True if mask is None else False
 
-    if _non_static_mode():
+    if in_dygraph_mode():
+        pre_bias = _C_ops.final_state_deformable_conv(
+            x, offset, weight, mask, stride, padding, dilation,
+            deformable_groups, groups, 1)
+        if bias is not None:
+            out = nn.elementwise_add(pre_bias, bias, axis=1)
+        else:
+            out = pre_bias
+    elif _in_legacy_dygraph():
         attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation,
                  'deformable_groups', deformable_groups, 'groups', groups,
                  'im2col_step', 1)
diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json
index 64fc4c618aebc342f4be06f6e7d1d494a37f16b8..eeb94d67032e23525e1aae51e81d7314d64543b8 100644
--- a/tools/infrt/skipped_phi_api.json
+++ b/tools/infrt/skipped_phi_api.json
@@ -1,4 +1,4 @@
 {
-"phi_apis":["conj", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"],
+"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"],
 "phi_kernels":["equal_all"]
 }