diff --git a/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml index 1ffeba07995860f964e22b8b9d2538320d80f651..f7e327d1e0378ed75b03b1f866414ef34e197c47 100644 --- a/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml +++ b/configs/cls/ch_PP-OCRv3/ch_PP-OCRv3_rotnet.yml @@ -63,8 +63,7 @@ Train: - DecodeImage: img_mode: BGR channel_first: false - - RecAug: - use_tia: False + - BaseDataAugmentation: - RandAugment: - SSLRotateResize: image_shape: [3, 48, 320] diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml index 5e643dc3839b2e2edf3c811db813dd6a90797366..0c46ff560277d9e318587432671de09f0d13cf35 100644 --- a/configs/cls/cls_mv3.yml +++ b/configs/cls/cls_mv3.yml @@ -60,8 +60,7 @@ Train: img_mode: BGR channel_first: False - ClsLabelEncode: # Class handling label - - RecAug: - use_tia: False + - BaseDataAugmentation: - RandAugment: - ClsResizeImg: image_shape: [3, 48, 192] diff --git a/doc/doc_ch/FAQ.md b/doc/doc_ch/FAQ.md index 24f8a3e92be1c003ec7c37b74d14f4ae4117086a..2dad829284806e199801ff3ac55d2e760c3e2583 100644 --- a/doc/doc_ch/FAQ.md +++ b/doc/doc_ch/FAQ.md @@ -682,7 +682,7 @@ lr: #### Q: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置? -**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。 +**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。 #### Q: 训练过程中,训练程序意外退出/挂起,应该如何解决? diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 548832fb0d116ba2de622bd97562b591d74501d8..f0fd578f61452d319ba9ea88eb6a5f37ab55afcf 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt -from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ +from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 7483dffe5b6d9a0a2204702757fcb49762a1cc7a..32de2b3fc36757fee954de41686d09999ce4c640 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -22,13 +22,74 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort class RecAug(object): - def __init__(self, use_tia=True, aug_prob=0.4, **kwargs): - self.use_tia = use_tia - self.aug_prob = aug_prob + def __init__(self, + tia_prob=0.4, + crop_prob=0.4, + reverse_prob=0.4, + noise_prob=0.4, + jitter_prob=0.4, + blur_prob=0.4, + hsv_aug_prob=0.4, + **kwargs): + self.tia_prob = tia_prob + self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob, + jitter_prob, blur_prob, hsv_aug_prob) def __call__(self, data): img = data['image'] - img = warp(img, 10, self.use_tia, self.aug_prob) + h, w, _ = img.shape + + # tia + if random.random() <= self.tia_prob: + if h >= 20 and w >= 20: + img = tia_distort(img, random.randint(3, 6)) + img = tia_stretch(img, random.randint(3, 6)) + img = tia_perspective(img) + + # bda + data['image'] = img + data = self.bda(data) + return data + + +class BaseDataAugmentation(object): + def __init__(self, + crop_prob=0.4, + reverse_prob=0.4, + noise_prob=0.4, + jitter_prob=0.4, + blur_prob=0.4, + hsv_aug_prob=0.4, + **kwargs): + self.crop_prob = crop_prob + self.reverse_prob = reverse_prob + self.noise_prob = noise_prob + self.jitter_prob = jitter_prob + self.blur_prob = blur_prob + self.hsv_aug_prob = hsv_aug_prob + + def __call__(self, data): + img = data['image'] + h, w, _ = img.shape + + if random.random() <= self.crop_prob and h >= 20 and w >= 20: + img = get_crop(img) + + if random.random() <= self.blur_prob: + img = blur(img) + + if random.random() <= self.hsv_aug_prob: + img = hsv_aug(img) + + if random.random() <= self.jitter_prob: + img = jitter(img) + + if random.random() <= self.noise_prob: + img = add_gasuss_noise(img) + + if random.random() <= self.reverse_prob: + img = 255 - img + data['image'] = img return data @@ -359,7 +420,7 @@ def flag(): return 1 if random.random() > 0.5000001 else -1 -def cvtColor(img): +def hsv_aug(img): """ cvtColor """ @@ -427,50 +488,6 @@ def get_crop(image): return crop_img -class Config: - """ - Config - """ - - def __init__(self, use_tia): - self.anglex = random.random() * 30 - self.angley = random.random() * 15 - self.anglez = random.random() * 10 - self.fov = 42 - self.r = 0 - self.shearx = random.random() * 0.3 - self.sheary = random.random() * 0.05 - self.borderMode = cv2.BORDER_REPLICATE - self.use_tia = use_tia - - def make(self, w, h, ang): - """ - make - """ - self.anglex = random.random() * 5 * flag() - self.angley = random.random() * 5 * flag() - self.anglez = -1 * random.random() * int(ang) * flag() - self.fov = 42 - self.r = 0 - self.shearx = 0 - self.sheary = 0 - self.borderMode = cv2.BORDER_REPLICATE - self.w = w - self.h = h - - self.perspective = self.use_tia - self.stretch = self.use_tia - self.distort = self.use_tia - - self.crop = True - self.affine = False - self.reverse = True - self.noise = True - self.jitter = True - self.blur = True - self.color = True - - def rad(x): """ rad @@ -554,48 +571,3 @@ def get_warpAffine(config): rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0], [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32) return rz - - -def warp(img, ang, use_tia=True, prob=0.4): - """ - warp - """ - h, w, _ = img.shape - config = Config(use_tia=use_tia) - config.make(w, h, ang) - new_img = img - - if config.distort: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = tia_distort(new_img, random.randint(3, 6)) - - if config.stretch: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = tia_stretch(new_img, random.randint(3, 6)) - - if config.perspective: - if random.random() <= prob: - new_img = tia_perspective(new_img) - - if config.crop: - img_height, img_width = img.shape[0:2] - if random.random() <= prob and img_height >= 20 and img_width >= 20: - new_img = get_crop(new_img) - - if config.blur: - if random.random() <= prob: - new_img = blur(new_img) - if config.color: - if random.random() <= prob: - new_img = cvtColor(new_img) - if config.jitter: - new_img = jitter(new_img) - if config.noise: - if random.random() <= prob: - new_img = add_gasuss_noise(new_img) - if config.reverse: - if random.random() <= prob: - new_img = 255 - new_img - return new_img