diff --git a/tools/eval.py b/tools/eval.py index 8ae270b35588b3dc2b23cbc9a2d74d9f4c2bb022..0120baab0f34d5fadbbf4df20d92d6b62dd176a2 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -55,6 +55,10 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" + if "model_type" in config['Architecture'].keys(): + model_type = config['Architecture']['model_type'] + else: + model_type = None best_model_dict = init_model(config, model) if len(best_model_dict):