diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index 292a2e4c3ab5e0ccc1c757f0bb5a148b8f3d4eac..cfe7341d503657ed3e8926039c2660d907bd6979 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -33,15 +33,15 @@ class TrainerFactory(object): def _build_trainer(config, yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) - train_mode = envs.get_global_env("train.trainer") + train_mode = envs.get_runtime_envion("train.trainer") if train_mode == "SingleTraining": trainer = SingleTrainer(yaml_path) elif train_mode == "ClusterTraining": trainer = ClusterTrainer(yaml_path) - elif train_mode == "CtrTrainer": + elif train_mode == "CtrTraining": trainer = CtrPaddleTrainer(config) - elif train_mode == "UserDefineTrainer": + elif train_mode == "UserDefineTraining": train_location = envs.get_global_env("train.location") train_dirname = os.path.dirname(train_location) base_name = os.path.splitext(os.path.basename(train_location))[0] diff --git a/fleetrec/core/utils/envs.py b/fleetrec/core/utils/envs.py index eb866fba5b549f924e34d9bbe153fa9d3b381d7d..53d68ce4c43be2cc468a9118d0ffdcf498a761a2 100644 --- a/fleetrec/core/utils/envs.py +++ b/fleetrec/core/utils/envs.py @@ -12,12 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import os import copy global_envs = {} +def set_runtime_envions(envs): + assert isinstance(envs, dict) + + for k, v in envs.items(): + os.environ[k] = v + + +def get_runtime_envion(key): + return os.getenv(key, None) + + def set_global_envs(envs): assert isinstance(envs, dict) @@ -87,4 +98,3 @@ def lazy_instance(package, class_name): model_package = __import__(package, globals(), locals(), package.split(".")) instance = getattr(model_package, class_name) return instance - diff --git a/fleetrec/run.py b/fleetrec/run.py index 1a79ce76fcaabaa0b5124d0ea5d58a420d94abd8..51dbd1fd998f169cc30607c30146da7da2801dcb 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -11,11 +11,10 @@ def run(model_yaml): trainer.run() -def single_engine(model_yaml): - single_envs = {} - single_envs["singleTraning"] = True - +def single_engine(single_envs, model_yaml): print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) + + envs.set_runtime_envions(single_envs) run(model_yaml) @@ -47,25 +46,30 @@ if __name__ == "__main__": if args.engine == "Single": print("use SingleTraining to run model: {}".format(args.model)) - single_engine(args.model) + single_envs = {} + single_envs["train.trainer"] = "SingleTraining" + + single_engine(single_envs, args.model) elif args.engine == "LocalCluster": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["server_num"] = 1 - cluster_envs["worker_num"] = 1 - cluster_envs["start_port"] = 36001 - cluster_envs["log_dir"] = "logs" + cluster_envs["train.server_num"] = 1 + cluster_envs["train.worker_num"] = 1 + cluster_envs["train.start_port"] = 36001 + cluster_envs["train.log_dir"] = "logs" + cluster_envs["train.trainer"] = "SingleTraining" local_cluster_engine(cluster_envs, args.model) elif args.engine == "LocalMPI": print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) cluster_envs = {} - cluster_envs["server_num"] = 1 - cluster_envs["worker_num"] = 1 - cluster_envs["start_port"] = 36001 - cluster_envs["log_dir"] = "logs" + cluster_envs["train.server_num"] = 1 + cluster_envs["train.worker_num"] = 1 + cluster_envs["train.start_port"] = 36001 + cluster_envs["train.log_dir"] = "logs" + cluster_envs["train.trainer"] = "CtrTraining" local_mpi_engine(cluster_envs, args.model) else: