提交 29e2fed3 编写于 作者: T tink2123

update code

上级 aa7e9ac3
...@@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog): ...@@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
loss_type = config['Global']['loss_type'] algorithm = config['Global']['algorithm']
if loss_type == "srn": if algorithm == "SRN":
image, others, outputs = model(mode='export') image, others, outputs = model(mode='export')
else: else:
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
if loss_type == "srn": if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()]) others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names feeded_var_names = [image.name] + others_var_names
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册