未验证 提交 6335b0d8 编写于 作者: D Double_V 提交者: GitHub

Merge branch 'dygraph' into fix2013

...@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object): ...@@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
self.std = config["Predictor"]["std"] self.std = config["Predictor"]["std"]
self.expand_result = config["Predictor"]["expand_result"] self.expand_result = config["Predictor"]["expand_result"]
def predict(self, style_input, text_input): def reshape_to_same_height(self, img_list):
h = img_list[0].shape[0]
for idx in range(1, len(img_list)):
new_w = round(1.0 * img_list[idx].shape[1] /
img_list[idx].shape[0] * h)
img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
return img_list
def predict_single_image(self, style_input, text_input):
style_input = self.rep_style_input(style_input, text_input) style_input = self.rep_style_input(style_input, text_input)
tensor_style_input = self.preprocess(style_input) tensor_style_input = self.preprocess(style_input)
tensor_text_input = self.preprocess(text_input) tensor_text_input = self.preprocess(text_input)
...@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object): ...@@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
"fake_bg": fake_bg, "fake_bg": fake_bg,
} }
def predict(self, style_input, text_input_list):
if not isinstance(text_input_list, (tuple, list)):
return self.predict_single_image(style_input, text_input_list)
synth_result_list = []
for text_input in text_input_list:
synth_result = self.predict_single_image(style_input, text_input)
synth_result_list.append(synth_result)
for key in synth_result:
res = [r[key] for r in synth_result_list]
res = self.reshape_to_same_height(res)
synth_result[key] = np.concatenate(res, axis=1)
return synth_result
def preprocess(self, img): def preprocess(self, img):
img = (img.astype('float32') * self.scale - self.mean) / self.std img = (img.astype('float32') * self.scale - self.mean) / self.std
img_height, img_width, channel = img.shape img_height, img_width, channel = img.shape
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import numpy as np
import cv2
from utils.config import ArgsParser, load_config, override_config from utils.config import ArgsParser, load_config, override_config
from utils.logging import get_logger from utils.logging import get_logger
...@@ -36,8 +38,9 @@ class ImageSynthesiser(object): ...@@ -36,8 +38,9 @@ class ImageSynthesiser(object):
self.predictor = getattr(predictors, predictor_method)(self.config) self.predictor = getattr(predictors, predictor_method)(self.config)
def synth_image(self, corpus, style_input, language="en"): def synth_image(self, corpus, style_input, language="en"):
corpus, text_input = self.text_drawer.draw_text(corpus, language) corpus_list, text_input_list = self.text_drawer.draw_text(
synth_result = self.predictor.predict(style_input, text_input) corpus, language, style_input_width=style_input.shape[1])
synth_result = self.predictor.predict(style_input, text_input_list)
return synth_result return synth_result
...@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser): ...@@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
for i in range(self.output_num): for i in range(self.output_num):
style_data = self.style_sampler.sample() style_data = self.style_sampler.sample()
style_input = style_data["image"] style_input = style_data["image"]
corpus_language, text_input_label = self.corpus_generator.generate( corpus_language, text_input_label = self.corpus_generator.generate()
) text_input_label_list, text_input_list = self.text_drawer.draw_text(
text_input_label, text_input = self.text_drawer.draw_text( text_input_label,
text_input_label, corpus_language) corpus_language,
style_input_width=style_input.shape[1])
synth_result = self.predictor.predict(style_input, text_input) text_input_label = "".join(text_input_label_list)
synth_result = self.predictor.predict(style_input, text_input_list)
fake_fusion = synth_result["fake_fusion"] fake_fusion = synth_result["fake_fusion"]
self.writer.save_image(fake_fusion, text_input_label) self.writer.save_image(fake_fusion, text_input_label)
self.writer.save_label() self.writer.save_label()
......
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
import cv2
from utils.logging import get_logger from utils.logging import get_logger
...@@ -28,7 +29,11 @@ class StdTextDrawer(object): ...@@ -28,7 +29,11 @@ class StdTextDrawer(object):
else: else:
return int((self.height - 4)**2 / font_height) return int((self.height - 4)**2 / font_height)
def draw_text(self, corpus, language="en", crop=True): def draw_text(self,
corpus,
language="en",
crop=True,
style_input_width=None):
if language not in self.support_languages: if language not in self.support_languages:
self.logger.warning( self.logger.warning(
"language {} not supported, use en instead.".format(language)) "language {} not supported, use en instead.".format(language))
...@@ -37,21 +42,43 @@ class StdTextDrawer(object): ...@@ -37,21 +42,43 @@ class StdTextDrawer(object):
width = min(self.max_width, len(corpus) * self.height) + 4 width = min(self.max_width, len(corpus) * self.height) + 4
else: else:
width = len(corpus) * self.height + 4 width = len(corpus) * self.height + 4
if style_input_width is not None:
width = min(width, style_input_width)
corpus_list = []
text_input_list = []
while len(corpus) != 0:
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127)) bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
draw = ImageDraw.Draw(bg) draw = ImageDraw.Draw(bg)
char_x = 2 char_x = 2
font = self.font_dict[language] font = self.font_dict[language]
for i, char_i in enumerate(corpus): i = 0
while i < len(corpus):
char_i = corpus[i]
char_size = font.getsize(char_i)[0] char_size = font.getsize(char_i)[0]
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
if char_x + char_size >= width and i != 0:
text_input = np.array(bg).astype(np.uint8)
text_input = text_input[:, 0:char_x, :]
corpus_list.append(corpus[0:i])
text_input_list.append(text_input)
corpus = corpus[i:]
break
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
char_x += char_size char_x += char_size
if char_x >= width:
corpus = corpus[0:i + 1]
self.logger.warning("corpus length exceed limit: {}".format(
corpus))
break
i += 1
# the whole text is shorter than style input
if i == len(corpus):
text_input = np.array(bg).astype(np.uint8) text_input = np.array(bg).astype(np.uint8)
text_input = text_input[:, 0:char_x, :] text_input = text_input[:, 0:char_x, :]
return corpus, text_input
corpus_list.append(corpus[0:i])
text_input_list.append(text_input)
corpus = corpus[i:]
break
return corpus_list, text_input_list
...@@ -20,7 +20,8 @@ def read_params(): ...@@ -20,7 +20,8 @@ def read_params():
#DB parmas #DB parmas
cfg.det_db_thresh = 0.3 cfg.det_db_thresh = 0.3
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 2.0 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False
# #EAST parmas # #EAST parmas
# cfg.det_east_score_thresh = 0.8 # cfg.det_east_score_thresh = 0.8
......
...@@ -20,7 +20,8 @@ def read_params(): ...@@ -20,7 +20,8 @@ def read_params():
#DB parmas #DB parmas
cfg.det_db_thresh = 0.3 cfg.det_db_thresh = 0.3
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 2.0 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False
#EAST parmas #EAST parmas
cfg.det_east_score_thresh = 0.8 cfg.det_east_score_thresh = 0.8
......
...@@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True): ...@@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True):
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
...@@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True): ...@@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True):
det_limit_type='max', det_limit_type='max',
det_db_thresh=0.3, det_db_thresh=0.3,
det_db_box_thresh=0.5, det_db_box_thresh=0.5,
det_db_unclip_ratio=2.0, det_db_unclip_ratio=1.6,
use_dilation=False,
det_east_score_thresh=0.8, det_east_score_thresh=0.8,
det_east_cover_thresh=0.1, det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2, det_east_nms_thresh=0.2,
......
...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio): ...@@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
i, :, :], ohem_ratio)) i, :, :], ohem_ratio))
selected_masks = np.concatenate(selected_masks, 0) selected_masks = np.concatenate(selected_masks, 0)
selected_masks = paddle.to_variable(selected_masks) selected_masks = paddle.to_tensor(selected_masks)
return selected_masks return selected_masks
...@@ -57,6 +57,9 @@ class AttentionHead(nn.Layer): ...@@ -57,6 +57,9 @@ class AttentionHead(nn.Layer):
else: else:
targets = paddle.zeros(shape=[batch_size], dtype="int32") targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None probs = None
char_onehots = None
outputs = None
alpha = None
for i in range(num_steps): for i in range(num_steps):
char_onehots = self._char_to_onehot( char_onehots = self._char_to_onehot(
...@@ -146,9 +149,6 @@ class AttentionLSTM(nn.Layer): ...@@ -146,9 +149,6 @@ class AttentionLSTM(nn.Layer):
else: else:
targets = paddle.zeros(shape=[batch_size], dtype="int32") targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None probs = None
char_onehots = None
outputs = None
alpha = None
for i in range(num_steps): for i in range(num_steps):
char_onehots = self._char_to_onehot( char_onehots = self._char_to_onehot(
......
...@@ -248,9 +248,11 @@ class TextRecognizer(object): ...@@ -248,9 +248,11 @@ class TextRecognizer(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -259,8 +261,11 @@ def main(args): ...@@ -259,8 +261,11 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
if len(img_list) >= args.rec_batch_num or idx == len(
image_file_list) - 1:
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
total_run_time += predict_time
except: except:
logger.info(traceback.format_exc()) logger.info(traceback.format_exc())
logger.info( logger.info(
...@@ -268,13 +273,17 @@ def main(args): ...@@ -268,13 +273,17 @@ def main(args):
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: " "If your model has tps module: "
"TPS does not support variable shape.\n" "TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], logger.info("Predicts of {}:{}".format(valid_image_file_list[
rec_res[ino])) ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) total_images_num, total_run_time))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册