提交 55e56b33 编写于 作者: B breezedeus

bugfix

上级 d3130df2
......@@ -78,28 +78,29 @@ class RandomCrop(torch.nn.Module):
self.crop_size = crop_size
self.interpolation = interpolation
def get_params(self, w, h) -> Tuple[int, int, int, int]:
def get_params(self, ori_w, ori_h) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
h_top, h_bot = (
random.randint(0, self.crop_size[0]),
random.randint(0, self.crop_size[0]),
)
w_left, w_right = (
random.randint(0, self.crop_size[1]),
random.randint(0, self.crop_size[1]),
)
h = h - h_top - h_bot
w = w - w_left - w_right
return h_top, w_left, h, w
while True:
h_top, h_bot = (
random.randint(0, self.crop_size[0]),
random.randint(0, self.crop_size[0]),
)
w_left, w_right = (
random.randint(0, self.crop_size[1]),
random.randint(0, self.crop_size[1]),
)
h = ori_h - h_top - h_bot
w = ori_w - w_left - w_right
if h < ori_h * 0.5 or w < ori_w * 0.9:
continue
return h_top, w_left, h, w
def forward(self, img):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册