diff --git a/tools/program.py b/tools/program.py index 09552f41c63c38f5addce6a9ea6d0ca07f36ce85..56f6b6993022d092095d7c3545f9ea8833900bc6 100755 --- a/tools/program.py +++ b/tools/program.py @@ -208,14 +208,14 @@ def build_export(config, main_prog, startup_prog): with fluid.unique_name.guard(): func_infor = config['Architecture']['function'] model = create_module(func_infor)(params=config) - loss_type = config['Global']['loss_type'] - if loss_type == "srn": + algorithm = config['Global']['algorithm'] + if algorithm == "SRN": image, others, outputs = model(mode='export') else: image, outputs = model(mode='export') fetches_var_name = sorted([name for name in outputs.keys()]) fetches_var = [outputs[name] for name in fetches_var_name] - if loss_type == "srn": + if algorithm == "SRN": others_var_names = sorted([name for name in others.keys()]) feeded_var_names = [image.name] + others_var_names else: