提交 e7986cb4 编写于 作者: T tangwei

fix readme

上级 3ccb3ea7
......@@ -46,7 +46,7 @@ class TrainerFactory(object):
if trainer_abs is None:
if not os.path.exists(train_mode) or os.path.isfile(train_mode):
raise ValueError("trainer {} can not be recognized")
raise ValueError("trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode
train_mode = "UserDefineTrainer"
......
......@@ -16,6 +16,8 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None:
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_envs = {}
if cluster_envs is None:
cluster_envs = {}
......@@ -24,15 +26,6 @@ def set_runtime_envs(cluster_envs, engine_yaml):
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
def get_engine(engine):
engine = engine.upper()
if version.is_transpiler():
......@@ -47,7 +40,7 @@ def get_engine(engine):
def single_engine(args):
print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTraining"}
single_envs = {"trainer.trainer": "SingleTrainer"}
set_runtime_envs(single_envs, args.engine_extras)
......@@ -58,7 +51,7 @@ def single_engine(args):
def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTraining"}
cluster_envs = {"trainer.trainer": "ClusterTrainer"}
set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs)
......@@ -69,7 +62,7 @@ def cluster_engine(args):
def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrTraining"}
cluster_envs = {"trainer.trainer": "CtrCodingTrainer"}
set_runtime_envs(cluster_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model)
......@@ -85,7 +78,7 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTraining"
cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy.mode"] = "async"
set_runtime_envs(cluster_envs, args.engine_extras)
......@@ -104,28 +97,22 @@ def local_mpi_engine(args):
if not mpi:
raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"}
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalMPIEngine(cluster_envs, args.model)
return launch
#
# def yaml_engine(engine_yaml, model_yaml):
# with open(engine_yaml, 'r') as rb:
# _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
# assert _config is not None
#
# envs.set_global_envs(_config)
#
# train_location = envs.get_global_env("engine.file")
# train_dirname = os.path.dirname(train_location)
# base_name = os.path.splitext(os.path.basename(train_location))[0]
# sys.path.append(train_dirname)
# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
# trainer = trainer_class(model_yaml)
# return trainer
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
engine_registry()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册