提交 3cae956f 编写于 作者: T tangwei

remove unused flag -d -e

上级 a66a717a
......@@ -28,7 +28,7 @@ from paddlerec.core.utils import dataloader_instance
class TranspileTrainer(Trainer):
def __init__(self, config=None):
Trainer.__init__(self, config)
device = envs.get_global_env("train.device")
device = envs.get_global_env("train.device", "cpu")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
self._exe = fluid.Executor(self._place)
......
......@@ -26,37 +26,47 @@ from paddlerec.core.utils import util
engines = {}
device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
custom_model = ['tdm']
engine_choices = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER",
"TDM_SINGLE", "TDM_LOCAL_CLUSTER", "TDM_CLUSTER"]
custom_model = ['TDM']
model_name = ""
def engine_registry():
cpu = {"TRANSPILER": {}, "PSLIB": {}}
cpu["TRANSPILER"]["SINGLE"] = single_engine
cpu["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
cpu["TRANSPILER"]["CLUSTER"] = cluster_engine
cpu["PSLIB"]["SINGLE"] = local_mpi_engine
cpu["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
cpu["PSLIB"]["CLUSTER"] = cluster_mpi_engine
engines = {"TRANSPILER": {}, "PSLIB": {}}
engines["TRANSPILER"]["SINGLE"] = single_engine
engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRANSPILER"]["CLUSTER"] = cluster_engine
gpu = {"TRANSPILER": {}, "PSLIB": {}}
gpu["TRANSPILER"]["SINGLE"] = single_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine
engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
engines["CPU"] = cpu
engines["GPU"] = gpu
def get_inters_from_yaml(file, filter):
with open(file, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs)
inters = {}
for k, v in flattens.items():
if k.startswith(filter):
inters[k] = v
return inters
def get_engine(args):
device = args.device
d_engine = engines[device]
transpiler = get_transpiler()
run_extras = get_inters_from_yaml(args.model, "train.")
engine = run_extras.get("train.engine", "")
engine = engine.upper()
engine = args.engine
run_engine = d_engine[transpiler].get(engine, None)
if engine not in engine_choices:
raise ValueError("train.engin can not be chosen in {}".format(engine_choices))
if run_engine is None:
raise ValueError(
"engine {} can not be supported on device: {}".format(engine, device))
run_engine = engines[transpiler].get(engine, None)
return run_engine
......@@ -73,24 +83,13 @@ def get_transpiler():
def set_runtime_envs(cluster_envs, engine_yaml):
def get_engine_extras():
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs)
engine_extras = {}
for k, v in flattens.items():
if k.startswith("train.trainer."):
engine_extras[k] = v
return engine_extras
if cluster_envs is None:
cluster_envs = {}
engine_extras = get_engine_extras()
engine_extras = get_inters_from_yaml(engine_yaml, "train.trainer.")
if "train.trainer.threads" in engine_extras and "CPU_NUM" in cluster_envs:
cluster_envs["CPU_NUM"] = engine_extras["train.trainer.threads"]
envs.set_runtime_environs(cluster_envs)
envs.set_runtime_environs(engine_extras)
......@@ -114,7 +113,6 @@ def single_engine(args):
single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single"
single_envs["train.trainer.device"] = args.device
single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model))
......@@ -124,7 +122,6 @@ def single_engine(args):
def cluster_engine(args):
def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None)
if not workspace:
......@@ -144,12 +141,13 @@ def cluster_engine(args):
cluster_envs[name] = value
def master():
role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine
with open(args.backend, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs, "_")
flattens["engine_role"] = args.role
flattens["engine_role"] = role
flattens["engine_run_config"] = args.model
flattens["engine_temp_path"] = tempfile.mkdtemp()
update_workspace(flattens)
......@@ -161,12 +159,12 @@ def cluster_engine(args):
return launch
def worker():
role = "WORKER"
trainer = get_trainer_prefix(args) + "ClusterTrainer"
cluster_envs = {}
cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.threads"] = envs.get_runtime_environ("CPU_NUM")
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
print("launch {} engine with cluster to with model: {}".format(
trainer, args.model))
......@@ -175,7 +173,9 @@ def cluster_engine(args):
trainer = TrainerFactory.create(args.model)
return trainer
if args.role == "WORKER":
role = os.getenv("PADDLE_PADDLEREC_ROLE", "MASTER")
if role == "WORKER":
return worker()
else:
return master()
......@@ -186,7 +186,6 @@ def cluster_mpi_engine(args):
cluster_envs = {}
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
......@@ -208,8 +207,6 @@ def local_cluster_engine(args):
cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
cluster_envs["CPU_NUM"] = "2"
......@@ -235,7 +232,6 @@ def local_mpi_engine(args):
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model)
......@@ -258,29 +254,17 @@ def get_abs_model(model):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='paddle-rec run')
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str,
choices=["single", "local_cluster", "cluster",
"tdm_single", "tdm_local_cluster", "tdm_cluster"])
parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu")
parser.add_argument("-b", "--backend", type=str, default=None)
parser.add_argument("-r", "--role", type=str,
choices=["master", "worker"], default="master")
abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
args = parser.parse_args()
args.engine = args.engine.upper()
args.device = args.device.upper()
args.role = args.role.upper()
model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model)
engine_registry()
which_engine = get_engine(args)
engine = which_engine(args)
engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册