提交 d3130df2 编写于 作者: B breezedeus

add randomly cropping borders when training

上级 d52e296d
......@@ -33,7 +33,7 @@ from torchvision import transforms as T
from cnocr.consts import MODEL_VERSION, ENCODER_CONFIGS, DECODER_CONFIGS
from cnocr.utils import set_logger, load_model_params, check_model_name, save_img, read_img
from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug, RandomStretchAug
from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug, RandomStretchAug, RandomCrop
from cnocr.dataset import OcrDataModule
from cnocr.trainer import PlTrainer, resave_model
from cnocr import CnOcr, gen_model
......@@ -97,6 +97,7 @@ def train(
train_transform = T.Compose(
[
RandomStretchAug(min_ratio=0.5, max_ratio=1.5),
RandomCrop((8, 10)),
T.RandomInvert(p=0.2),
T.RandomApply([T.RandomRotation(degrees=1)], p=0.4),
# T.RandomAutocontrast(p=0.05),
......
......@@ -18,6 +18,7 @@
# under the License.
import random
from typing import Tuple
import torch
import torchvision.transforms.functional as F
......@@ -69,6 +70,50 @@ class RandomStretchAug(object):
return F.resize(img, [h, int(w * new_w_ratio)])
class RandomCrop(torch.nn.Module):
def __init__(
self, crop_size: Tuple[int, int], interpolation=F.InterpolationMode.BILINEAR
):
super().__init__()
self.crop_size = crop_size
self.interpolation = interpolation
def get_params(self, w, h) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
h_top, h_bot = (
random.randint(0, self.crop_size[0]),
random.randint(0, self.crop_size[0]),
)
w_left, w_right = (
random.randint(0, self.crop_size[1]),
random.randint(0, self.crop_size[1]),
)
h = h - h_top - h_bot
w = w - w_left - w_right
return h_top, w_left, h, w
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
ori_w, ori_h = F._get_image_size(img)
i, j, h, w = self.get_params(ori_w, ori_h)
return F.resized_crop(img, i, j, h, w, (ori_h, ori_w), self.interpolation)
class RandomPaddingAug(object):
def __init__(self, p, max_pad_len):
self.p = p
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册