From 5e835d3698c60f0d51382226b297bfbba3d81f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:07:34 +0800 Subject: [PATCH] [Dy2St] RandomHorizontalFlip (#49590) * Dy2St_RandomHorizontalFlip * Dy2St_RandomHorizontalFlip --- python/paddle/tests/test_transforms_static.py | 10 ++++++++++ python/paddle/vision/transforms/transforms.py | 13 +++++++++++++ 2 files changed, 23 insertions(+) diff --git a/python/paddle/tests/test_transforms_static.py b/python/paddle/tests/test_transforms_static.py index e8cb449df0..d6656c96ab 100644 --- a/python/paddle/tests/test_transforms_static.py +++ b/python/paddle/tests/test_transforms_static.py @@ -100,6 +100,16 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase): self.api = transforms.RandomVerticalFlip(prob=1) +class TestRandomHorizontalFlip0(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomHorizontalFlip(0) + + +class TestRandomHorizontalFlip1(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomHorizontalFlip(1) + + class TestRandomCrop_random(TestTransformUnitTestBase): def get_shape(self): return (3, 240, 240) diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 7dd4daeaeb..8b8d822886 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -709,10 +709,23 @@ class RandomHorizontalFlip(BaseTransform): self.prob = prob def _apply_image(self, img): + if paddle.in_dynamic_mode(): + return self._dynamic_apply_image(img) + else: + return self._static_apply_image(img) + + def _dynamic_apply_image(self, img): if random.random() < self.prob: return F.hflip(img) return img + def _static_apply_image(self, img): + return paddle.static.nn.cond( + paddle.rand(shape=(1,)) < self.prob, + lambda: F.hflip(img), + lambda: img, + ) + class RandomVerticalFlip(BaseTransform): """Vertically flip the input data randomly with a given probability. -- GitLab