diff --git a/fleetrec/run.py b/fleetrec/run.py index db1b56f2adb4a0db0c6ae8cd8aa9d0e093bf4013..0a052072104908db421f38f09d08d1c3df8c7061 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -45,13 +45,13 @@ if __name__ == "__main__": if not os.path.exists(args.model) or not os.path.isfile(args.model): raise ValueError("argument model: {} error, must specify a existed yaml file".format(args.model)) - if args.engine == "Single": + if args.engine.upper == "SINGLE": print("use SingleTraining to run model: {}".format(args.model)) single_envs = {} single_envs["train.trainer"] = "SingleTraining" single_engine(single_envs, args.model) - elif args.engine == "LocalCluster": + elif args.engine.upper == "LOCAL_CLUSTER": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} @@ -63,7 +63,7 @@ if __name__ == "__main__": cluster_envs["train.strategy.mode"] = "async" local_cluster_engine(cluster_envs, args.model) - elif args.engine == "LocalMPI": + elif args.engine.upper == "LOCAL_MPI": print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {}