提交 8cf79b11 编写于 作者: T tangwei

update setup.py

上级 24934161
...@@ -43,7 +43,7 @@ class TranspileTrainer(Trainer): ...@@ -43,7 +43,7 @@ class TranspileTrainer(Trainer):
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../reader', 'reader_instance.py') reader = os.path.join(abs_dir, '../utils', 'reader_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config) pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config)
train_data_path = envs.get_global_env("train_data_path", None, namespace) train_data_path = envs.get_global_env("train_data_path", None, namespace)
......
...@@ -54,21 +54,21 @@ if __name__ == "__main__": ...@@ -54,21 +54,21 @@ if __name__ == "__main__":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["train.server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["train.worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["train.start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["train.log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "SingleTraining" cluster_envs["train.trainer"] = "ClusterTraining"
local_cluster_engine(cluster_envs, args.model) local_cluster_engine(cluster_envs, args.model)
elif args.engine == "LocalMPI": elif args.engine == "LocalMPI":
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["train.server_num"] = 1 cluster_envs["server_num"] = 1
cluster_envs["train.worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["train.start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["train.log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer"] = "CtrTraining" cluster_envs["train.trainer"] = "CtrTraining"
local_mpi_engine(cluster_envs, args.model) local_mpi_engine(cluster_envs, args.model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册