diff --git a/python/paddle/tests/test_transforms_static.py b/python/paddle/tests/test_transforms_static.py index e8cb449df0014cb091072588539980212b8eefc0..d6656c96ab1863cca1c8d24af6f81fe082f18a38 100644 --- a/python/paddle/tests/test_transforms_static.py +++ b/python/paddle/tests/test_transforms_static.py @@ -100,6 +100,16 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase): self.api = transforms.RandomVerticalFlip(prob=1) +class TestRandomHorizontalFlip0(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomHorizontalFlip(0) + + +class TestRandomHorizontalFlip1(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomHorizontalFlip(1) + + class TestRandomCrop_random(TestTransformUnitTestBase): def get_shape(self): return (3, 240, 240) diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index 7dd4daeaebb26a440c68926afcc5f5f54abe4f2e..8b8d822886201d4ea7e0b16d698cdc024d8cf105 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -709,10 +709,23 @@ class RandomHorizontalFlip(BaseTransform): self.prob = prob def _apply_image(self, img): + if paddle.in_dynamic_mode(): + return self._dynamic_apply_image(img) + else: + return self._static_apply_image(img) + + def _dynamic_apply_image(self, img): if random.random() < self.prob: return F.hflip(img) return img + def _static_apply_image(self, img): + return paddle.static.nn.cond( + paddle.rand(shape=(1,)) < self.prob, + lambda: F.hflip(img), + lambda: img, + ) + class RandomVerticalFlip(BaseTransform): """Vertically flip the input data randomly with a given probability.