diff --git a/StyleText/engine/predictors.py b/StyleText/engine/predictors.py index a1ba21f1b6cd084f9f95140d4227d600d4631715..ca9ab9ce6fc471e077970766c252c98b8617c6cc 100644 --- a/StyleText/engine/predictors.py +++ b/StyleText/engine/predictors.py @@ -38,7 +38,15 @@ class StyleTextRecPredictor(object): self.std = config["Predictor"]["std"] self.expand_result = config["Predictor"]["expand_result"] - def predict(self, style_input, text_input): + def reshape_to_same_height(self, img_list): + h = img_list[0].shape[0] + for idx in range(1, len(img_list)): + new_w = round(1.0 * img_list[idx].shape[1] / + img_list[idx].shape[0] * h) + img_list[idx] = cv2.resize(img_list[idx], (new_w, h)) + return img_list + + def predict_single_image(self, style_input, text_input): style_input = self.rep_style_input(style_input, text_input) tensor_style_input = self.preprocess(style_input) tensor_text_input = self.preprocess(text_input) @@ -64,6 +72,21 @@ class StyleTextRecPredictor(object): "fake_bg": fake_bg, } + def predict(self, style_input, text_input_list): + if not isinstance(text_input_list, (tuple, list)): + return self.predict_single_image(style_input, text_input_list) + + synth_result_list = [] + for text_input in text_input_list: + synth_result = self.predict_single_image(style_input, text_input) + synth_result_list.append(synth_result) + + for key in synth_result: + res = [r[key] for r in synth_result_list] + res = self.reshape_to_same_height(res) + synth_result[key] = np.concatenate(res, axis=1) + return synth_result + def preprocess(self, img): img = (img.astype('float32') * self.scale - self.mean) / self.std img_height, img_width, channel = img.shape diff --git a/StyleText/engine/synthesisers.py b/StyleText/engine/synthesisers.py index 177e3e049a695ecd06f5d2271f21336dd4eff997..6461d9e363f5f6e0c92831a50580c2748dffa248 100644 --- a/StyleText/engine/synthesisers.py +++ b/StyleText/engine/synthesisers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import numpy as np +import cv2 from utils.config import ArgsParser, load_config, override_config from utils.logging import get_logger @@ -36,8 +38,9 @@ class ImageSynthesiser(object): self.predictor = getattr(predictors, predictor_method)(self.config) def synth_image(self, corpus, style_input, language="en"): - corpus, text_input = self.text_drawer.draw_text(corpus, language) - synth_result = self.predictor.predict(style_input, text_input) + corpus_list, text_input_list = self.text_drawer.draw_text( + corpus, language, style_input_width=style_input.shape[1]) + synth_result = self.predictor.predict(style_input, text_input_list) return synth_result @@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser): for i in range(self.output_num): style_data = self.style_sampler.sample() style_input = style_data["image"] - corpus_language, text_input_label = self.corpus_generator.generate( - ) - text_input_label, text_input = self.text_drawer.draw_text( - text_input_label, corpus_language) + corpus_language, text_input_label = self.corpus_generator.generate() + text_input_label_list, text_input_list = self.text_drawer.draw_text( + text_input_label, + corpus_language, + style_input_width=style_input.shape[1]) - synth_result = self.predictor.predict(style_input, text_input) + text_input_label = "".join(text_input_label_list) + + synth_result = self.predictor.predict(style_input, text_input_list) fake_fusion = synth_result["fake_fusion"] self.writer.save_image(fake_fusion, text_input_label) self.writer.save_label() diff --git a/StyleText/engine/text_drawers.py b/StyleText/engine/text_drawers.py index 8aaac06ec50816bb6e2774972644c0a7dfb908c6..aeec75c3378f91b64b4387ef16971165f0b80ebe 100644 --- a/StyleText/engine/text_drawers.py +++ b/StyleText/engine/text_drawers.py @@ -1,5 +1,6 @@ from PIL import Image, ImageDraw, ImageFont import numpy as np +import cv2 from utils.logging import get_logger @@ -28,7 +29,11 @@ class StdTextDrawer(object): else: return int((self.height - 4)**2 / font_height) - def draw_text(self, corpus, language="en", crop=True): + def draw_text(self, + corpus, + language="en", + crop=True, + style_input_width=None): if language not in self.support_languages: self.logger.warning( "language {} not supported, use en instead.".format(language)) @@ -37,21 +42,43 @@ class StdTextDrawer(object): width = min(self.max_width, len(corpus) * self.height) + 4 else: width = len(corpus) * self.height + 4 - bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) - draw = ImageDraw.Draw(bg) - - char_x = 2 - font = self.font_dict[language] - for i, char_i in enumerate(corpus): - char_size = font.getsize(char_i)[0] - draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) - char_x += char_size - if char_x >= width: - corpus = corpus[0:i + 1] - self.logger.warning("corpus length exceed limit: {}".format( - corpus)) + + if style_input_width is not None: + width = min(width, style_input_width) + + corpus_list = [] + text_input_list = [] + + while len(corpus) != 0: + bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) + draw = ImageDraw.Draw(bg) + char_x = 2 + font = self.font_dict[language] + i = 0 + while i < len(corpus): + char_i = corpus[i] + char_size = font.getsize(char_i)[0] + # split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image) + if char_x + char_size >= width and i != 0: + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + corpus = corpus[i:] + break + draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) + char_x += char_size + + i += 1 + # the whole text is shorter than style input + if i == len(corpus): + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + corpus = corpus[i:] break - text_input = np.array(bg).astype(np.uint8) - text_input = text_input[:, 0:char_x, :] - return corpus, text_input + return corpus_list, text_input_list diff --git a/deploy/hubserving/ocr_det/params.py b/deploy/hubserving/ocr_det/params.py index 132158904d44a5a45600e6cfc9cd3e565ddcef0b..7be88e9bc6673fadfce19a281de4ba4d2b235fd2 100755 --- a/deploy/hubserving/ocr_det/params.py +++ b/deploy/hubserving/ocr_det/params.py @@ -20,7 +20,8 @@ def read_params(): #DB parmas cfg.det_db_thresh = 0.3 cfg.det_db_box_thresh = 0.5 - cfg.det_db_unclip_ratio = 2.0 + cfg.det_db_unclip_ratio = 1.6 + cfg.use_dilation = False # #EAST parmas # cfg.det_east_score_thresh = 0.8 diff --git a/deploy/hubserving/ocr_system/params.py b/deploy/hubserving/ocr_system/params.py index add466668eee0be1e1674fce5f5a07c24c0c5e3f..bd56dc2e8fc05309e27227d25975dc784a17c2cf 100755 --- a/deploy/hubserving/ocr_system/params.py +++ b/deploy/hubserving/ocr_system/params.py @@ -20,7 +20,8 @@ def read_params(): #DB parmas cfg.det_db_thresh = 0.3 cfg.det_db_box_thresh = 0.5 - cfg.det_db_unclip_ratio = 2.0 + cfg.det_db_unclip_ratio = 1.6 + cfg.use_dilation = False #EAST parmas cfg.det_east_score_thresh = 0.8 diff --git a/paddleocr.py b/paddleocr.py index db24aa59e9237ce9cafa972673ecb0b1a3357c33..7c126261eff1168a1888d72f71fb284e347f9ec9 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True): # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) + parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) + parser.add_argument("--use_dilation", type=bool, default=False) # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) @@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True): det_limit_type='max', det_db_thresh=0.3, det_db_box_thresh=0.5, - det_db_unclip_ratio=2.0, + det_db_unclip_ratio=1.6, + use_dilation=False, det_east_score_thresh=0.8, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, diff --git a/ppocr/losses/det_basic_loss.py b/ppocr/losses/det_basic_loss.py index 57b3667d9f32a871f748c40a65429551613991ca..eba5526dd2bd1c0328130b50817172df437cc360 100644 --- a/ppocr/losses/det_basic_loss.py +++ b/ppocr/losses/det_basic_loss.py @@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): i, :, :], ohem_ratio)) selected_masks = np.concatenate(selected_masks, 0) - selected_masks = paddle.to_variable(selected_masks) + selected_masks = paddle.to_tensor(selected_masks) return selected_masks diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index 59e26c1e20c72c1993a00de96ba66ee5462b74cc..0d222714ff7edebfc717daa81d48ce7424dfbd03 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -57,6 +57,9 @@ class AttentionHead(nn.Layer): else: targets = paddle.zeros(shape=[batch_size], dtype="int32") probs = None + char_onehots = None + outputs = None + alpha = None for i in range(num_steps): char_onehots = self._char_to_onehot( @@ -146,9 +149,6 @@ class AttentionLSTM(nn.Layer): else: targets = paddle.zeros(shape=[batch_size], dtype="int32") probs = None - char_onehots = None - outputs = None - alpha = None for i in range(num_steps): char_onehots = self._char_to_onehot( diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b3d9d4907ba35f7cfade795b6d3897c525d41e6d..b24e57dd973bc0216f2875232bcec6e36ab47e29 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -248,9 +248,11 @@ class TextRecognizer(object): def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) + total_run_time = 0.0 + total_images_num = 0 valid_image_file_list = [] img_list = [] - for image_file in image_file_list: + for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: img = cv2.imread(image_file) @@ -259,22 +261,29 @@ def main(args): continue valid_image_file_list.append(image_file) img_list.append(img) - try: - rec_res, predict_time = text_recognizer(img_list) - except: - logger.info(traceback.format_exc()) - logger.info( - "ERROR!!!! \n" - "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" - "If your model has tps module: " - "TPS does not support variable shape.\n" - "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") - exit() - for ino in range(len(img_list)): - logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], - rec_res[ino])) + if len(img_list) >= args.rec_batch_num or idx == len( + image_file_list) - 1: + try: + rec_res, predict_time = text_recognizer(img_list) + total_run_time += predict_time + except: + logger.info(traceback.format_exc()) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' " + ) + exit() + for ino in range(len(img_list)): + logger.info("Predicts of {}:{}".format(valid_image_file_list[ + ino], rec_res[ino])) + total_images_num += len(valid_image_file_list) + valid_image_file_list = [] + img_list = [] logger.info("Total predict time for {} images, cost: {:.3f}".format( - len(img_list), predict_time)) + total_images_num, total_run_time)) if __name__ == "__main__":