未验证 提交 e9b7b195 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #97 from LDOUBLEV/fixocr

set mode of DB head as 'export' when export model
...@@ -109,7 +109,10 @@ class DetModel(object): ...@@ -109,7 +109,10 @@ class DetModel(object):
""" """
image, labels, loader = self.create_feed(mode) image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image) conv_feas = self.backbone(image)
predicts = self.head(conv_feas) if self.algorithm == "DB":
predicts = self.head(conv_feas, mode)
else:
predicts = self.head(conv_feas)
if mode == "train": if mode == "train":
losses = self.loss(predicts, labels) losses = self.loss(predicts, labels)
return loader, losses return loader, losses
......
...@@ -196,7 +196,7 @@ class DBHead(object): ...@@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1) fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse) shrink_maps = self.binarize(fuse)
if mode != "train": if mode != "train":
return {"maps", shrink_maps} return {"maps": shrink_maps}
threshold_maps = self.thresh(fuse) threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat( y = fluid.layers.concat(
......
...@@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog): ...@@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
...@@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict): ...@@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path) save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
...@@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path) save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册