From 16bd2dd0932f8ab98f6264f6bd38d5c1f8616c8f Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 21 Jan 2021 11:11:03 +0800 Subject: [PATCH] fix sast process --- ppocr/data/imaug/sast_process.py | 291 ++++++++++++++++++++----------- 1 file changed, 188 insertions(+), 103 deletions(-) diff --git a/ppocr/data/imaug/sast_process.py b/ppocr/data/imaug/sast_process.py index b8d6ff89..1536dceb 100644 --- a/ppocr/data/imaug/sast_process.py +++ b/ppocr/data/imaug/sast_process.py @@ -24,11 +24,11 @@ __all__ = ['SASTProcessTrain'] class SASTProcessTrain(object): def __init__(self, - image_shape = [512, 512], - min_crop_size = 24, - min_crop_side_ratio = 0.3, - min_text_size = 10, - max_text_size = 512, + image_shape=[512, 512], + min_crop_size=24, + min_crop_side_ratio=0.3, + min_text_size=10, + max_text_size=512, **kwargs): self.input_size = image_shape[1] self.min_crop_size = min_crop_size @@ -42,12 +42,10 @@ class SASTProcessTrain(object): :param poly: :return: """ - edge = [ - (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), - (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), - (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), - (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) - ] + edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), + (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), + (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), + (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])] return np.sum(edge) / 2. def gen_quad_from_poly(self, poly): @@ -57,7 +55,8 @@ class SASTProcessTrain(object): point_num = poly.shape[0] min_area_quad = np.zeros((4, 2), dtype=np.float32) if True: - rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) center_point = rect[0] box = np.array(cv2.boxPoints(rect)) @@ -102,23 +101,33 @@ class SASTProcessTrain(object): if p_area > 0: if tag == False: print('poly in wrong direction') - tag = True # reversed cases should be ignore - poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :] + tag = True # reversed cases should be ignore + poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, + 1), :] quad = quad[(0, 3, 2, 1), :] - len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2]) - len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) + len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - + quad[2]) + len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - + quad[2]) hv_tag = 1 - - if len_w * 2.0 < len_h: + + if len_w * 2.0 < len_h: hv_tag = 0 validated_polys.append(poly) validated_tags.append(tag) hv_tags.append(hv_tag) - return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags) + return np.array(validated_polys), np.array(validated_tags), np.array( + hv_tags) - def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25): + def crop_area(self, + im, + polys, + tags, + hv_tags, + crop_background=False, + max_tries=25): """ make random crop from the input image :param im: @@ -137,10 +146,10 @@ class SASTProcessTrain(object): poly = np.round(poly, decimals=0).astype(np.int32) minx = np.min(poly[:, 0]) maxx = np.max(poly[:, 0]) - w_array[minx + pad_w: maxx + pad_w] = 1 + w_array[minx + pad_w:maxx + pad_w] = 1 miny = np.min(poly[:, 1]) maxy = np.max(poly[:, 1]) - h_array[miny + pad_h: maxy + pad_h] = 1 + h_array[miny + pad_h:maxy + pad_h] = 1 # ensure the cropped area not across a text h_axis = np.where(h_array == 0)[0] w_axis = np.where(w_array == 0)[0] @@ -166,17 +175,18 @@ class SASTProcessTrain(object): if polys.shape[0] != 0: poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) - selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] + selected_polys = np.where( + np.sum(poly_axis_in_area, axis=1) == 4)[0] else: selected_polys = [] if len(selected_polys) == 0: # no text in this area if crop_background: return im[ymin : ymax + 1, xmin : xmax + 1, :], \ - polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts + polys[selected_polys], tags[selected_polys], hv_tags[selected_polys] else: continue - im = im[ymin: ymax + 1, xmin: xmax + 1, :] + im = im[ymin:ymax + 1, xmin:xmax + 1, :] polys = polys[selected_polys] tags = tags[selected_polys] hv_tags = hv_tags[selected_polys] @@ -192,18 +202,28 @@ class SASTProcessTrain(object): width_list = [] height_list = [] for quad in poly_quads: - quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 - quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 + quad_w = (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 width_list.append(quad_w) height_list.append(quad_h) - norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0) + norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0) average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0) for quad in poly_quads: - direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 - direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width - direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)])) - cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label) + direct_vector_full = ( + (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 + direct_vector = direct_vector_full / ( + np.linalg.norm(direct_vector_full) + 1e-6) * norm_width + direction_label = tuple( + map(float, [ + direct_vector[0], direct_vector[1], 1.0 / (average_height + + 1e-6) + ])) + cv2.fillPoly(direction_map, + quad.round().astype(np.int32)[np.newaxis, :, :], + direction_label) return direction_map def calculate_average_height(self, poly_quads): @@ -211,13 +231,19 @@ class SASTProcessTrain(object): """ height_list = [] for quad in poly_quads: - quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 + quad_h = (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[2] - quad[1])) / 2.0 height_list.append(quad_h) average_height = max(sum(height_list) / len(height_list), 1.0) return average_height - def generate_tcl_label(self, hw, polys, tags, ds_ratio, - tcl_ratio=0.3, shrink_ratio_of_width=0.15): + def generate_tcl_label(self, + hw, + polys, + tags, + ds_ratio, + tcl_ratio=0.3, + shrink_ratio_of_width=0.15): """ Generate polygon. """ @@ -225,21 +251,30 @@ class SASTProcessTrain(object): h, w = int(h * ds_ratio), int(w * ds_ratio) polys = polys * ds_ratio - score_map = np.zeros((h, w,), dtype=np.float32) + score_map = np.zeros( + ( + h, + w, ), dtype=np.float32) tbo_map = np.zeros((h, w, 5), dtype=np.float32) - training_mask = np.ones((h, w,), dtype=np.float32) - direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32) + training_mask = np.ones( + ( + h, + w, ), dtype=np.float32) + direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape( + [1, 1, 3]).astype(np.float32) for poly_idx, poly_tag in enumerate(zip(polys, tags)): - poly = poly_tag[0] + poly = poly_tag[0] tag = poly_tag[1] # generate min_area_quad min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) - min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + - np.linalg.norm(min_area_quad[1] - min_area_quad[2])) - min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + - np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: @@ -247,25 +282,37 @@ class SASTProcessTrain(object): if tag: # continue - cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15) + cv2.fillPoly(training_mask, + poly.astype(np.int32)[np.newaxis, :, :], 0.15) else: tcl_poly = self.poly2tcl(poly, tcl_ratio) tcl_quads = self.poly2quads(tcl_poly) poly_quads = self.poly2quads(poly) # stcl map - stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width, - expand_height_ratio=1.0 / tcl_ratio) + stcl_quads, quad_index = self.shrink_poly_along_width( + tcl_quads, + shrink_ratio_of_width=shrink_ratio_of_width, + expand_height_ratio=1.0 / tcl_ratio) # generate tcl map - cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) + cv2.fillPoly(score_map, + np.round(stcl_quads).astype(np.int32), 1.0) # generate tbo map for idx, quad in enumerate(stcl_quads): quad_mask = np.zeros((h, w), dtype=np.float32) - quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) - tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map) + quad_mask = cv2.fillPoly( + quad_mask, + np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) + tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], + quad_mask, tbo_map) return score_map, tbo_map, training_mask - def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25): + def generate_tvo_and_tco(self, + hw, + polys, + tags, + tcl_ratio=0.3, + ds_ratio=0.25): """ Generate tcl map, tvo map and tbo map. """ @@ -297,35 +344,44 @@ class SASTProcessTrain(object): # generate min_area_quad min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) - min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + - np.linalg.norm(min_area_quad[1] - min_area_quad[2])) - min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + - np.linalg.norm(min_area_quad[2] - min_area_quad[3])) + min_area_quad_h = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + + np.linalg.norm(min_area_quad[1] - min_area_quad[2])) + min_area_quad_w = 0.5 * ( + np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + + np.linalg.norm(min_area_quad[2] - min_area_quad[3])) # generate tcl map and text, 128 * 128 tcl_poly = self.poly2tcl(poly, tcl_ratio) # generate poly_tv_xy_map for idx in range(4): - cv2.fillPoly(poly_tv_xy_map[2 * idx], - np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), - float(min(max(min_area_quad[idx, 0], 0), w))) - cv2.fillPoly(poly_tv_xy_map[2 * idx + 1], - np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), - float(min(max(min_area_quad[idx, 1], 0), h))) + cv2.fillPoly( + poly_tv_xy_map[2 * idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 0], 0), w))) + cv2.fillPoly( + poly_tv_xy_map[2 * idx + 1], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(min(max(min_area_quad[idx, 1], 0), h))) # generate poly_tc_xy_map for idx in range(2): - cv2.fillPoly(poly_tc_xy_map[idx], - np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx])) + cv2.fillPoly( + poly_tc_xy_map[idx], + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(center_point[idx])) # generate poly_short_edge_map - cv2.fillPoly(poly_short_edge_map, - np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), - float(max(min(min_area_quad_h, min_area_quad_w), 1.0))) + cv2.fillPoly( + poly_short_edge_map, + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + float(max(min(min_area_quad_h, min_area_quad_w), 1.0))) # generate poly_mask and training_mask - cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1) + cv2.fillPoly(poly_mask, + np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), + 1) tvo_map *= poly_mask tvo_map[:8] -= poly_tv_xy_map @@ -356,7 +412,8 @@ class SASTProcessTrain(object): elif point_num > 4: vector_1 = poly[0] - poly[1] vector_2 = poly[1] - poly[2] - cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) + cos_theta = np.dot(vector_1, vector_2) / ( + np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) theta = np.arccos(np.round(cos_theta, decimals=4)) if abs(theta) > (70 / 180 * math.pi): @@ -374,7 +431,8 @@ class SASTProcessTrain(object): min_area_quad = poly center_point = np.sum(poly, axis=0) / 4 else: - rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) center_point = rect[0] box = np.array(cv2.boxPoints(rect)) @@ -394,16 +452,23 @@ class SASTProcessTrain(object): return min_area_quad, center_point - def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): + def shrink_quad_along_width(self, + quad, + begin_width_ratio=0., + end_width_ratio=1.): """ Generate shrink_quad_along_width. """ - ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) + ratio_pair = np.array( + [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) - def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0): + def shrink_poly_along_width(self, + quads, + shrink_ratio_of_width, + expand_height_ratio=1.0): """ shrink poly with given length. """ @@ -421,22 +486,28 @@ class SASTProcessTrain(object): upper_edge_list.append(upper_edge_len) # length of left edge and right edge. - left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio - right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio + left_length = np.linalg.norm(quads[0][0] - quads[0][ + 3]) * expand_height_ratio + right_length = np.linalg.norm(quads[-1][1] - quads[-1][ + 2]) * expand_height_ratio - shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width + shrink_length = min(left_length, right_length, + sum(upper_edge_list)) * shrink_ratio_of_width # shrinking length upper_len_left = shrink_length upper_len_right = sum(upper_edge_list) - shrink_length left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) - left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) + left_quad = self.shrink_quad_along_width( + quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) - right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) - + right_quad = self.shrink_quad_along_width( + quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) + out_quad_list = [] if left_idx == right_idx: - out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) + out_quad_list.append( + [left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) else: out_quad_list.append(left_quad) for idx in range(left_idx + 1, right_idx): @@ -500,7 +571,8 @@ class SASTProcessTrain(object): """ Generate center line by poly clock-wise point. (4, 2) """ - ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) @@ -509,12 +581,14 @@ class SASTProcessTrain(object): """ Generate center line by poly clock-wise point. """ - ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) + ratio_pair = np.array( + [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) tcl_poly = np.zeros_like(poly) point_num = poly.shape[0] for idx in range(point_num // 2): - point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair + point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx] + ) * ratio_pair tcl_poly[idx] = point_pair[0] tcl_poly[point_num - 1 - idx] = point_pair[1] return tcl_poly @@ -527,8 +601,10 @@ class SASTProcessTrain(object): up_line = self.line_cross_two_point(quad[0], quad[1]) lower_line = self.line_cross_two_point(quad[3], quad[2]) - quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) - quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) + quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + + np.linalg.norm(quad[1] - quad[2])) + quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + + np.linalg.norm(quad[2] - quad[3])) # average angle of left and right line. angle = self.average_angle(quad) @@ -565,7 +641,8 @@ class SASTProcessTrain(object): quad_num = point_num // 2 - 1 for idx in range(quad_num): # reshape and adjust to clock-wise - quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]) + quad_list.append((np.array(point_pair_list)[[idx, idx + 1]] + ).reshape(4, 2)[[0, 2, 3, 1]]) return np.array(quad_list) @@ -579,7 +656,8 @@ class SASTProcessTrain(object): return None h, w, _ = im.shape - text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w)) + text_polys, text_tags, hv_tags = self.check_and_validate_polys( + text_polys, text_tags, (h, w)) if text_polys.shape[0] == 0: return None @@ -591,7 +669,7 @@ class SASTProcessTrain(object): if np.random.rand() < 0.5: asp_scale = 1.0 / asp_scale asp_scale = math.sqrt(asp_scale) - + asp_wx = asp_scale asp_hy = 1.0 / asp_scale im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) @@ -610,7 +688,7 @@ class SASTProcessTrain(object): #no background im, text_polys, text_tags, hv_tags = self.crop_area(im, \ text_polys, text_tags, hv_tags, crop_background=False) - + if text_polys.shape[0] == 0: return None #continue for all ignore case @@ -621,17 +699,18 @@ class SASTProcessTrain(object): return None #resize image std_ratio = float(self.input_size) / max(new_w, new_h) - rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) + rand_scales = np.array( + [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) rz_scale = std_ratio * np.random.choice(rand_scales) im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) text_polys[:, :, 0] *= rz_scale text_polys[:, :, 1] *= rz_scale - + #add gaussian blur if np.random.rand() < 0.1 * 0.5: ks = np.random.permutation(5)[0] + 1 - ks = int(ks/2)*2 + 1 - im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) + ks = int(ks / 2) * 2 + 1 + im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) #add brighter if np.random.rand() < 0.1 * 0.5: im = im * (1.0 + np.random.rand() * 0.5) @@ -640,13 +719,14 @@ class SASTProcessTrain(object): if np.random.rand() < 0.1 * 0.5: im = im * (1.0 - np.random.rand() * 0.5) im = np.clip(im, 0.0, 255.0) - + # Padding the im to [input_size, input_size] new_h, new_w, _ = im.shape if min(new_w, new_h) < self.input_size * 0.5: return None - im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32) + im_padded = np.ones( + (self.input_size, self.input_size, 3), dtype=np.float32) im_padded[:, :, 2] = 0.485 * 255 im_padded[:, :, 1] = 0.456 * 255 im_padded[:, :, 0] = 0.406 * 255 @@ -661,24 +741,29 @@ class SASTProcessTrain(object): sw = int(np.random.rand() * del_w) # Padding - im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy() + im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy() text_polys[:, :, 0] += sw text_polys[:, :, 1] += sh - score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size), - text_polys, text_tags, 0.25) - + score_map, border_map, training_mask = self.generate_tcl_label( + (self.input_size, self.input_size), text_polys, text_tags, 0.25) + # SAST head - tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25) + tvo_map, tco_map = self.generate_tvo_and_tco( + (self.input_size, self.input_size), + text_polys, + text_tags, + tcl_ratio=0.3, + ds_ratio=0.25) # print("test--------tvo_map shape:", tvo_map.shape) im_padded[:, :, 2] -= 0.485 * 255 im_padded[:, :, 1] -= 0.456 * 255 im_padded[:, :, 0] -= 0.406 * 255 - im_padded[:, :, 2] /= (255.0 * 0.229) - im_padded[:, :, 1] /= (255.0 * 0.224) - im_padded[:, :, 0] /= (255.0 * 0.225) - im_padded = im_padded.transpose((2, 0, 1)) + im_padded[:, :, 2] /= (255.0 * 0.229) + im_padded[:, :, 1] /= (255.0 * 0.224) + im_padded[:, :, 0] /= (255.0 * 0.225) + im_padded = im_padded.transpose((2, 0, 1)) data['image'] = im_padded[::-1, :, :] data['score_map'] = score_map[np.newaxis, :, :] @@ -686,4 +771,4 @@ class SASTProcessTrain(object): data['training_mask'] = training_mask[np.newaxis, :, :] data['tvo_map'] = tvo_map.transpose((2, 0, 1)) data['tco_map'] = tco_map.transpose((2, 0, 1)) - return data \ No newline at end of file + return data -- GitLab