未验证 提交 e9b7b195 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #97 from LDOUBLEV/fixocr

set mode of DB head as 'export' when export model
......@@ -109,7 +109,10 @@ class DetModel(object):
"""
image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image)
predicts = self.head(conv_feas)
if self.algorithm == "DB":
predicts = self.head(conv_feas, mode)
else:
predicts = self.head(conv_feas)
if mode == "train":
losses = self.loss(predicts, labels)
return loader, losses
......
......@@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse)
if mode != "train":
return {"maps", shrink_maps}
return {"maps": shrink_maps}
threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat(
......
......@@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config)
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs])
fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name]
target_vars = fetches_var
......@@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path)
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
......@@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path)
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册