未验证 提交 5e835d36 编写于 作者: 张春乔 提交者: GitHub

[Dy2St] RandomHorizontalFlip (#49590)

* Dy2St_RandomHorizontalFlip

* Dy2St_RandomHorizontalFlip
上级 2f601282
...@@ -100,6 +100,16 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase): ...@@ -100,6 +100,16 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase):
self.api = transforms.RandomVerticalFlip(prob=1) self.api = transforms.RandomVerticalFlip(prob=1)
class TestRandomHorizontalFlip0(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomHorizontalFlip(0)
class TestRandomHorizontalFlip1(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomHorizontalFlip(1)
class TestRandomCrop_random(TestTransformUnitTestBase): class TestRandomCrop_random(TestTransformUnitTestBase):
def get_shape(self): def get_shape(self):
return (3, 240, 240) return (3, 240, 240)
......
...@@ -709,10 +709,23 @@ class RandomHorizontalFlip(BaseTransform): ...@@ -709,10 +709,23 @@ class RandomHorizontalFlip(BaseTransform):
self.prob = prob self.prob = prob
def _apply_image(self, img): def _apply_image(self, img):
if paddle.in_dynamic_mode():
return self._dynamic_apply_image(img)
else:
return self._static_apply_image(img)
def _dynamic_apply_image(self, img):
if random.random() < self.prob: if random.random() < self.prob:
return F.hflip(img) return F.hflip(img)
return img return img
def _static_apply_image(self, img):
return paddle.static.nn.cond(
paddle.rand(shape=(1,)) < self.prob,
lambda: F.hflip(img),
lambda: img,
)
class RandomVerticalFlip(BaseTransform): class RandomVerticalFlip(BaseTransform):
"""Vertically flip the input data randomly with a given probability. """Vertically flip the input data randomly with a given probability.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册