diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 81d3cb9ddf0f4dccdd5f01b3b75f02c10979a219..efbf02e3314333f1e12a1b65856309822a3d2465 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -169,6 +169,27 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, logits_grad->set_dtype(softmax.dtype()); } +void DeformableConvGradInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const MetaTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* dx, + MetaTensor* offset_grad, + MetaTensor* filter_grad, + MetaTensor* mask_grad) { + GeneralTernaryGradInferMeta(x, offset, filter, dx, offset_grad, filter_grad); + if (mask) { + UnchangedInferMeta(mask.get(), mask_grad); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 058ff7541cd8b770f10fb0e839dfca3f2b338e95..6e730c83d1d5065cf95331f460aa850c3127232b 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -79,6 +79,22 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, MetaTensor* logits_grad, MetaConfig config = MetaConfig()); +void DeformableConvGradInferMeta(const MetaTensor& x, + const MetaTensor& offset, + const MetaTensor& filter, + paddle::optional mask, + const MetaTensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + int deformable_groups, + int groups, + int im2col_step, + MetaTensor* dx, + MetaTensor* offset_grad, + MetaTensor* filter_grad, + MetaTensor* mask_grad); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d947784e518c86d0f0747f9f9cb62f6ef7165baa..f7f88ab76f22775489df8dd7d974853bffbbc0e6 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -968,7 +968,7 @@ set_tests_properties(test_lstm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_star_gan_with_gradient_penalty PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_deformable_conv_op PROPERTIES TIMEOUT 200) set_tests_properties(test_nearest_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_profiler PROPERTIES TIMEOUT 120) set_tests_properties(test_inplace_softmax_with_cross_entropy PROPERTIES TIMEOUT 120) @@ -1045,7 +1045,7 @@ set_tests_properties(test_distributed_fused_lamb_op_with_clip PROPERTIES TIMEOUT set_tests_properties(test_distributed_fused_lamb_op_without_clip PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_min_op PROPERTIES TIMEOUT 120) set_tests_properties(test_nan_inf PROPERTIES TIMEOUT 120) -set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_deformable_conv_v1_op PROPERTIES TIMEOUT 300) set_tests_properties(test_parallel_executor_transformer_auto_growth PROPERTIES TIMEOUT 120) set_tests_properties(test_py_reader_using_executor PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_add_op PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py index 508fc1705218a0da72d1fb5213f4663852e08c3f..f5f1479d07d2f0e570624ebe3f84ea20df59da32 100644 --- a/python/paddle/fluid/tests/unittests/test_deform_conv2d.py +++ b/python/paddle/fluid/tests/unittests/test_deform_conv2d.py @@ -17,6 +17,7 @@ import paddle.nn.functional as F import paddle.nn.initializer as I import numpy as np import unittest +from paddle.fluid.framework import _test_eager_guard from unittest import TestCase @@ -183,6 +184,10 @@ class TestDeformConv2D(TestCase): self.place = paddle.CUDAPlace(0) self._test_identity() + def test_identity_with_eager_guard(self): + with _test_eager_guard(): + self.test_identity() + class TestDeformConv2DFunctional(TestCase): batch_size = 4 @@ -418,6 +423,10 @@ class TestDeformConv2DFunctional(TestCase): self.place = paddle.CUDAPlace(0) self._test_identity() + def test_identity_with_eager_guard(self): + with _test_eager_guard(): + self.test_identity() + # testcases for DeformConv2D class TestDeformConv2DWithPadding(TestDeformConv2D): diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index 45a23231945ece4247b5e5f1b9eaa63f8c33f964..5fc849575b6597a2a229434355cadb59d43e75fe 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -14,13 +14,15 @@ from __future__ import print_function +import paddle import unittest import numpy as np - -import paddle import paddle.fluid.core as core import paddle.fluid as fluid from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard + +paddle.enable_static() def dmc_bilinear(data_im, height, width, h, w): @@ -108,8 +110,24 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): return out +def deform_conv2d_wrapper(x, + offset, + weight, + mask=None, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + im2col_step=1): + return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride, + padding, dilation, deformable_groups, + groups, mask) + + class TestModulatedDeformableConvOp(OpTest): def setUp(self): + self.python_api = deform_conv2d_wrapper self.op_type = "deformable_conv" self.init_type() self.init_group() @@ -148,13 +166,14 @@ class TestModulatedDeformableConvOp(OpTest): self.outputs = {'Output': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad( {'Input', 'Offset', 'Mask', 'Filter'}, 'Output', - max_relative_error=0.05) + max_relative_error=0.05, + check_eager=True) def init_test_case(self): self.pad = [1, 1] @@ -327,6 +346,10 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.assertRaises(ValueError, test_invalid_filter) + def test_error_with_eager_guard(self): + with _test_eager_guard(): + self.test_error() + class TestDeformConv2DAPI(unittest.TestCase): def test_api(self): @@ -358,6 +381,10 @@ class TestDeformConv2DAPI(unittest.TestCase): test_deform_conv2d_v2() + def test_api_with_eager_guard(self): + with _test_eager_guard(): + self.test_api() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py index e8b18d601afae649ba6af49230f41bc0465a8959..304a151c4d3bfc7aa1e5228a1335aaf9b8663a31 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py @@ -14,12 +14,13 @@ from __future__ import print_function +import paddle import unittest import numpy as np - -import paddle.fluid.core as core import paddle.fluid as fluid +import paddle.fluid.core as core from op_test import OpTest +from paddle.fluid.framework import _test_eager_guard def dmc_bilinear(data_im, height, width, h, w): @@ -105,8 +106,24 @@ def dconv_im2col_gemm(input, offset, filter, group, conv_param): return out +def deform_conv2d_wrapper(x, + offset, + weight, + mask=None, + stride=1, + padding=0, + dilation=1, + deformable_groups=1, + groups=1, + im2col_step=1): + return paddle.vision.ops.deform_conv2d(x, offset, weight, None, stride, + padding, dilation, deformable_groups, + groups, mask) + + class TestModulatedDeformableConvOp(OpTest): def setUp(self): + self.python_api = deform_conv2d_wrapper self.op_type = "deformable_conv_v1" self.init_type() self.init_group() @@ -142,18 +159,22 @@ class TestModulatedDeformableConvOp(OpTest): self.outputs = {'Output': output} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): self.check_grad( - ['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05) + ['Input', 'Offset', 'Filter'], + 'Output', + max_relative_error=0.05, + check_eager=True) def test_check_grad_no_filter(self): self.check_grad( ['Input', 'Offset'], 'Output', max_relative_error=0.1, - no_grad_set=set(['Filter'])) + no_grad_set=set(['Filter']), + check_eager=True) def init_test_case(self): self.pad = [1, 1] @@ -292,6 +313,10 @@ class TestModulatedDeformableConvV1InvalidInput(unittest.TestCase): self.assertRaises(TypeError, test_invalid_offset) + def test_error_with_eager_guard(self): + with _test_eager_guard(): + self.test_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 08028ba17185cc141ca64e1eac1c2a8a1f1b1e51..6387525fa26f1bdf2ff0382b7343d1577397a6bb 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -451,6 +451,16 @@ func : cumsum backward : cumsum_grad +- api : deformable_conv + args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) + output : Tensor(out) + infer_meta : + func : DeformableConvInferMeta + kernel : + func : deformable_conv + optional : mask + backward : deformable_conv_grad + - api : depthwise_conv2d_transpose args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f8366744bdbe6d78e3a0a08c22599643e6ce53d7..d243b4d160d570cabc32ef53c8c61add425b6ee2 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -339,6 +339,16 @@ output : Tensor(x_grad) invoke : cumsum(out_grad, axis, flatten, exclusive, !reverse) +- backward_api : deformable_conv_grad + forward : deformable_conv(Tensor x, Tensor offset, Tensor filter, Tensor mask, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) -> Tensor(out) + args : (Tensor x, Tensor offset, Tensor filter, Tensor mask, Tensor out_grad, int[] strides, int[] paddings, int[] dilations, int deformable_groups, int groups, int im2col_step) + output : Tensor(x_grad), Tensor(offset_grad), Tensor(filter_grad), Tensor(mask_grad) + infer_meta : + func : DeformableConvGradInferMeta + kernel : + func : deformable_conv_grad + optional : mask + - backward_api : depthwise_conv2d_transpose_grad forward : depthwise_conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index 2ed01d42cfb8c239258999f235c3d67b09df3be2..8fa51df9ac10d7f6213dfe5906395914a4527637 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -558,7 +558,15 @@ def deform_conv2d(x, use_deform_conv2d_v1 = True if mask is None else False - if _non_static_mode(): + if in_dygraph_mode(): + pre_bias = _C_ops.final_state_deformable_conv( + x, offset, weight, mask, stride, padding, dilation, + deformable_groups, groups, 1) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=1) + else: + out = pre_bias + elif _in_legacy_dygraph(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'deformable_groups', deformable_groups, 'groups', groups, 'im2col_step', 1) diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index b352240c6dcc5cff69ae19bfd0d9d6a6795dfa2a..2502e248c5c48310dbf979d0339c0f8c97b321b3 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"], +"phi_apis":["conj", "deformable_conv", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth", "layer_norm"], "phi_kernels":["equal_all"] }