diff --git a/fleetrec/core/factory.py b/fleetrec/core/factory.py index ef882e9ee21a79340c28266701dc6d8834862d82..6bf8a176161e33bbec6430f628fdf71081031ecb 100644 --- a/fleetrec/core/factory.py +++ b/fleetrec/core/factory.py @@ -25,7 +25,7 @@ class TrainerFactory(object): pass @staticmethod - def _build_trainer(config, yaml_path): + def _build_trainer(yaml_path): print(envs.pretty_print_envs(envs.get_global_envs())) train_mode = envs.get_global_env("train.trainer") @@ -40,8 +40,8 @@ class TrainerFactory(object): from fleetrec.core.trainers.cluster_trainer import ClusterTrainer trainer = ClusterTrainer(yaml_path) elif train_mode == "CtrTraining": - from fleetrec.core.trainers.ctr_modul_trainer import CtrPaddleTrainer - trainer = CtrPaddleTrainer(config) + from fleetrec.core.trainers.ctr_coding_trainer import CtrPaddleTrainer + trainer = CtrPaddleTrainer(yaml_path) elif train_mode == "UserDefineTraining": train_location = envs.get_global_env("train.location") train_dirname = os.path.dirname(train_location) @@ -63,7 +63,7 @@ class TrainerFactory(object): raise ValueError("fleetrec's config only support yaml") envs.set_global_envs(_config) - trainer = TrainerFactory._build_trainer(_config, config) + trainer = TrainerFactory._build_trainer(config) return trainer diff --git a/fleetrec/core/trainer.py b/fleetrec/core/trainer.py index 0158202b420b40783d7b9bb6f70dce794a9680bd..b080a87e9c4f0f61bb52823945e2186aa32e41b9 100755 --- a/fleetrec/core/trainer.py +++ b/fleetrec/core/trainer.py @@ -14,6 +14,8 @@ import abc import time + +import yaml from paddle import fluid @@ -28,7 +30,10 @@ class Trainer(object): self._exe = fluid.Executor(self._place) self._exector_context = {} self._context = {'status': 'uninit', 'is_exit': False} - self._config = config + self._config_yaml = config + + with open(config, 'r') as rb: + self._config = yaml.load(rb.read(), Loader=yaml.FullLoader) def regist_context_processor(self, status_name, processor): """ diff --git a/fleetrec/core/trainers/ctr_coding_trainer.py b/fleetrec/core/trainers/ctr_coding_trainer.py index d8751c30c29cb7bf1dcf07cdba532e32634b98b7..7ba3bec71b260acd391ada55df708eb57c22c08a 100755 --- a/fleetrec/core/trainers/ctr_coding_trainer.py +++ b/fleetrec/core/trainers/ctr_coding_trainer.py @@ -62,7 +62,7 @@ class CtrPaddleTrainer(Trainer): reader_class = envs.get_global_env("class", None, namespace) abs_dir = os.path.dirname(os.path.abspath(__file__)) reader = os.path.join(abs_dir, '../utils', 'reader_instance.py') - pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) + pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) train_data_path = envs.get_global_env("train_data_path", None, namespace) dataset = fluid.DatasetFactory().create_dataset() diff --git a/fleetrec/run.py b/fleetrec/run.py index a3ea24f8a5d51abb1edaab547b48184d6e61648a..a3cdc2175be597924a48ed027beea153952b7b57 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -17,7 +17,6 @@ def run(model_yaml): def single_engine(single_envs, model_yaml): print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) - envs.set_runtime_envions(single_envs) run(model_yaml) @@ -33,8 +32,8 @@ def local_cluster_engine(cluster_envs, model_yaml): def local_mpi_engine(cluster_envs, model_yaml): from fleetrec.core.engine.local_mpi_engine import LocalMPIEngine - print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) + print(envs.pretty_print_envs(cluster_envs, ("Local MPI Cluster Envs", "Value"))) envs.set_runtime_envions(cluster_envs) launch = LocalMPIEngine(cluster_envs, model_yaml) launch.run() @@ -79,7 +78,7 @@ if __name__ == "__main__": if not mpi_path: raise RuntimeError("can not find mpirun, please check environment") - cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"} local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "LOCAL_CLUSTER": print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) @@ -100,7 +99,7 @@ if __name__ == "__main__": if not mpi_path: raise RuntimeError("can not find mpirun, please check environment") - cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining"} + cluster_envs = {"mpirun": mpi_path, "train.trainer": "CtrTraining", "log_dir": "logs"} local_mpi_engine(cluster_envs, args.model) elif args.engine.upper() == "CLUSTER": print("launch ClusterTraining with cluster to run model: {}".format(args.model))