未验证 提交 9b2c0e48 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #1235 from WenmuZhou/dygraph_rc

修复ips计算过少的问题
......@@ -123,7 +123,7 @@ class BaseRecLabelEncode(object):
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) > self.max_text_len:
if len(text) == 0 or len(text) > self.max_text_len:
return None
if self.character_type == "en":
text = text.lower()
......@@ -138,9 +138,6 @@ class BaseRecLabelEncode(object):
return None
return text_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......@@ -160,8 +157,6 @@ class CTCLabelEncode(BaseRecLabelEncode):
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
......@@ -195,11 +190,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
text = self.encode(text)
return text
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
......
......@@ -82,7 +82,7 @@ class TextClassifier(object):
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
predict_time = 0
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
......@@ -107,14 +107,14 @@ class TextClassifier(object):
self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu()
cls_res = self.postprocess_op(prob_out)
elapse = time.time() - starttime
elapse += time.time() - starttime
for rno in range(len(cls_res)):
label, score = cls_res[rno]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, predict_time
return img_list, cls_res, elapse
def main(args):
......@@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
print("Total predict time for %d images, cost: %.3f" %
(len(img_list), predict_time))
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
ino]))
print("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
if __name__ == "__main__":
main(utility.parse_args())
if __name__ == "__main__":
main(utility.parse_args())
......@@ -174,15 +174,15 @@ if __name__ == "__main__":
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
dt_boxes, elapse = text_detector(img)
if count > 0:
total_time += elapse
count += 1
print("Predict time of %s:" % image_file, elapse)
print("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = image_file.split("/")[-1]
cv2.imwrite(
os.path.join(draw_img_save, "det_res_%s" % img_name_pure), src_im)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
if count > 1:
print("Avg Time:", total_time / (count - 1))
......@@ -115,7 +115,7 @@ class TextRecognizer(object):
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse = time.time() - starttime
elapse += time.time() - starttime
return rec_res, elapse
......@@ -145,9 +145,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images, cost: %.3f" %
(len(img_list), predict_time))
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
ino]))
print("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
if __name__ == "__main__":
......
......@@ -236,7 +236,6 @@ def train(config,
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
......@@ -275,6 +274,7 @@ def train(config,
best_model_dict[main_indicator],
global_step)
global_step += 1
batch_start = time.time()
if dist.get_rank() == 0:
save_model(
model,
......@@ -333,20 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return metirc
def save_inference_mode(model, config, logger):
model.eval()
save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'],
config['Architecture']['model_type'])
if config['Architecture']['model_type'] == 'rec':
input_shape = [None, 3, 32, None]
jit_model = paddle.jit.to_static(
model, input_spec=[paddle.static.InputSpec(input_shape)])
paddle.jit.save(jit_model, save_path)
logger.info('inference model save to {}'.format(save_path))
model.train()
def preprocess():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
......
......@@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer):
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
program.save_inference_mode(model, config, logger)
def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册