diff --git a/run.py b/run.py index 94a9386fdcbe89f3ad8e35579783b2e3368d0089..419280898b343ac6a5ae390424951bac5f7f45db 100755 --- a/run.py +++ b/run.py @@ -33,7 +33,9 @@ model_name = "" def engine_registry(): - engines = {"TRANSPILER": {}, "PSLIB": {}} + engines["TRANSPILER"] = {} + engines["PSLIB"] = {} + engines["TRANSPILER"]["SINGLE"] = single_engine engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine engines["TRANSPILER"]["CLUSTER"] = cluster_engine @@ -60,13 +62,16 @@ def get_engine(args): transpiler = get_transpiler() run_extras = get_inters_from_yaml(args.model, "train.") - engine = run_extras.get("train.engine", "") + engine = run_extras.get("train.engine", "single") engine = engine.upper() if engine not in engine_choices: raise ValueError("train.engin can not be chosen in {}".format(engine_choices)) + print("engines: \n{}".format(engines)) + run_engine = engines[transpiler].get(engine, None) + return run_engine