diff --git a/StyleText/engine/predictors.py b/StyleText/engine/predictors.py index a1ba21f1b6cd084f9f95140d4227d600d4631715..ca9ab9ce6fc471e077970766c252c98b8617c6cc 100644 --- a/StyleText/engine/predictors.py +++ b/StyleText/engine/predictors.py @@ -38,7 +38,15 @@ class StyleTextRecPredictor(object): self.std = config["Predictor"]["std"] self.expand_result = config["Predictor"]["expand_result"] - def predict(self, style_input, text_input): + def reshape_to_same_height(self, img_list): + h = img_list[0].shape[0] + for idx in range(1, len(img_list)): + new_w = round(1.0 * img_list[idx].shape[1] / + img_list[idx].shape[0] * h) + img_list[idx] = cv2.resize(img_list[idx], (new_w, h)) + return img_list + + def predict_single_image(self, style_input, text_input): style_input = self.rep_style_input(style_input, text_input) tensor_style_input = self.preprocess(style_input) tensor_text_input = self.preprocess(text_input) @@ -64,6 +72,21 @@ class StyleTextRecPredictor(object): "fake_bg": fake_bg, } + def predict(self, style_input, text_input_list): + if not isinstance(text_input_list, (tuple, list)): + return self.predict_single_image(style_input, text_input_list) + + synth_result_list = [] + for text_input in text_input_list: + synth_result = self.predict_single_image(style_input, text_input) + synth_result_list.append(synth_result) + + for key in synth_result: + res = [r[key] for r in synth_result_list] + res = self.reshape_to_same_height(res) + synth_result[key] = np.concatenate(res, axis=1) + return synth_result + def preprocess(self, img): img = (img.astype('float32') * self.scale - self.mean) / self.std img_height, img_width, channel = img.shape diff --git a/StyleText/engine/synthesisers.py b/StyleText/engine/synthesisers.py index 177e3e049a695ecd06f5d2271f21336dd4eff997..6461d9e363f5f6e0c92831a50580c2748dffa248 100644 --- a/StyleText/engine/synthesisers.py +++ b/StyleText/engine/synthesisers.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import numpy as np +import cv2 from utils.config import ArgsParser, load_config, override_config from utils.logging import get_logger @@ -36,8 +38,9 @@ class ImageSynthesiser(object): self.predictor = getattr(predictors, predictor_method)(self.config) def synth_image(self, corpus, style_input, language="en"): - corpus, text_input = self.text_drawer.draw_text(corpus, language) - synth_result = self.predictor.predict(style_input, text_input) + corpus_list, text_input_list = self.text_drawer.draw_text( + corpus, language, style_input_width=style_input.shape[1]) + synth_result = self.predictor.predict(style_input, text_input_list) return synth_result @@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser): for i in range(self.output_num): style_data = self.style_sampler.sample() style_input = style_data["image"] - corpus_language, text_input_label = self.corpus_generator.generate( - ) - text_input_label, text_input = self.text_drawer.draw_text( - text_input_label, corpus_language) + corpus_language, text_input_label = self.corpus_generator.generate() + text_input_label_list, text_input_list = self.text_drawer.draw_text( + text_input_label, + corpus_language, + style_input_width=style_input.shape[1]) - synth_result = self.predictor.predict(style_input, text_input) + text_input_label = "".join(text_input_label_list) + + synth_result = self.predictor.predict(style_input, text_input_list) fake_fusion = synth_result["fake_fusion"] self.writer.save_image(fake_fusion, text_input_label) self.writer.save_label() diff --git a/StyleText/engine/text_drawers.py b/StyleText/engine/text_drawers.py index 8aaac06ec50816bb6e2774972644c0a7dfb908c6..aeec75c3378f91b64b4387ef16971165f0b80ebe 100644 --- a/StyleText/engine/text_drawers.py +++ b/StyleText/engine/text_drawers.py @@ -1,5 +1,6 @@ from PIL import Image, ImageDraw, ImageFont import numpy as np +import cv2 from utils.logging import get_logger @@ -28,7 +29,11 @@ class StdTextDrawer(object): else: return int((self.height - 4)**2 / font_height) - def draw_text(self, corpus, language="en", crop=True): + def draw_text(self, + corpus, + language="en", + crop=True, + style_input_width=None): if language not in self.support_languages: self.logger.warning( "language {} not supported, use en instead.".format(language)) @@ -37,21 +42,43 @@ class StdTextDrawer(object): width = min(self.max_width, len(corpus) * self.height) + 4 else: width = len(corpus) * self.height + 4 - bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) - draw = ImageDraw.Draw(bg) - - char_x = 2 - font = self.font_dict[language] - for i, char_i in enumerate(corpus): - char_size = font.getsize(char_i)[0] - draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) - char_x += char_size - if char_x >= width: - corpus = corpus[0:i + 1] - self.logger.warning("corpus length exceed limit: {}".format( - corpus)) + + if style_input_width is not None: + width = min(width, style_input_width) + + corpus_list = [] + text_input_list = [] + + while len(corpus) != 0: + bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) + draw = ImageDraw.Draw(bg) + char_x = 2 + font = self.font_dict[language] + i = 0 + while i < len(corpus): + char_i = corpus[i] + char_size = font.getsize(char_i)[0] + # split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image) + if char_x + char_size >= width and i != 0: + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + corpus = corpus[i:] + break + draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) + char_x += char_size + + i += 1 + # the whole text is shorter than style input + if i == len(corpus): + text_input = np.array(bg).astype(np.uint8) + text_input = text_input[:, 0:char_x, :] + + corpus_list.append(corpus[0:i]) + text_input_list.append(text_input) + corpus = corpus[i:] break - text_input = np.array(bg).astype(np.uint8) - text_input = text_input[:, 0:char_x, :] - return corpus, text_input + return corpus_list, text_input_list