提交 e7986cb4 编写于 作者: T tangwei

fix readme

上级 3ccb3ea7
...@@ -46,7 +46,7 @@ class TrainerFactory(object): ...@@ -46,7 +46,7 @@ class TrainerFactory(object):
if trainer_abs is None: if trainer_abs is None:
if not os.path.exists(train_mode) or os.path.isfile(train_mode): if not os.path.exists(train_mode) or os.path.isfile(train_mode):
raise ValueError("trainer {} can not be recognized") raise ValueError("trainer {} can not be recognized".format(train_mode))
trainer_abs = train_mode trainer_abs = train_mode
train_mode = "UserDefineTrainer" train_mode = "UserDefineTrainer"
......
...@@ -16,6 +16,8 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -16,6 +16,8 @@ def set_runtime_envs(cluster_envs, engine_yaml):
if engine_yaml is not None: if engine_yaml is not None:
with open(engine_yaml, 'r') as rb: with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader) _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
_envs = {}
if cluster_envs is None: if cluster_envs is None:
cluster_envs = {} cluster_envs = {}
...@@ -24,15 +26,6 @@ def set_runtime_envs(cluster_envs, engine_yaml): ...@@ -24,15 +26,6 @@ def set_runtime_envs(cluster_envs, engine_yaml):
print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value"))) print(envs.pretty_print_envs(cluster_envs, ("Runtime Envs", "Value")))
def engine_registry():
engines["TRAINSPILER"]["SINGLE"] = single_engine
engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
def get_engine(engine): def get_engine(engine):
engine = engine.upper() engine = engine.upper()
if version.is_transpiler(): if version.is_transpiler():
...@@ -47,7 +40,7 @@ def get_engine(engine): ...@@ -47,7 +40,7 @@ def get_engine(engine):
def single_engine(args): def single_engine(args):
print("use single engine to run model: {}".format(args.model)) print("use single engine to run model: {}".format(args.model))
single_envs = {"trainer.trainer": "SingleTraining"} single_envs = {"trainer.trainer": "SingleTrainer"}
set_runtime_envs(single_envs, args.engine_extras) set_runtime_envs(single_envs, args.engine_extras)
...@@ -58,7 +51,7 @@ def single_engine(args): ...@@ -58,7 +51,7 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "ClusterTraining"} cluster_envs = {"trainer.trainer": "ClusterTrainer"}
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
envs.set_runtime_envions(cluster_envs) envs.set_runtime_envions(cluster_envs)
...@@ -69,7 +62,7 @@ def cluster_engine(args): ...@@ -69,7 +62,7 @@ def cluster_engine(args):
def cluster_mpi_engine(args): def cluster_mpi_engine(args):
print("launch cluster engine with cluster to run model: {}".format(args.model)) print("launch cluster engine with cluster to run model: {}".format(args.model))
cluster_envs = {"trainer.trainer": "CtrTraining"} cluster_envs = {"trainer.trainer": "CtrCodingTrainer"}
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
...@@ -85,7 +78,7 @@ def local_cluster_engine(args): ...@@ -85,7 +78,7 @@ def local_cluster_engine(args):
cluster_envs["worker_num"] = 1 cluster_envs["worker_num"] = 1
cluster_envs["start_port"] = 36001 cluster_envs["start_port"] = 36001
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["trainer.trainer"] = "ClusterTraining" cluster_envs["trainer.trainer"] = "ClusterTrainer"
cluster_envs["trainer.strategy.mode"] = "async" cluster_envs["trainer.strategy.mode"] = "async"
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
...@@ -104,28 +97,22 @@ def local_mpi_engine(args): ...@@ -104,28 +97,22 @@ def local_mpi_engine(args):
if not mpi: if not mpi:
raise RuntimeError("can not find mpirun, please check environment") raise RuntimeError("can not find mpirun, please check environment")
cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrTraining", "log_dir": "logs"} cluster_envs = {"mpirun": mpi, "trainer.trainer": "CtrCodingTrainer", "log_dir": "logs"}
set_runtime_envs(cluster_envs, args.engine_extras) set_runtime_envs(cluster_envs, args.engine_extras)
launch = LocalMPIEngine(cluster_envs, args.model) launch = LocalMPIEngine(cluster_envs, args.model)
return launch return launch
# def engine_registry():
# def yaml_engine(engine_yaml, model_yaml): engines["TRAINSPILER"]["SINGLE"] = single_engine
# with open(engine_yaml, 'r') as rb: engines["TRAINSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
# _config = yaml.load(rb.read(), Loader=yaml.FullLoader) engines["TRAINSPILER"]["CLUSTER"] = cluster_engine
# assert _config is not None engines["PSLIB"]["SINGLE"] = local_mpi_engine
# engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
# envs.set_global_envs(_config) engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
#
# train_location = envs.get_global_env("engine.file")
# train_dirname = os.path.dirname(train_location)
# base_name = os.path.splitext(os.path.basename(train_location))[0]
# sys.path.append(train_dirname)
# trainer_class = envs.lazy_instance(base_name, "UserDefineTraining")
# trainer = trainer_class(model_yaml)
# return trainer
engine_registry()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='fleet-rec run') parser = argparse.ArgumentParser(description='fleet-rec run')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册