From f24182e6d85ed66e060ce8c780c68c8c15b75de6 Mon Sep 17 00:00:00 2001 From: tangwei Date: Mon, 20 Apr 2020 17:08:17 +0800 Subject: [PATCH] fix bug --- fleetrec/examples/user_define_trainer.py | 6 ++---- fleetrec/run.py | 22 +++++++++++++++++----- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/fleetrec/examples/user_define_trainer.py b/fleetrec/examples/user_define_trainer.py index 194d1e2b..068fd690 100644 --- a/fleetrec/examples/user_define_trainer.py +++ b/fleetrec/examples/user_define_trainer.py @@ -29,11 +29,9 @@ class UserDefineTrainer(TranspileTrainer): self.regist_context_processor('train_pass', self.train) def init(self, context): - self.model.net() - self.model.metrics() - self.model.avg_loss() + self.model.train_net() optimizer = self.model.optimizer() - optimizer.minimize(self.model._cost) + optimizer.minimize((self.model.get_cost_op())) self.fetch_vars = [] self.fetch_alias = [] diff --git a/fleetrec/run.py b/fleetrec/run.py index 6027a678..51c8a583 100644 --- a/fleetrec/run.py +++ b/fleetrec/run.py @@ -21,10 +21,15 @@ def set_runtime_envs(cluster_envs, engine_yaml): if cluster_envs is None: cluster_envs = {} - cluster_envs.update(cluster_envs) - cluster_envs.update(_envs) - # envs.set_runtime_envions(cluster_envs) - print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) + envs.set_runtime_envions(cluster_envs) + envs.set_runtime_envions(_envs) + + need_print = {} + for k, v in os.environ.items(): + if k.startswith() == "trainer.": + need_print[k] = v + + print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value"))) def get_engine(engine): @@ -45,6 +50,7 @@ def single_engine(args): single_envs = {} single_envs["trainer.trainer"] = "SingleTrainer" single_envs["trainer.threads"] = "2" + single_envs["trainer.engine"] = "single" set_runtime_envs(single_envs, args.engine_extras) trainer = TrainerFactory.create(args.model) return trainer @@ -55,6 +61,7 @@ def cluster_engine(args): cluster_envs = {} cluster_envs["trainer.trainer"] = "ClusterTrainer" + cluster_envs["trainer.engine"] = "cluster" set_runtime_envs(cluster_envs, args.engine_extras) envs.set_runtime_envions(cluster_envs) @@ -85,6 +92,7 @@ def local_cluster_engine(args): cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.strategy"] = "async" cluster_envs["trainer.threads"] = "2" + cluster_envs["trainer.engine"] = "local_cluster" cluster_envs["CPU_NUM"] = "2" set_runtime_envs(cluster_envs, args.engine_extras) @@ -102,8 +110,12 @@ def local_mpi_engine(args): mpi = util.run_which("mpirun") if not mpi: raise RuntimeError("can not find mpirun, please check environment") + cluster_envs = {} + cluster_envs["mpirun"] = mpi + cluster_envs["trainer.trainer"] = "CtrCodingTrainer" + cluster_envs["log_dir"] = "logs" + cluster_envs["trainer.engine"] = "local_cluster" - cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"} set_runtime_envs(cluster_envs, args.engine_extras) launch = LocalMPIEngine(cluster_envs, args.model) return launch -- GitLab