提交 f24182e6 编写于 作者: T tangwei

fix bug

上级 42347db7
...@@ -29,11 +29,9 @@ class UserDefineTrainer(TranspileTrainer): ...@@ -29,11 +29,9 @@ class UserDefineTrainer(TranspileTrainer):
self.regist_context_processor('train_pass', self.train) self.regist_context_processor('train_pass', self.train)
def init(self, context): def init(self, context):
self.model.net() self.model.train_net()
self.model.metrics()
self.model.avg_loss()
optimizer = self.model.optimizer() optimizer = self.model.optimizer()
optimizer.minimize(self.model._cost) optimizer.minimize((self.model.get_cost_op()))
self.fetch_vars = [] self.fetch_vars = []
self.fetch_alias = [] self.fetch_alias = []
......
...@@ -21,10 +21,15 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -21,10 +21,15 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if cluster_envs is None: if cluster_envs is None:
cluster_envs = {} cluster_envs = {}
cluster_envs.update(cluster_envs) envs.set_runtime_envions(cluster_envs)
cluster_envs.update(_envs) envs.set_runtime_envions(_envs)
# envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) need_print = {}
for k, v in os.environ.items():
if k.startswith() == "trainer.":
need_print[k] = v
print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
def get_engine(engine): def get_engine(engine):
...@@ -45,6 +50,7 @@ def single_engine(args): ...@@ -45,6 +50,7 @@ def single_engine(args):
single_envs = {} single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer" single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2" single_envs["trainer.threads"] = "2"
single_envs["trainer.engine"] = "single"
set_runtime_envs(single_envs, args.engine_extras) set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
...@@ -55,6 +61,7 @@ def cluster_engine(args): ...@@ -55,6 +61,7 @@ def cluster_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
...@@ -85,6 +92,7 @@ def local_cluster_engine(args): ...@@ -85,6 +92,7 @@ def local_cluster_engine(args):
cluster_envs["trainer.trainer"] = "ClusterTrainer" cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy"] = "async" cluster_envs["trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2" cluster_envs["trainer.threads"] = "2"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2" cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
...@@ -102,8 +110,12 @@ def local_mpi_engine(args): ...@@ -102,8 +110,12 @@ def local_mpi_engine(args):
mpi = util.run_which("mpirun") mpi = util.run_which("mpirun")
if not mpi: if not mpi:
raise RuntimeError("can not find mpirun, please check environment") raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {}
cluster_envs["mpirun"] = mpi
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalMPIEngine(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)
return launch return launch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册