From 33583ddecaeee78afc6956fd76658f64b3043ad8 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Thu, 14 Apr 2022 00:13:34 +0800 Subject: [PATCH] [Cherry pick] Add some final state OPs (#41737) * Add yaml for matrix rank op (#41466) * modify matrix_rank * add matrix_rank shape * add matrix_rank shape * Add yaml for matrix_rank OP * Add UT Co-authored-by: zhoujianqian <15205085056@163.com> * Add yaml for eye OP (#41476) * [cherry-pick] Add yaml config for matrix_rank, eye, deformable_conv and deformable_conv_v1 OPs * Add yaml for deformable_conv and deformable_conv_v1 OPs * Add UT * Add to skipped_phi_api list for infrt Co-authored-by: zhoujianqian <15205085056@163.com> --- paddle/phi/infermeta/backward.cc | 21 ++++++++ paddle/phi/infermeta/backward.h | 16 ++++++ paddle/phi/infermeta/binary.cc | 51 +++++++++++++++++++ paddle/phi/infermeta/binary.h | 6 +++ paddle/phi/infermeta/unary.cc | 35 +++++++++++++ paddle/phi/infermeta/unary.h | 5 ++ python/paddle/fluid/layers/tensor.py | 6 ++- .../fluid/tests/unittests/CMakeLists.txt | 4 +- .../tests/unittests/test_deform_conv2d.py | 9 ++++ .../unittests/test_deformable_conv_op.py | 35 +++++++++++-- .../unittests/test_deformable_conv_v1_op.py | 35 +++++++++++-- .../fluid/tests/unittests/test_eye_op.py | 9 ++-- .../tests/unittests/test_matrix_rank_op.py | 29 ++++++++++- python/paddle/tensor/linalg.py | 20 +++++++- python/paddle/utils/code_gen/api.yaml | 39 ++++++++++++++ python/paddle/utils/code_gen/backward.yaml | 10 ++++ python/paddle/vision/ops.py | 10 +++- tools/infrt/skipped_phi_api.json | 2 +- 18 files changed, 322 insertions(+), 20 deletions(-) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 49e416fd01..c0c50d6886 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -169,6 +169,27 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, logits_grad->set_dtype(softmax.dtype()); } +void DeformableConvGradInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const MetaTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* dx, + MetaTensor* offset_grad, + MetaTensor* filter_grad, + MetaTensor* mask_grad) { + GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad); + if (mask) { + UnchangedInferMeta(mask.get(), mask_grad); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index eff3731bf2..ad375e6093 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -79,6 +79,22 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, MetaTensor* logits_grad, MetaConfig config = MetaConfig()); +void DeformableConvGradInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const MetaTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* dx, + MetaTensor* offset_grad, + MetaTensor* filter_grad, + MetaTensor* mask_grad); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 298ad14f9e..2139605fb2 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -64,6 +64,16 @@ static void BinarySameInputDimsCheck(const MetaTensor& x, } } +// Used in MatrixRankTolInferMeta +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} + } // namespace detail void AllValueCompareInferMeta(const MetaTensor& x, @@ -1465,6 +1475,47 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, out->share_lod(x); } +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + auto dim_tol = atol_tensor.dims(); + if (dim_x_batch == dim_tol) { + out->set_dims(dim_x_batch); + } else { + int max_dim = std::max(dim_x_batch.size(), dim_tol.size()); + int axis = std::abs(dim_x_batch.size() - dim_tol.size()); + std::vector x_batch_dims_array(max_dim); + std::vector tol_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(dim_x_batch, + dim_tol, + x_batch_dims_array.data(), + tol_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + } + out->share_lod(x); +} + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out) { auto dim_x = x.dims(); auto dim_vec = vec.dims(); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 70c3c9dfe8..192fa214c9 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -218,6 +218,12 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, int y_num_col_dims, MetaTensor* out); +void MatrixRankTolInferMeta(const MetaTensor& x, + const MetaTensor& atol_tensor, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MvInferMeta(const MetaTensor& x, const MetaTensor& vec, MetaTensor* out); void PReluInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index fa3ea84c93..c6e2cb7619 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -31,6 +31,18 @@ limitations under the License. */ namespace phi { +namespace detail { +// Used in MatrixRankInferMeta +static DDim CheckAndGetOutputDim(const DDim& dim_x) { + auto x_vec = phi::vectorize(dim_x); + if (x_vec.size() == 2) { + return phi::make_ddim({1}); + } + x_vec.erase(x_vec.end() - 2, x_vec.end()); + return phi::make_ddim(x_vec); +} +} // namespace detail + void ArgMinMaxInferMeta(const MetaTensor& x, int64_t axis, bool keepdims, @@ -901,6 +913,29 @@ void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out) { out->set_dtype(x.dtype()); } +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out) { + auto dim_x = x.dims(); + PADDLE_ENFORCE_GE( + dim_x.size(), + 2, + phi::errors::InvalidArgument("The dims of input must be greater than 2")); + + if (hermitian) { + int rows = dim_x[dim_x.size() - 2]; + int cols = dim_x[dim_x.size() - 1]; + PADDLE_ENFORCE_EQ(rows, + cols, + phi::errors::InvalidArgument( + "if hermitian == true, matrix should be n*n")); + } + DDim dim_x_batch = detail::CheckAndGetOutputDim(dim_x); + out->set_dims(dim_x_batch); + out->share_lod(x); +} + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index a79f21c4a3..c49e4c88dd 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -142,6 +142,11 @@ void LogsumexpInferMeta(const MetaTensor& input, void MatrixPowerInferMeta(const MetaTensor& x, int n, MetaTensor* out); +void MatrixRankInferMeta(const MetaTensor& x, + bool use_default_tol, + bool hermitian, + MetaTensor* out); + void MaxOutInferMeta(const MetaTensor& x, int groups, int axis, diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 2fad915be1..28e0d4eff3 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1752,10 +1752,12 @@ def eye(num_rows, else: num_columns = num_rows - if _non_static_mode(): + if in_dygraph_mode(): + out = _C_ops.final_state_eye(num_rows, num_columns, dtype, + _current_expected_place()) + elif _in_legacy_dygraph(): out = _C_ops.eye('dtype', dtype, 'num_rows', num_rows, 'num_columns', num_columns) - else: helper = LayerHelper("eye", **locals()) check_dtype(dtype, 'dtype', diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9c3ca13327..d42166d832 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -968,7 +968,7 @@ set_tests_properties(test_lstm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_star_gan_with_gradient_penalty PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 200) set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120) @@ -1045,7 +1045,7 @@ set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120) -set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300) set_tests_properties(test_parallel_executor_transformer_auto_growth PROPERTIES TIMEOUT 120) set_tests_properties(test_py_reader_using_executor PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_add_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py index 508fc17052..f5f1479d07 100644 --- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py @@ -17,6 +17,7 @@ import paddle.nn.functional as F import paddle.nn.initializer as I import numpy as np import unittest +from paddle.fluid.framework import _test_eager_guard from unittest import TestCase @@ -183,6 +184,10 @@ class TestDeformConv2D(TestCase): self.place = paddle.CUDAPlace(0) self._test_identity() + def test_identity_with_eager_guard(self): + with _test_eager_guard(): + self.test_identity() + class TestDeformConv2DFunctional(TestCase): batch_size = 4 @@ -418,6 +423,10 @@ class TestDeformConv2DFunctional(TestCase): self.place = paddle.CUDAPlace(0) self._test_identity() + def test_identity_with_eager_guard(self): + with _test_eager_guard(): + self.test_identity() + # testcases for DeformConv2D class TestDeformConv2DWithPadding(TestDeformConv2D): diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index 45a2323194..5fc849575b 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -14,13 +14,15 @@ from __future__ import print_function +import paddle import unittest import numpy as np - -import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard + +paddle.enable_static() def dmc_bilinear(data_im, height, width, h, w): @@ -108,8 +110,24 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): return out +def deform_conv2d_wrapper(x, + offset, + weight, + mask=None, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + im2col_step=1): + return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride, + padding, dilation, deformable_groups, + groups, mask) + + class TestModulatedDeformableConvOp(OpTest): def setUp(self): + self.python_api = deform_conv2d_wrapper self.op_type = "deformable_conv" self.init_type() self.init_group() @@ -148,13 +166,14 @@ class TestModulatedDeformableConvOp(OpTest): self.outputs = {'Output': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad( {'Input', 'Offset', 'Mask', 'Filter'}, 'Output', - max_relative_error=0.05) + max_relative_error=0.05, + check_eager=True) def init_test_case(self): self.pad = [1, 1] @@ -327,6 +346,10 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.assertRaises(ValueError, test_invalid_filter) + def test_error_with_eager_guard(self): + with _test_eager_guard(): + self.test_error() + class TestDeformConv2DAPI(unittest.TestCase): def test_api(self): @@ -358,6 +381,10 @@ class TestDeformConv2DAPI(unittest.TestCase): test_deform_conv2d_v2() + def test_api_with_eager_guard(self): + with _test_eager_guard(): + self.test_api() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py index e8b18d601a..304a151c4d 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py @@ -14,12 +14,13 @@ from __future__ import print_function +import paddle import unittest import numpy as np - -import paddle.fluid.core as core import paddle.fluid as fluid +import paddle.fluid.core as core from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard def dmc_bilinear(data_im, height, width, h, w): @@ -105,8 +106,24 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param): return out +def deform_conv2d_wrapper(x, + offset, + weight, + mask=None, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + im2col_step=1): + return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride, + padding, dilation, deformable_groups, + groups, mask) + + class TestModulatedDeformableConvOp(OpTest): def setUp(self): + self.python_api = deform_conv2d_wrapper self.op_type = "deformable_conv_v1" self.init_type() self.init_group() @@ -142,18 +159,22 @@ class TestModulatedDeformableConvOp(OpTest): self.outputs = {'Output': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad( - ['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05) + ['Input', 'Offset', 'Filter'], + 'Output', + max_relative_error=0.05, + check_eager=True) def test_check_grad_no_filter(self): self.check_grad( ['Input', 'Offset'], 'Output', max_relative_error=0.1, - no_grad_set=set(['Filter'])) + no_grad_set=set(['Filter']), + check_eager=True) def init_test_case(self): self.pad = [1, 1] @@ -292,6 +313,10 @@ class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase): self.assertRaises(TypeError, test_invalid_offset) + def test_error_with_eager_guard(self): + with _test_eager_guard(): + self.test_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eye_op.py b/python/paddle/fluid/tests/unittests/test_eye_op.py index cb757cffc4..704762d809 100644 --- a/python/paddle/fluid/tests/unittests/test_eye_op.py +++ b/python/paddle/fluid/tests/unittests/test_eye_op.py @@ -28,6 +28,7 @@ class TestEyeOp(OpTest): ''' Test eye op with specified shape ''' + self.python_api = paddle.eye self.op_type = "eye" self.inputs = {} @@ -39,7 +40,7 @@ class TestEyeOp(OpTest): self.outputs = {'Out': np.eye(219, 319, dtype=np.int32)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestEyeOp1(OpTest): @@ -47,6 +48,7 @@ class TestEyeOp1(OpTest): ''' Test eye op with default parameters ''' + self.python_api = paddle.eye self.op_type = "eye" self.inputs = {} @@ -54,7 +56,7 @@ class TestEyeOp1(OpTest): self.outputs = {'Out': np.eye(50, dtype=float)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestEyeOp2(OpTest): @@ -62,6 +64,7 @@ class TestEyeOp2(OpTest): ''' Test eye op with specified shape ''' + self.python_api = paddle.eye self.op_type = "eye" self.inputs = {} @@ -69,7 +72,7 @@ class TestEyeOp2(OpTest): self.outputs = {'Out': np.eye(99, 1, dtype=float)} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class API_TestTensorEye(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py index d0b84a0d7e..b13b346261 100644 --- a/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py +++ b/python/paddle/fluid/tests/unittests/test_matrix_rank_op.py @@ -30,8 +30,13 @@ SEED = 2049 np.random.seed(SEED) +def matrix_rank_wraper(x, tol=None, use_default_tol=True, hermitian=False): + return paddle.linalg.matrix_rank(x, tol, hermitian) + + class TestMatrixRankOP(OpTest): def setUp(self): + self.python_api = matrix_rank_wraper self.op_type = "matrix_rank" self.init_data() self.inputs = {'X': self.x} @@ -44,7 +49,7 @@ class TestMatrixRankOP(OpTest): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def init_data(self): self.x = np.eye(3, dtype=np.float32) @@ -110,6 +115,28 @@ class TestMatrixRankOP5(TestMatrixRankOP): self.hermitian) +class TestMatrixRankOP6(TestMatrixRankOP): + def init_data(self): + self.x = np.random.rand(3, 4, 5, 6).astype(np.float32) + self.tol_tensor = None + self.tol = None + self.use_default_tol = False + self.hermitian = False + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + +class TestMatrixRankOP7(TestMatrixRankOP): + def init_data(self): + self.x = np.eye(200, dtype=np.float64) + self.tol_tensor = np.random.random([200, 200]).astype(self.x.dtype) + self.tol = None + self.use_default_tol = True + self.hermitian = True + self.out = np.linalg.matrix_rank(self.x, self.tol_tensor, + self.hermitian) + + class TestMatrixRankAPI(unittest.TestCase): def test_dygraph(self): paddle.disable_static() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 23c83b1e38..33ff272020 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1288,8 +1288,26 @@ def matrix_rank(x, tol=None, hermitian=False, name=None): # [1, 1, 1, 1]] """ + if in_dygraph_mode(): + if isinstance(tol, Variable): + if tol.dtype != x.dtype: + tol_tensor = cast(tol, x.dtype) + else: + tol_tensor = tol + use_default_tol = False + return _C_ops.final_state_matrix_rank_tol( + x, tol_tensor, use_default_tol, hermitian) - if paddle.in_dynamic_mode(): + if tol is None: + tol_attr = 0.0 + use_default_tol = True + else: + tol_attr = float(tol) + use_default_tol = False + return _C_ops.final_state_matrix_rank(x, tol_attr, use_default_tol, + hermitian) + + if _in_legacy_dygraph(): if tol is None: tol_tensor = None tol_attr = 0.0 diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 6cb7bfa793..718c35683c 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -435,6 +435,16 @@ func : cumsum backward : cumsum_grad +- api : deformable_conv + args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) + output : Tensor(out) + infer_meta : + func : DeformableConvInferMeta + kernel : + func : deformable_conv + optional : mask + backward : deformable_conv_grad + - api : depthwise_conv2d_transpose args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) output : Tensor(out) @@ -609,6 +619,18 @@ func : expm1 backward : expm1_grad +- api : eye + args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={}) + output : Tensor(out) + infer_meta : + func : EyeInferMeta + param : [num_rows, num_columns, dtype] + kernel : + func : eye + param : [num_rows, num_columns, dtype] + data_type : dtype + backend : place + - api : flatten args : (Tensor x, int start_axis, int stop_axis) output : Tensor(out), Tensor(xshape) @@ -1167,6 +1189,23 @@ func : matrix_power backward : matrix_power_grad +- api : matrix_rank + args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankInferMeta + param : [x, use_default_tol, hermitian] + kernel : + func : matrix_rank + +- api : matrix_rank_tol + args : (Tensor x, Tensor atol_tensor, bool use_default_tol=true, bool hermitian=false) + output : Tensor(out) + infer_meta : + func : MatrixRankTolInferMeta + kernel : + func : matrix_rank_tol + - api : max args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 4c50967b6f..f60563d5d0 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -302,6 +302,16 @@ output : Tensor(x_grad) invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) +- backward_api : deformable_conv_grad + forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) + args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) + output : Tensor(x_grad), Tensor(offset_grad), Tensor(filter_grad), Tensor(mask_grad) + infer_meta : + func : DeformableConvGradInferMeta + kernel : + func : deformable_conv_grad + optional : mask + - backward_api : depthwise_conv2d_transpose_grad forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 2ed01d42cf..8fa51df9ac 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -558,7 +558,15 @@ def deform_conv2d(x, use_deform_conv2d_v1 = True if mask is None else False - if _non_static_mode(): + if in_dygraph_mode(): + pre_bias = _C_ops.final_state_deformable_conv( + x, offset, weight, mask, stride, padding, dilation, + deformable_groups, groups, 1) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=1) + else: + out = pre_bias + elif _in_legacy_dygraph(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'deformable_groups', deformable_groups, 'groups', groups, 'im2col_step', 1) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 64fc4c618a..eeb94d6703 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"], +"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"], "phi_kernels":["equal_all"] } -- GitLab