From 889e5834d1d2fd6d1e90cbb291f97c75b47586dd Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Tue, 13 Dec 2022 16:52:38 +0800 Subject: [PATCH] [Dy2St] transforms.RandomVerticalFlip Support static mode (#49024) * add static RandomVerticalFlip * object => unittest.TestCase --- python/paddle/tests/test_transforms_static.py | 102 ++++++++++++++++++ python/paddle/vision/transforms/functional.py | 6 +- .../vision/transforms/functional_tensor.py | 12 ++- python/paddle/vision/transforms/transforms.py | 13 +++ 4 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 python/paddle/tests/test_transforms_static.py diff --git a/python/paddle/tests/test_transforms_static.py b/python/paddle/tests/test_transforms_static.py new file mode 100644 index 0000000000..47a0606c5b --- /dev/null +++ b/python/paddle/tests/test_transforms_static.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.vision.transforms import transforms + +SEED = 2022 + + +class TestTransformUnitTestBase(unittest.TestCase): + def setUp(self): + self.img = (np.random.rand(*self.get_shape()) * 255.0).astype( + np.float32 + ) + self.set_trans_api() + + def get_shape(self): + return (64, 64, 3) + + def set_trans_api(self): + self.api = transforms.Resize(size=16) + + def dynamic_transform(self): + paddle.seed(SEED) + + img_t = paddle.to_tensor(self.img) + return self.api(img_t) + + def static_transform(self): + paddle.enable_static() + paddle.seed(SEED) + + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data( + shape=self.get_shape(), dtype=paddle.float32, name='img' + ) + out = self.api(x) + + exe = paddle.static.Executor() + res = exe.run(main_program, fetch_list=[out], feed={'img': self.img}) + + paddle.disable_static() + return res[0] + + def test_transform(self): + dy_res = self.dynamic_transform() + st_res = self.static_transform() + + np.testing.assert_almost_equal(dy_res, st_res) + + +class TestResize(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.Resize(size=(16, 16)) + + +class TestResizeError(TestTransformUnitTestBase): + def test_transform(self): + pass + + def test_error(self): + paddle.enable_static() + # Not support while w<=0 or h<=0, but received w=-1, h=-1 + with self.assertRaises(NotImplementedError): + main_program = paddle.static.Program() + with paddle.static.program_guard(main_program): + x = paddle.static.data( + shape=[-1, -1, -1], dtype=paddle.float32, name='img' + ) + self.api(x) + + paddle.disable_static() + + +class TestRandomVerticalFlip0(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomVerticalFlip(prob=0) + + +class TestRandomVerticalFlip1(TestTransformUnitTestBase): + def set_trans_api(self): + self.api = transforms.RandomVerticalFlip(prob=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/vision/transforms/functional.py b/python/paddle/vision/transforms/functional.py index d58c0f610e..91a600efd3 100644 --- a/python/paddle/vision/transforms/functional.py +++ b/python/paddle/vision/transforms/functional.py @@ -20,6 +20,7 @@ from PIL import Image import paddle +from ...fluid.framework import Variable from . import functional_cv2 as F_cv2 from . import functional_pil as F_pil from . import functional_tensor as F_t @@ -32,7 +33,10 @@ def _is_pil_image(img): def _is_tensor_image(img): - return isinstance(img, paddle.Tensor) + """ + Return True if img is a Tensor for dynamic mode or Variable for static mode. + """ + return isinstance(img, (paddle.Tensor, Variable)) def _is_numpy_image(img): diff --git a/python/paddle/vision/transforms/functional_tensor.py b/python/paddle/vision/transforms/functional_tensor.py index d18fdfc51c..8137a4f284 100644 --- a/python/paddle/vision/transforms/functional_tensor.py +++ b/python/paddle/vision/transforms/functional_tensor.py @@ -18,12 +18,14 @@ import numbers import paddle import paddle.nn.functional as F +from ...fluid.framework import Variable + __all__ = [] def _assert_image_tensor(img, data_format): if ( - not isinstance(img, paddle.Tensor) + not isinstance(img, (paddle.Tensor, Variable)) or img.ndim < 3 or img.ndim > 4 or not data_format.lower() in ('chw', 'hwc') @@ -725,6 +727,14 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'): if isinstance(size, int): w, h = _get_image_size(img, data_format) + # TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape. + # We should consider to support this case in future. + if w <= 0 or h <= 0: + raise NotImplementedError( + "Not support while w<=0 or h<=0, but received w={}, h={}".format( + w, h + ) + ) if (w <= h and w == size) or (h <= w and h == size): return img if w < h: diff --git a/python/paddle/vision/transforms/transforms.py b/python/paddle/vision/transforms/transforms.py index cb48598c8f..14d2511994 100644 --- a/python/paddle/vision/transforms/transforms.py +++ b/python/paddle/vision/transforms/transforms.py @@ -653,10 +653,23 @@ class RandomVerticalFlip(BaseTransform): self.prob = prob def _apply_image(self, img): + if paddle.in_dynamic_mode(): + return self._dynamic_apply_image(img) + else: + return self._static_apply_image(img) + + def _dynamic_apply_image(self, img): if random.random() < self.prob: return F.vflip(img) return img + def _static_apply_image(self, img): + return paddle.static.nn.cond( + paddle.rand(shape=(1,)) < self.prob, + lambda: F.vflip(img), + lambda: img, + ) + class Normalize(BaseTransform): """Normalize the input data with mean and standard deviation. -- GitLab