diff --git a/configs/det/det_r50_db++_ic15.yml b/configs/det/det_r50_db++_ic15.yml index a90c868cc33797396158c456c9d6c17204fe479b..e0cd6012b660573a79ff013a1b6e2309074a3d86 100644 --- a/configs/det/det_r50_db++_ic15.yml +++ b/configs/det/det_r50_db++_ic15.yml @@ -18,7 +18,7 @@ Global: save_res_path: ./checkpoints/det_db/predicts_db.txt Architecture: model_type: det - algorithm: DB + algorithm: DB++ Transform: null Backbone: name: ResNet diff --git a/configs/det/det_r50_db++_td_tr.yml b/configs/det/det_r50_db++_td_tr.yml index 5e26ddb0ac4ad8a27fe2cd4ca865421d30e0a6ec..65021bb66184381ba732980ac1b7a65d7bd3a355 100644 --- a/configs/det/det_r50_db++_td_tr.yml +++ b/configs/det/det_r50_db++_td_tr.yml @@ -18,7 +18,7 @@ Global: save_res_path: ./checkpoints/det_db/predicts_db.txt Architecture: model_type: det - algorithm: DB + algorithm: DB++ Transform: null Backbone: name: ResNet diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 7b6bebf1fbced2de5bb0e4e75840fb8dd7beb374..394a48948b1f284bd405532769b76eeb298668bd 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -67,6 +67,23 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation postprocess_params["score_mode"] = args.det_db_score_mode + elif self.det_algorithm == "DB++": + postprocess_params['name'] = 'DBPostProcess' + postprocess_params["thresh"] = args.det_db_thresh + postprocess_params["box_thresh"] = args.det_db_box_thresh + postprocess_params["max_candidates"] = 1000 + postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio + postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode + pre_process_list[1] = { + 'NormalizeImage': { + 'std': [1.0, 1.0, 1.0], + 'mean': + [0.48109378172549, 0.45752457890196, 0.40787054090196], + 'scale': '1./255.', + 'order': 'hwc' + } + } elif self.det_algorithm == "EAST": postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh @@ -231,7 +248,7 @@ class TextDetector(object): preds['f_score'] = outputs[1] preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] - elif self.det_algorithm in ['DB', 'PSE']: + elif self.det_algorithm in ['DB', 'PSE', 'DB++']: preds['maps'] = outputs[0] elif self.det_algorithm == 'FCE': for i, output in enumerate(outputs): diff --git a/tools/program.py b/tools/program.py index aa0d2698cf66c928f87217996c31c042e1c8aa02..620c61f093b6e6626bd85eac0c4b1a2bb482fd59 100755 --- a/tools/program.py +++ b/tools/program.py @@ -307,7 +307,8 @@ def train(config, train_stats.update(stats) if log_writer is not None and dist.get_rank() == 0: - log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) + log_writer.log_metrics( + metrics=train_stats.get(), prefix="TRAIN", step=global_step) if dist.get_rank() == 0 and ( (global_step > 0 and global_step % print_batch_step == 0) or @@ -354,7 +355,8 @@ def train(config, # logger metric if log_writer is not None: - log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) + log_writer.log_metrics( + metrics=cur_metric, prefix="EVAL", step=global_step) if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: @@ -377,11 +379,18 @@ def train(config, logger.info(best_str) # logger best metric if log_writer is not None: - log_writer.log_metrics(metrics={ - "best_{}".format(main_indicator): best_model_dict[main_indicator] - }, prefix="EVAL", step=global_step) - - log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) + log_writer.log_metrics( + metrics={ + "best_{}".format(main_indicator): + best_model_dict[main_indicator] + }, + prefix="EVAL", + step=global_step) + + log_writer.log_model( + is_best=True, + prefix="best_accuracy", + metadata=best_model_dict) reader_start = time.time() if dist.get_rank() == 0: @@ -413,7 +422,8 @@ def train(config, epoch=epoch, global_step=global_step) if log_writer is not None: - log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) + log_writer.log_model( + is_best=False, prefix='iter_epoch_{}'.format(epoch)) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) @@ -564,7 +574,7 @@ def preprocess(is_train=False): assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', - 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' + 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'DB++' ] if use_xpu: @@ -585,7 +595,8 @@ def preprocess(is_train=False): vdl_writer_path = '{}/vdl/'.format(save_model_dir) log_writer = VDLLogger(save_model_dir) loggers.append(log_writer) - if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: + if ('use_wandb' in config['Global'] and + config['Global']['use_wandb']) or 'wandb' in config: save_dir = config['Global']['save_model_dir'] wandb_writer_path = "{}/wandb".format(save_dir) if "wandb" in config: