提交 04e71041 编写于 作者: W wangjingyeye

add db++

上级 1315cdfc
......@@ -18,7 +18,7 @@ Global:
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB
algorithm: DB++
Transform: null
Backbone:
name: ResNet
......
......@@ -18,7 +18,7 @@ Global:
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB
algorithm: DB++
Transform: null
Backbone:
name: ResNet
......
......@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
'mean':
[0.48109378172549, 0.45752457890196, 0.40787054090196],
'scale': '1./255.',
'order': 'hwc'
}
}
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
......@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']:
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
......
......@@ -307,7 +307,8 @@ def train(config,
train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
log_writer.log_metrics(
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
if dist.get_rank() == 0 and (
(global_step > 0 and global_step % print_batch_step == 0) or
......@@ -354,7 +355,8 @@ def train(config,
# logger metric
if log_writer is not None:
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
log_writer.log_metrics(
metrics=cur_metric, prefix="EVAL", step=global_step)
if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
......@@ -377,11 +379,18 @@ def train(config,
logger.info(best_str)
# logger best metric
if log_writer is not None:
log_writer.log_metrics(metrics={
"best_{}".format(main_indicator): best_model_dict[main_indicator]
}, prefix="EVAL", step=global_step)
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
log_writer.log_metrics(
metrics={
"best_{}".format(main_indicator):
best_model_dict[main_indicator]
},
prefix="EVAL",
step=global_step)
log_writer.log_model(
is_best=True,
prefix="best_accuracy",
metadata=best_model_dict)
reader_start = time.time()
if dist.get_rank() == 0:
......@@ -413,7 +422,8 @@ def train(config,
epoch=epoch,
global_step=global_step)
if log_writer is not None:
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
log_writer.log_model(
is_best=False, prefix='iter_epoch_{}'.format(epoch))
best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
......@@ -564,7 +574,7 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'DB++'
]
if use_xpu:
......@@ -585,7 +595,8 @@ def preprocess(is_train=False):
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
log_writer = VDLLogger(save_model_dir)
loggers.append(log_writer)
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
if ('use_wandb' in config['Global'] and
config['Global']['use_wandb']) or 'wandb' in config:
save_dir = config['Global']['save_model_dir']
wandb_writer_path = "{}/wandb".format(save_dir)
if "wandb" in config:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册