未验证 提交 82caa677 编写于 作者: R Ryan 提交者: GitHub

[Dy2St] transforms.RandomResizedCrop Support static mode (#49619)

* add RandomResizedCrop

* fix code style

* fix first loop

* add uni test
上级 50a8b655
......@@ -158,5 +158,14 @@ class TestRandomRotation_expand_True(TestTransformUnitTestBase):
self.api = transforms.RandomRotation(degree_tuple, expand=True, fill=3)
class TestRandomResizedCrop(TestTransformUnitTestBase):
def set_trans_api(self, eps=10e-5):
c, h, w = self.get_shape()
size = h, w
scale = (1 - eps, 1.0)
ratio = (1 - eps, 1.0)
self.api = transforms.RandomResizedCrop(size, scale=scale, ratio=ratio)
if __name__ == "__main__":
unittest.main()
......@@ -494,7 +494,7 @@ class RandomResizedCrop(BaseTransform):
self.ratio = ratio
self.interpolation = interpolation
def _get_param(self, image, attempts=10):
def _dynamic_get_param(self, image, attempts=10):
width, height = _get_image_size(image)
area = height * width
......@@ -527,8 +527,106 @@ class RandomResizedCrop(BaseTransform):
j = (width - w) // 2
return i, j, h, w
def _static_get_param(self, image, attempts=10):
width, height = _get_image_size(image)
area = height * width
log_ratio = tuple(math.log(x) for x in self.ratio)
counter = paddle.full(
shape=[1], fill_value=0, dtype='int32'
) # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int32'
) # loop length
i = paddle.zeros([1], dtype="int32")
j = paddle.zeros([1], dtype="int32")
h = paddle.ones([1], dtype="int32") * (height + 1)
w = paddle.ones([1], dtype="int32") * (width + 1)
def cond(counter, ten, i, j, h, w):
return (counter < ten) and (w > width or h > height)
def body(counter, ten, i, j, h, w):
target_area = (
paddle.uniform(shape=[1], min=self.scale[0], max=self.scale[1])
* area
)
aspect_ratio = paddle.exp(
paddle.uniform(shape=[1], min=log_ratio[0], max=log_ratio[1])
)
w = paddle.round(paddle.sqrt(target_area * aspect_ratio)).astype(
'int32'
)
h = paddle.round(paddle.sqrt(target_area / aspect_ratio)).astype(
'int32'
)
i = paddle.static.nn.cond(
0 < w <= width and 0 < h <= height,
lambda: paddle.uniform(shape=[1], min=0, max=height - h).astype(
"int32"
),
lambda: i,
)
j = paddle.static.nn.cond(
0 < w <= width and 0 < h <= height,
lambda: paddle.uniform(shape=[1], min=0, max=width - w).astype(
"int32"
),
lambda: j,
)
counter += 1
return counter, ten, i, j, h, w
counter, ten, i, j, h, w = paddle.static.nn.while_loop(
cond, body, [counter, ten, i, j, h, w]
)
def central_crop(width, height):
height = paddle.assign([height]).astype("float32")
width = paddle.assign([width]).astype("float32")
# Fallback to central crop
in_ratio = width / height
w, h = paddle.static.nn.cond(
in_ratio < self.ratio[0],
lambda: [
width.astype("int32"),
paddle.round(width / self.ratio[0]).astype("int32"),
],
lambda: paddle.static.nn.cond(
in_ratio > self.ratio[1],
lambda: [
paddle.round(height * self.ratio[1]),
height.astype("int32"),
],
lambda: [width.astype("int32"), height.astype("int32")],
),
)
i = (height.astype("int32") - h) // 2
j = (width.astype("int32") - w) // 2
return i, j, h, w, counter
return paddle.static.nn.cond(
0 < w <= width and 0 < h <= height,
lambda: [i, j, h, w, counter],
lambda: central_crop(width, height),
)
def _apply_image(self, img):
i, j, h, w = self._get_param(img)
if paddle.in_dynamic_mode():
i, j, h, w = self._dynamic_get_param(img)
else:
i, j, h, w, counter = self._static_get_param(img)
cropped_img = F.crop(img, i, j, h, w)
return F.resize(cropped_img, self.size, self.interpolation)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册