From a05762508f3573f3be3356b9234d5cdfaf5f5246 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Fri, 16 Dec 2022 10:29:31 +0800 Subject: [PATCH] [Dy2St] transforms.RandomCrop Support static mode (#49057) * add RandomCrop * 10e-5 => eps * add same shape test --- python/paddle/tests/test_transforms_static.py | 44 +++++++++++++++++++ python/paddle/vision/transforms/transforms.py | 8 +++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/python/paddle/tests/test_transforms_static.py b/python/paddle/tests/test_transforms_static.py index 47a0606c5b..4a8018b688 100644 --- a/python/paddle/tests/test_transforms_static.py +++ b/python/paddle/tests/test_transforms_static.py @@ -60,6 +60,8 @@ class TestTransformUnitTestBase(unittest.TestCase): def test_transform(self): dy_res = self.dynamic_transform() + if isinstance(dy_res, paddle.Tensor): + dy_res = dy_res.numpy() st_res = self.static_transform() np.testing.assert_almost_equal(dy_res, st_res) @@ -98,5 +100,47 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase): self.api = transforms.RandomVerticalFlip(prob=1) +class TestRandomCrop_random(TestTransformUnitTestBase): + def get_shape(self): + return (3, 240, 240) + + def set_trans_api(self): + self.crop_size = (224, 224) + self.api = transforms.RandomCrop(self.crop_size) + + def assert_test_random_equal(self, res, eps=10e-5): + + _, h, w = self.get_shape() + c_h, c_w = self.crop_size + res_assert = True + for y in range(h - c_h): + for x in range(w - c_w): + diff_abs_sum = np.abs( + (self.img[:, y : y + c_h, x : x + c_w] - res) + ).sum() + if diff_abs_sum < eps: + res_assert = False + break + if not res_assert: + break + assert not res_assert + + def test_transform(self): + dy_res = self.dynamic_transform().numpy() + st_res = self.static_transform() + + self.assert_test_random_equal(dy_res) + self.assert_test_random_equal(st_res) + + +class TestRandomCrop_same(TestTransformUnitTestBase): + def get_shape(self): + return (3, 224, 224) + + def set_trans_api(self): + self.crop_size = (224, 224) + self.api = transforms.RandomCrop(self.crop_size) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 14d2511994..8c1554e00d 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -1138,8 +1138,12 @@ class RandomCrop(BaseTransform): if w == tw and h == th: return 0, 0, h, w - i = random.randint(0, h - th) - j = random.randint(0, w - tw) + if paddle.in_dynamic_mode(): + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + else: + i = paddle.randint(low=0, high=h - th) + j = paddle.randint(low=0, high=w - tw) return i, j, th, tw def _apply_image(self, img): -- GitLab