提交 02ec66fd 编写于 作者: T tangwei

update setup.py

上级 7c852a31
...@@ -33,15 +33,15 @@ class TrainerFactory(object): ...@@ -33,15 +33,15 @@ class TrainerFactory(object):
def _build_trainer(config, yaml_path): def _build_trainer(config, yaml_path):
print(envs.pretty_print_envs(envs.get_global_envs())) print(envs.pretty_print_envs(envs.get_global_envs()))
train_mode = envs.get_global_env("train.trainer") train_mode = envs.get_runtime_envion("train.trainer")
if train_mode == "SingleTraining": if train_mode == "SingleTraining":
trainer = SingleTrainer(yaml_path) trainer = SingleTrainer(yaml_path)
elif train_mode == "ClusterTraining": elif train_mode == "ClusterTraining":
trainer = ClusterTrainer(yaml_path) trainer = ClusterTrainer(yaml_path)
elif train_mode == "CtrTrainer": elif train_mode == "CtrTraining":
trainer = CtrPaddleTrainer(config) trainer = CtrPaddleTrainer(config)
elif train_mode == "UserDefineTrainer": elif train_mode == "UserDefineTraining":
train_location = envs.get_global_env("train.location") train_location = envs.get_global_env("train.location")
train_dirname = os.path.dirname(train_location) train_dirname = os.path.dirname(train_location)
base_name = os.path.splitext(os.path.basename(train_location))[0] base_name = os.path.splitext(os.path.basename(train_location))[0]
......
...@@ -12,12 +12,23 @@ ...@@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import copy import copy
global_envs = {} global_envs = {}
def set_runtime_envions(envs):
assert isinstance(envs, dict)
for k, v in envs.items():
os.environ[k] = v
def get_runtime_envion(key):
return os.getenv(key, None)
def set_global_envs(envs): def set_global_envs(envs):
assert isinstance(envs, dict) assert isinstance(envs, dict)
...@@ -87,4 +98,3 @@ def lazy_instance(package, class_name): ...@@ -87,4 +98,3 @@ def lazy_instance(package, class_name):
model_package = __import__(package, globals(), locals(), package.split(".")) model_package = __import__(package, globals(), locals(), package.split("."))
instance = getattr(model_package, class_name) instance = getattr(model_package, class_name)
return instance return instance
...@@ -11,11 +11,10 @@ def run(model_yaml): ...@@ -11,11 +11,10 @@ def run(model_yaml):
trainer.run() trainer.run()
def single_engine(model_yaml): def single_engine(single_envs, model_yaml):
single_envs = {}
single_envs["singleTraning"] = True
print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value"))) print(envs.pretty_print_envs(single_envs, ("Single Envs", "Value")))
envs.set_runtime_envions(single_envs)
run(model_yaml) run(model_yaml)
...@@ -47,25 +46,30 @@ if __name__ == "__main__": ...@@ -47,25 +46,30 @@ if __name__ == "__main__":
if args.engine == "Single": if args.engine == "Single":
print("use SingleTraining to run model: {}".format(args.model)) print("use SingleTraining to run model: {}".format(args.model))
single_engine(args.model) single_envs = {}
single_envs["train.trainer"] = "SingleTraining"
single_engine(single_envs, args.model)
elif args.engine == "LocalCluster": elif args.engine == "LocalCluster":
print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["train.server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["train.worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["train.start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["train.log_dir"] = "logs"
cluster_envs["train.trainer"] = "SingleTraining"
local_cluster_engine(cluster_envs, args.model) local_cluster_engine(cluster_envs, args.model)
elif args.engine == "LocalMPI": elif args.engine == "LocalMPI":
print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model)) print("use 1X1 MPI ClusterTraining at localhost to run model: {}".format(args.model))
cluster_envs = {} cluster_envs = {}
cluster_envs["server_num"] = 1 cluster_envs["train.server_num"] = 1
cluster_envs["worker_num"] = 1 cluster_envs["train.worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["train.start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["train.log_dir"] = "logs"
cluster_envs["train.trainer"] = "CtrTraining"
local_mpi_engine(cluster_envs, args.model) local_mpi_engine(cluster_envs, args.model)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册