提交 896e5739 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add tia aug

上级 90f30dbe
...@@ -19,6 +19,8 @@ import random ...@@ -19,6 +19,8 @@ import random
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from .text_image_aug.augment import tia_distort, tia_stretch, tia_perspective
def get_bounding_box_rect(pos): def get_bounding_box_rect(pos):
left = min(pos[0]) left = min(pos[0])
...@@ -196,6 +198,9 @@ class Config: ...@@ -196,6 +198,9 @@ class Config:
self.h = h self.h = h
self.perspective = True self.perspective = True
self.stretch = True
self.distort = True
self.crop = True self.crop = True
self.affine = False self.affine = False
self.reverse = True self.reverse = True
...@@ -299,41 +304,40 @@ def warp(img, ang): ...@@ -299,41 +304,40 @@ def warp(img, ang):
config.make(w, h, ang) config.make(w, h, ang)
new_img = img new_img = img
prob = 0.4
if config.distort:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_distort(new_img, random.randint(3, 6))
if config.stretch:
img_height, img_width = img.shape[0:2]
if random.random() <= prob and img_height >= 20 and img_width >= 20:
new_img = tia_stretch(new_img, random.randint(3, 6))
if config.perspective: if config.perspective:
tp = random.randint(1, 100) if random.random() <= prob:
if tp >= 50: new_img = tia_perspective(new_img)
warpR, (r1, c1), ratio, dst = get_warpR(config)
new_w = int(np.max(dst[:, 0])) - int(np.min(dst[:, 0]))
new_img = cv2.warpPerspective(
new_img,
warpR, (int(new_w * ratio), h),
borderMode=config.borderMode)
if config.crop: if config.crop:
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
tp = random.randint(1, 100) if random.random() <= prob and img_height >= 20 and img_width >= 20:
if tp >= 50 and img_height >= 20 and img_width >= 20:
new_img = get_crop(new_img) new_img = get_crop(new_img)
if config.affine:
warpT = get_warpAffine(config)
new_img = cv2.warpAffine(
new_img, warpT, (w, h), borderMode=config.borderMode)
if config.blur: if config.blur:
tp = random.randint(1, 100) if random.random() <= prob:
if tp >= 50:
new_img = blur(new_img) new_img = blur(new_img)
if config.color: if config.color:
tp = random.randint(1, 100) if random.random() <= prob:
if tp >= 50:
new_img = cvtColor(new_img) new_img = cvtColor(new_img)
if config.jitter: if config.jitter:
new_img = jitter(new_img) new_img = jitter(new_img)
if config.noise: if config.noise:
tp = random.randint(1, 100) if random.random() <= prob:
if tp >= 50:
new_img = add_gasuss_noise(new_img) new_img = add_gasuss_noise(new_img)
if config.reverse: if config.reverse:
tp = random.randint(1, 100) if random.random() <= prob:
if tp >= 50:
new_img = 255 - new_img new_img = 255 - new_img
return new_img return new_img
...@@ -360,7 +364,7 @@ def process_image(img, ...@@ -360,7 +364,7 @@ def process_image(img,
text = char_ops.encode(label) text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length: if len(text) == 0 or len(text) > max_text_length:
logger.info( logger.info(
"Warning in ppocr/data/rec/img_tools.py: Wrong data type." "Warning in ppocr/data/rec/img_tools.py:line362: Wrong data type."
"Excepted string with length between 1 and {}, but " "Excepted string with length between 1 and {}, but "
"got '{}'. Label is '{}'".format(max_text_length, "got '{}'. Label is '{}'".format(max_text_length,
len(text), label)) len(text), label))
...@@ -382,6 +386,7 @@ def process_image(img, ...@@ -382,6 +386,7 @@ def process_image(img,
% loss_type % loss_type
return (norm_img) return (norm_img)
def resize_norm_img_srn(img, image_shape): def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
...@@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape): ...@@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape):
return np.reshape(img_black, (c, row, col)).astype(np.float32) return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape,
num_heads, def srn_other_inputs(image_shape, num_heads, max_text_length, char_num):
max_text_length,
char_num):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8)) feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64') encoder_word_pos = np.array(range(0, feature_dim)).reshape(
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64') (feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64') lbl_weight = np.array([int(char_num - 1)] * max_text_length).reshape(
(-1, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9] [-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length]) gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9] [-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]) * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :] encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :] gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] return [
lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(img, def process_image_srn(img,
image_shape, image_shape,
...@@ -453,14 +467,16 @@ def process_image_srn(img, ...@@ -453,14 +467,16 @@ def process_image_srn(img,
return None return None
else: else:
if loss_type == "srn": if loss_type == "srn":
text_padded = [int(char_num-1)] * max_text_length text_padded = [int(char_num - 1)] * max_text_length
for i in range(len(text)): for i in range(len(text)):
text_padded[i] = text[i] text_padded[i] = text[i]
lbl_weight[i] = [1.0] lbl_weight[i] = [1.0]
text_padded = np.array(text_padded) text_padded = np.array(text_padded)
text = text_padded.reshape(-1, 1) text = text_padded.reshape(-1, 1)
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight) return (norm_img, text, encoder_word_pos, gsrm_word_pos,
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight)
else: else:
assert False, "Unsupport loss_type %s in process_image"\ assert False, "Unsupport loss_type %s in process_image"\
% loss_type % loss_type
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
# -*- coding:utf-8 -*-
# Author: RubanSeven
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
# import cv2
import numpy as np
from .warp_mls import WarpMLS
def tia_distort(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut // 3
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
dst_pts.append(
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
dst_pts.append(
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
np.random.randint(thresh) - half_thresh
])
dst_pts.append([
cut * cut_idx + np.random.randint(thresh) - half_thresh,
img_h + np.random.randint(thresh) - half_thresh
])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_stretch(src, segment=4):
img_h, img_w = src.shape[:2]
cut = img_w // segment
thresh = cut * 4 // 5
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, 0])
dst_pts.append([img_w, 0])
dst_pts.append([img_w, img_h])
dst_pts.append([0, img_h])
half_thresh = thresh * 0.5
for cut_idx in np.arange(1, segment, 1):
move = np.random.randint(thresh) - half_thresh
src_pts.append([cut * cut_idx, 0])
src_pts.append([cut * cut_idx, img_h])
dst_pts.append([cut * cut_idx + move, 0])
dst_pts.append([cut * cut_idx + move, img_h])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
def tia_perspective(src):
img_h, img_w = src.shape[:2]
thresh = img_h // 2
src_pts = list()
dst_pts = list()
src_pts.append([0, 0])
src_pts.append([img_w, 0])
src_pts.append([img_w, img_h])
src_pts.append([0, img_h])
dst_pts.append([0, np.random.randint(thresh)])
dst_pts.append([img_w, np.random.randint(thresh)])
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
dst_pts.append([0, img_h - np.random.randint(thresh)])
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
dst = trans.generate()
return dst
# -*- coding:utf-8 -*-
# Author: RubanSeven
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
import math
import numpy as np
class WarpMLS:
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
self.src = src
self.src_pts = src_pts
self.dst_pts = dst_pts
self.pt_count = len(self.dst_pts)
self.dst_w = dst_w
self.dst_h = dst_h
self.trans_ratio = trans_ratio
self.grid_size = 100
self.rdx = np.zeros((self.dst_h, self.dst_w))
self.rdy = np.zeros((self.dst_h, self.dst_w))
@staticmethod
def __bilinear_interp(x, y, v11, v12, v21, v22):
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
(1 - y) + v22 * y) * x
def generate(self):
self.calc_delta()
return self.gen_img()
def calc_delta(self):
w = np.zeros(self.pt_count, dtype=np.float32)
if self.pt_count < 2:
return
i = 0
while 1:
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
i = self.dst_w - 1
elif i >= self.dst_w:
break
j = 0
while 1:
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
j = self.dst_h - 1
elif j >= self.dst_h:
break
sw = 0
swp = np.zeros(2, dtype=np.float32)
swq = np.zeros(2, dtype=np.float32)
new_pt = np.zeros(2, dtype=np.float32)
cur_pt = np.array([i, j], dtype=np.float32)
k = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
break
w[k] = 1. / (
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
sw += w[k]
swp = swp + w[k] * np.array(self.dst_pts[k])
swq = swq + w[k] * np.array(self.src_pts[k])
if k == self.pt_count - 1:
pstar = 1 / sw * swp
qstar = 1 / sw * swq
miu_s = 0
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
miu_s += w[k] * np.sum(pt_i * pt_i)
cur_pt -= pstar
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
for k in range(self.pt_count):
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
continue
pt_i = self.dst_pts[k] - pstar
pt_j = np.array([-pt_i[1], pt_i[0]])
tmp_pt = np.zeros(2, dtype=np.float32)
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
tmp_pt *= (w[k] / miu_s)
new_pt += tmp_pt
new_pt += qstar
else:
new_pt = self.src_pts[k]
self.rdx[j, i] = new_pt[0] - i
self.rdy[j, i] = new_pt[1] - j
j += self.grid_size
i += self.grid_size
def gen_img(self):
src_h, src_w = self.src.shape[:2]
dst = np.zeros_like(self.src, dtype=np.float32)
for i in np.arange(0, self.dst_h, self.grid_size):
for j in np.arange(0, self.dst_w, self.grid_size):
ni = i + self.grid_size
nj = j + self.grid_size
w = h = self.grid_size
if ni >= self.dst_h:
ni = self.dst_h - 1
h = ni - i + 1
if nj >= self.dst_w:
nj = self.dst_w - 1
w = nj - j + 1
di = np.reshape(np.arange(h), (-1, 1))
dj = np.reshape(np.arange(w), (1, -1))
delta_x = self.__bilinear_interp(
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
self.rdx[ni, j], self.rdx[ni, nj])
delta_y = self.__bilinear_interp(
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
self.rdy[ni, j], self.rdy[ni, nj])
nx = j + dj + delta_x * self.trans_ratio
ny = i + di + delta_y * self.trans_ratio
nx = np.clip(nx, 0, src_w - 1)
ny = np.clip(ny, 0, src_h - 1)
nxi = np.array(np.floor(nx), dtype=np.int32)
nyi = np.array(np.floor(ny), dtype=np.int32)
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
if len(self.src.shape) == 3:
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
else:
x = ny - nyi
y = nx - nxi
dst[i:i + h, j:j + w] = self.__bilinear_interp(
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
self.src[nyi1, nxi], self.src[nyi1, nxi1])
dst = np.clip(dst, 0, 255)
dst = np.array(dst, dtype=np.uint8)
return dst
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册