未验证 提交 889e5834 编写于 作者: R Ryan 提交者: GitHub

[Dy2St] transforms.RandomVerticalFlip Support static mode (#49024)

* add static RandomVerticalFlip

* object => unittest.TestCase
上级 31922692
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.vision.transforms import transforms
SEED = 2022
class TestTransformUnitTestBase(unittest.TestCase):
def setUp(self):
self.img = (np.random.rand(*self.get_shape()) * 255.0).astype(
np.float32
)
self.set_trans_api()
def get_shape(self):
return (64, 64, 3)
def set_trans_api(self):
self.api = transforms.Resize(size=16)
def dynamic_transform(self):
paddle.seed(SEED)
img_t = paddle.to_tensor(self.img)
return self.api(img_t)
def static_transform(self):
paddle.enable_static()
paddle.seed(SEED)
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(
shape=self.get_shape(), dtype=paddle.float32, name='img'
)
out = self.api(x)
exe = paddle.static.Executor()
res = exe.run(main_program, fetch_list=[out], feed={'img': self.img})
paddle.disable_static()
return res[0]
def test_transform(self):
dy_res = self.dynamic_transform()
st_res = self.static_transform()
np.testing.assert_almost_equal(dy_res, st_res)
class TestResize(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.Resize(size=(16, 16))
class TestResizeError(TestTransformUnitTestBase):
def test_transform(self):
pass
def test_error(self):
paddle.enable_static()
# Not support while w<=0 or h<=0, but received w=-1, h=-1
with self.assertRaises(NotImplementedError):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
x = paddle.static.data(
shape=[-1, -1, -1], dtype=paddle.float32, name='img'
)
self.api(x)
paddle.disable_static()
class TestRandomVerticalFlip0(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomVerticalFlip(prob=0)
class TestRandomVerticalFlip1(TestTransformUnitTestBase):
def set_trans_api(self):
self.api = transforms.RandomVerticalFlip(prob=1)
if __name__ == "__main__":
unittest.main()
...@@ -20,6 +20,7 @@ from PIL import Image ...@@ -20,6 +20,7 @@ from PIL import Image
import paddle import paddle
from ...fluid.framework import Variable
from . import functional_cv2 as F_cv2 from . import functional_cv2 as F_cv2
from . import functional_pil as F_pil from . import functional_pil as F_pil
from . import functional_tensor as F_t from . import functional_tensor as F_t
...@@ -32,7 +33,10 @@ def _is_pil_image(img): ...@@ -32,7 +33,10 @@ def _is_pil_image(img):
def _is_tensor_image(img): def _is_tensor_image(img):
return isinstance(img, paddle.Tensor) """
Return True if img is a Tensor for dynamic mode or Variable for static mode.
"""
return isinstance(img, (paddle.Tensor, Variable))
def _is_numpy_image(img): def _is_numpy_image(img):
......
...@@ -18,12 +18,14 @@ import numbers ...@@ -18,12 +18,14 @@ import numbers
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from ...fluid.framework import Variable
__all__ = [] __all__ = []
def _assert_image_tensor(img, data_format): def _assert_image_tensor(img, data_format):
if ( if (
not isinstance(img, paddle.Tensor) not isinstance(img, (paddle.Tensor, Variable))
or img.ndim < 3 or img.ndim < 3
or img.ndim > 4 or img.ndim > 4
or not data_format.lower() in ('chw', 'hwc') or not data_format.lower() in ('chw', 'hwc')
...@@ -725,6 +727,14 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'): ...@@ -725,6 +727,14 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'):
if isinstance(size, int): if isinstance(size, int):
w, h = _get_image_size(img, data_format) w, h = _get_image_size(img, data_format)
# TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape.
# We should consider to support this case in future.
if w <= 0 or h <= 0:
raise NotImplementedError(
"Not support while w<=0 or h<=0, but received w={}, h={}".format(
w, h
)
)
if (w <= h and w == size) or (h <= w and h == size): if (w <= h and w == size) or (h <= w and h == size):
return img return img
if w < h: if w < h:
......
...@@ -653,10 +653,23 @@ class RandomVerticalFlip(BaseTransform): ...@@ -653,10 +653,23 @@ class RandomVerticalFlip(BaseTransform):
self.prob = prob self.prob = prob
def _apply_image(self, img): def _apply_image(self, img):
if paddle.in_dynamic_mode():
return self._dynamic_apply_image(img)
else:
return self._static_apply_image(img)
def _dynamic_apply_image(self, img):
if random.random() < self.prob: if random.random() < self.prob:
return F.vflip(img) return F.vflip(img)
return img return img
def _static_apply_image(self, img):
return paddle.static.nn.cond(
paddle.rand(shape=(1,)) < self.prob,
lambda: F.vflip(img),
lambda: img,
)
class Normalize(BaseTransform): class Normalize(BaseTransform):
"""Normalize the input data with mean and standard deviation. """Normalize the input data with mean and standard deviation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册