未验证 提交 a0576250 编写于 作者: R Ryan 提交者: GitHub

[Dy2St] transforms.RandomCrop Support static mode (#49057)

* add RandomCrop

* 10e-5  =>  eps

* add same shape test
上级 73ec9b78
......@@ -60,6 +60,8 @@ class TestTransformUnitTestBase(unittest.TestCase):
def test_transform(self):
dy_res = self.dynamic_transform()
if isinstance(dy_res, paddle.Tensor):
dy_res = dy_res.numpy()
st_res = self.static_transform()
np.testing.assert_almost_equal(dy_res, st_res)
......@@ -98,5 +100,47 @@ class TestRandomVerticalFlip1(TestTransformUnitTestBase):
self.api = transforms.RandomVerticalFlip(prob=1)
class TestRandomCrop_random(TestTransformUnitTestBase):
def get_shape(self):
return (3, 240, 240)
def set_trans_api(self):
self.crop_size = (224, 224)
self.api = transforms.RandomCrop(self.crop_size)
def assert_test_random_equal(self, res, eps=10e-5):
_, h, w = self.get_shape()
c_h, c_w = self.crop_size
res_assert = True
for y in range(h - c_h):
for x in range(w - c_w):
diff_abs_sum = np.abs(
(self.img[:, y : y + c_h, x : x + c_w] - res)
).sum()
if diff_abs_sum < eps:
res_assert = False
break
if not res_assert:
break
assert not res_assert
def test_transform(self):
dy_res = self.dynamic_transform().numpy()
st_res = self.static_transform()
self.assert_test_random_equal(dy_res)
self.assert_test_random_equal(st_res)
class TestRandomCrop_same(TestTransformUnitTestBase):
def get_shape(self):
return (3, 224, 224)
def set_trans_api(self):
self.crop_size = (224, 224)
self.api = transforms.RandomCrop(self.crop_size)
if __name__ == "__main__":
unittest.main()
......@@ -1138,8 +1138,12 @@ class RandomCrop(BaseTransform):
if w == tw and h == th:
return 0, 0, h, w
if paddle.in_dynamic_mode():
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
else:
i = paddle.randint(low=0, high=h - th)
j = paddle.randint(low=0, high=w - tw)
return i, j, th, tw
def _apply_image(self, img):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册