提交 04e71041 编写于 作者: W wangjingyeye

add db++

上级 1315cdfc
...@@ -18,7 +18,7 @@ Global: ...@@ -18,7 +18,7 @@ Global:
save_res_path: ./checkpoints/det_db/predicts_db.txt save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture: Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB++
Transform: null Transform: null
Backbone: Backbone:
name: ResNet name: ResNet
......
...@@ -18,7 +18,7 @@ Global: ...@@ -18,7 +18,7 @@ Global:
save_res_path: ./checkpoints/det_db/predicts_db.txt save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture: Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB++
Transform: null Transform: null
Backbone: Backbone:
name: ResNet name: ResNet
......
...@@ -67,6 +67,23 @@ class TextDetector(object): ...@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
'mean':
[0.48109378172549, 0.45752457890196, 0.40787054090196],
'scale': '1./255.',
'order': 'hwc'
}
}
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
...@@ -231,7 +248,7 @@ class TextDetector(object): ...@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']: elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE': elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
......
...@@ -307,7 +307,8 @@ def train(config, ...@@ -307,7 +307,8 @@ def train(config,
train_stats.update(stats) train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and ( if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
...@@ -354,7 +355,8 @@ def train(config, ...@@ -354,7 +355,8 @@ def train(config,
# logger metric # logger metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
...@@ -377,11 +379,18 @@ def train(config, ...@@ -377,11 +379,18 @@ def train(config,
logger.info(best_str) logger.info(best_str)
# logger best metric # logger best metric
if log_writer is not None: if log_writer is not None:
log_writer.log_metrics(metrics={ log_writer.log_metrics(
"best_{}".format(main_indicator): best_model_dict[main_indicator] metrics={
}, prefix="EVAL", step=global_step) "best_{}".format(main_indicator):
best_model_dict[main_indicator]
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) },
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
...@@ -413,7 +422,8 @@ def train(config, ...@@ -413,7 +422,8 @@ def train(config,
epoch=epoch, epoch=epoch,
global_step=global_step) global_step=global_step)
if log_writer is not None: if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
...@@ -564,7 +574,7 @@ def preprocess(is_train=False): ...@@ -564,7 +574,7 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'DB++'
] ]
if use_xpu: if use_xpu:
...@@ -585,7 +595,8 @@ def preprocess(is_train=False): ...@@ -585,7 +595,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir) log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer) loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir'] save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir) wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config: if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册