提交 feca6368 编写于 作者: 文幕地方's avatar 文幕地方

add bda

上级 8a5a9870
...@@ -63,8 +63,7 @@ Train: ...@@ -63,8 +63,7 @@ Train:
- DecodeImage: - DecodeImage:
img_mode: BGR img_mode: BGR
channel_first: false channel_first: false
- RecAug: - BaseDataAugmentation:
use_tia: False
- RandAugment: - RandAugment:
- SSLRotateResize: - SSLRotateResize:
image_shape: [3, 48, 320] image_shape: [3, 48, 320]
......
...@@ -60,8 +60,7 @@ Train: ...@@ -60,8 +60,7 @@ Train:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- ClsLabelEncode: # Class handling label - ClsLabelEncode: # Class handling label
- RecAug: - BaseDataAugmentation:
use_tia: False
- RandAugment: - RandAugment:
- ClsResizeImg: - ClsResizeImg:
image_shape: [3, 48, 192] image_shape: [3, 48, 192]
......
...@@ -682,7 +682,7 @@ lr: ...@@ -682,7 +682,7 @@ lr:
#### Q: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置? #### Q: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置?
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744) **A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)
#### Q: 训练过程中,训练程序意外退出/挂起,应该如何解决? #### Q: 训练过程中,训练程序意外退出/挂起,应该如何解决?
......
...@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap ...@@ -22,7 +22,7 @@ from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, RandomCropImgMask from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
......
...@@ -22,13 +22,74 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort ...@@ -22,13 +22,74 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
class RecAug(object): class RecAug(object):
def __init__(self, use_tia=True, aug_prob=0.4, **kwargs): def __init__(self,
self.use_tia = use_tia tia_prob=True,
self.aug_prob = aug_prob crop_prob=0.4,
reverse_prob=0.4,
noise_prob=0.4,
jitter_prob=0.4,
blur_prob=0.4,
hsv_aug_prob=0.4,
**kwargs):
self.tia_prob = tia_prob
self.bda = BaseDataAugmentation(crop_prob, reverse_prob, noise_prob,
jitter_prob, blur_prob, hsv_aug_prob)
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
img = warp(img, 10, self.use_tia, self.aug_prob) h, w, _ = img.shape
# tia
if random.random() <= self.tia_prob:
if h >= 20 and w >= 20:
img = tia_distort(img, random.randint(3, 6))
img = tia_stretch(img, random.randint(3, 6))
img = tia_perspective(img)
# bda
data['image'] = img
data = self.bda(data)
return data
class BaseDataAugmentation(object):
def __init__(self,
crop_prob=0.4,
reverse_prob=0.4,
noise_prob=0.4,
jitter_prob=0.4,
blur_prob=0.4,
hsv_aug_prob=0.4,
**kwargs):
self.crop_prob = crop_prob
self.reverse_prob = reverse_prob
self.noise_prob = noise_prob
self.jitter_prob = jitter_prob
self.blur_prob = blur_prob
self.hsv_aug_prob = hsv_aug_prob
def __call__(self, data):
img = data['image']
h, w, _ = img.shape
if random.random() <= self.crop_prob and h >= 20 and w >= 20:
img = get_crop(img)
if random.random() <= self.blur_prob:
img = blur(img)
if random.random() <= self.hsv_aug_prob:
img = hsv_aug(img)
if random.random() <= self.jitter_prob:
img = jitter(img)
if random.random() <= self.noise_prob:
img = add_gasuss_noise(img)
if random.random() <= self.reverse_prob:
img = 255 - img
data['image'] = img data['image'] = img
return data return data
...@@ -359,7 +420,7 @@ def flag(): ...@@ -359,7 +420,7 @@ def flag():
return 1 if random.random() > 0.5000001 else -1 return 1 if random.random() > 0.5000001 else -1
def cvtColor(img): def hsv_aug(img):
""" """
cvtColor cvtColor
""" """
...@@ -427,50 +488,6 @@ def get_crop(image): ...@@ -427,50 +488,6 @@ def get_crop(image):
return crop_img return crop_img
class Config:
"""
Config
"""
def __init__(self, use_tia):
self.anglex = random.random() * 30
self.angley = random.random() * 15
self.anglez = random.random() * 10
self.fov = 42
self.r = 0
self.shearx = random.random() * 0.3
self.sheary = random.random() * 0.05
self.borderMode = cv2.BORDER_REPLICATE
self.use_tia = use_tia
def make(self, w, h, ang):
"""
make
"""
self.anglex = random.random() * 5 * flag()
self.angley = random.random() * 5 * flag()
self.anglez = -1 * random.random() * int(ang) * flag()
self.fov = 42
self.r = 0
self.shearx = 0
self.sheary = 0
self.borderMode = cv2.BORDER_REPLICATE
self.w = w
self.h = h
self.perspective = self.use_tia
self.stretch = self.use_tia
self.distort = self.use_tia
self.crop = True
self.affine = False
self.reverse = True
self.noise = True
self.jitter = True
self.blur = True
self.color = True
def rad(x): def rad(x):
""" """
rad rad
...@@ -554,48 +571,3 @@ def get_warpAffine(config): ...@@ -554,48 +571,3 @@ def get_warpAffine(config):
rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0], rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0],
[-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32) [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32)
return rz return rz
def warp(img, ang, use_tia=True, prob=0.4):
"""
warp
"""
h, w, _ = img.shape
config = Config(use_tia=use_tia)
config.make(w, h, ang)
new_img = img
if config.distort:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_distort(new_img, random.randint(3, 6))
if config.stretch:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_stretch(new_img, random.randint(3, 6))
if config.perspective:
if random.random() <= prob:
new_img = tia_perspective(new_img)
if config.crop:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img)
if config.blur:
if random.random() <= prob:
new_img = blur(new_img)
if config.color:
if random.random() <= prob:
new_img = cvtColor(new_img)
if config.jitter:
new_img = jitter(new_img)
if config.noise:
if random.random() <= prob:
new_img = add_gasuss_noise(new_img)
if config.reverse:
if random.random() <= prob:
new_img = 255 - new_img
return new_img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册