From fcfdd0a39e89699ff39c8cd0a936dcce8b5c44e7 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Tue, 26 May 2020 21:02:27 +0800 Subject: [PATCH] set mode of DB head as 'export' when export model --- ppocr/modeling/architectures/det_model.py | 5 ++++- ppocr/modeling/heads/det_db_head.py | 2 +- tools/program.py | 6 +++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/ppocr/modeling/architectures/det_model.py b/ppocr/modeling/architectures/det_model.py index e09bcbe9..5016546a 100755 --- a/ppocr/modeling/architectures/det_model.py +++ b/ppocr/modeling/architectures/det_model.py @@ -109,7 +109,10 @@ class DetModel(object): """ image, labels, loader = self.create_feed(mode) conv_feas = self.backbone(image) - predicts = self.head(conv_feas) + if self.algorithm == "DB": + predicts = self.head(conv_feas, mode) + else: + predicts = self.head(conv_feas) if mode == "train": losses = self.loss(predicts, labels) return loader, losses diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index bafacaaa..c89a1255 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -196,7 +196,7 @@ class DBHead(object): fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1) shrink_maps = self.binarize(fuse) if mode != "train": - return {"maps", shrink_maps} + return {"maps": shrink_maps} threshold_maps = self.thresh(fuse) binary_maps = self.step_function(shrink_maps, threshold_maps) y = fluid.layers.concat( diff --git a/tools/program.py b/tools/program.py index 67cef9bc..18a4ab7d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) image, outputs = model(mode='export') - fetches_var_name = sorted([name for name in outputs]) + fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var = [outputs[name] for name in fetches_var_name] feeded_var_names = [image.name] target_vars = fetches_var @@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): train_loader.reset() if epoch == 0 and save_epoch_step == 1: save_path = save_model_dir + "/iter_epoch_0" - save_model(train_info_dict['train_program'],save_path) + save_model(train_info_dict['train_program'], save_path) if epoch > 0 and epoch % save_epoch_step == 0: save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_model(train_info_dict['train_program'], save_path) @@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): train_loader.reset() if epoch == 0 and save_epoch_step == 1: save_path = save_model_dir + "/iter_epoch_0" - save_model(train_info_dict['train_program'],save_path) + save_model(train_info_dict['train_program'], save_path) if epoch > 0 and epoch % save_epoch_step == 0: save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_model(train_info_dict['train_program'], save_path) -- GitLab