提交 29e2fed3 编写于 作者: T tink2123

update code

上级 aa7e9ac3
......@@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard():
func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config)
loss_type = config['Global']['loss_type']
if loss_type == "srn":
algorithm = config['Global']['algorithm']
if algorithm == "SRN":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name]
if loss_type == "srn":
if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册