提交 f24182e6 编写于 作者: T tangwei

fix bug

上级 42347db7
......@@ -29,11 +29,9 @@ class UserDefineTrainer(TranspileTrainer):
self.regist_context_processor('train_pass', self.train)
def init(self, context):
self.model.net()
self.model.metrics()
self.model.avg_loss()
self.model.train_net()
optimizer = self.model.optimizer()
optimizer.minimize(self.model._cost)
optimizer.minimize((self.model.get_cost_op()))
self.fetch_vars = []
self.fetch_alias = []
......
......@@ -21,10 +21,15 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if cluster_envs is None:
cluster_envs = {}
cluster_envs.update(cluster_envs)
cluster_envs.update(_envs)
# envs.set_runtime_envions(cluster_envs)
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
envs.set_runtime_envions(cluster_envs)
envs.set_runtime_envions(_envs)
need_print = {}
for k, v in os.environ.items():
if k.startswith() == "trainer.":
need_print[k] = v
print(envs.pretty_print_envs(need_print, ("Runtime Envs", "Value")))
def get_engine(engine):
......@@ -45,6 +50,7 @@ def single_engine(args):
single_envs = {}
single_envs["trainer.trainer"] = "SingleTrainer"
single_envs["trainer.threads"] = "2"
single_envs["trainer.engine"] = "single"
set_runtime_envs(single_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model)
return trainer
......@@ -55,6 +61,7 @@ def cluster_engine(args):
cluster_envs = {}
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.engine"] = "cluster"
set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs)
......@@ -85,6 +92,7 @@ def local_cluster_engine(args):
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy"] = "async"
cluster_envs["trainer.threads"] = "2"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs["CPU_NUM"] = "2"
set_runtime_envs(cluster_envs, args.engine_extras)
......@@ -102,8 +110,12 @@ def local_mpi_engine(args):
mpi = util.run_which("mpirun")
if not mpi:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {}
cluster_envs["mpirun"] = mpi
cluster_envs["trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.engine"] = "local_cluster"
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册