提交 3cae956f 编写于 作者: T tangwei

remove unused flag -d -e

上级 a66a717a
...@@ -28,7 +28,7 @@ from paddlerec.core.utils import dataloader_instance ...@@ -28,7 +28,7 @@ from paddlerec.core.utils import dataloader_instance
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
device = envs.get_global_env("train.device") device = envs.get_global_env("train.device", "cpu")
if device == 'gpu': if device == 'gpu':
self._place = fluid.CUDAPlace(0) self._place = fluid.CUDAPlace(0)
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
......
...@@ -26,37 +26,47 @@ from paddlerec.core.utils import util ...@@ -26,37 +26,47 @@ from paddlerec.core.utils import util
engines = {} engines = {}
device = ["CPU", "GPU"] device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
custom_model = ['tdm'] engine_choices = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER",
"TDM_SINGLE", "TDM_LOCAL_CLUSTER", "TDM_CLUSTER"]
custom_model = ['TDM']
model_name = "" model_name = ""
def engine_registry(): def engine_registry():
cpu = {"TRANSPILER": {}, "PSLIB": {}} engines = {"TRANSPILER": {}, "PSLIB": {}}
cpu["TRANSPILER"]["SINGLE"] = single_engine engines["TRANSPILER"]["SINGLE"] = single_engine
cpu["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
cpu["TRANSPILER"]["CLUSTER"] = cluster_engine engines["TRANSPILER"]["CLUSTER"] = cluster_engine
cpu["PSLIB"]["SINGLE"] = local_mpi_engine
cpu["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
cpu["PSLIB"]["CLUSTER"] = cluster_mpi_engine
gpu = {"TRANSPILER": {}, "PSLIB": {}} engines["PSLIB"]["SINGLE"] = local_mpi_engine
gpu["TRANSPILER"]["SINGLE"] = single_engine engines["PSLIB"]["LOCAL_CLUSTER"] = local_mpi_engine
engines["PSLIB"]["CLUSTER"] = cluster_mpi_engine
engines["CPU"] = cpu
engines["GPU"] = gpu def get_inters_from_yaml(file, filter):
with open(file, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs)
inters = {}
for k, v in flattens.items():
if k.startswith(filter):
inters[k] = v
return inters
def get_engine(args): def get_engine(args):
device = args.device
d_engine = engines[device]
transpiler = get_transpiler() transpiler = get_transpiler()
run_extras = get_inters_from_yaml(args.model, "train.")
engine = run_extras.get("train.engine", "")
engine = engine.upper()
engine = args.engine if engine not in engine_choices:
run_engine = d_engine[transpiler].get(engine, None) raise ValueError("train.engin can not be chosen in {}".format(engine_choices))
if run_engine is None: run_engine = engines[transpiler].get(engine, None)
raise ValueError(
"engine {} can not be supported on device: {}".format(engine, device))
return run_engine return run_engine
...@@ -73,24 +83,13 @@ def get_transpiler(): ...@@ -73,24 +83,13 @@ def get_transpiler():
def set_runtime_envs(cluster_envs, engine_yaml): def set_runtime_envs(cluster_envs, engine_yaml):
def get_engine_extras():
with open(engine_yaml, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs)
engine_extras = {}
for k, v in flattens.items():
if k.startswith("train.trainer."):
engine_extras[k] = v
return engine_extras
if cluster_envs is None: if cluster_envs is None:
cluster_envs = {} cluster_envs = {}
engine_extras = get_engine_extras() engine_extras = get_inters_from_yaml(engine_yaml, "train.trainer.")
if "train.trainer.threads" in engine_extras and "CPU_NUM" in cluster_envs: if "train.trainer.threads" in engine_extras and "CPU_NUM" in cluster_envs:
cluster_envs["CPU_NUM"] = engine_extras["train.trainer.threads"] cluster_envs["CPU_NUM"] = engine_extras["train.trainer.threads"]
envs.set_runtime_environs(cluster_envs) envs.set_runtime_environs(cluster_envs)
envs.set_runtime_environs(engine_extras) envs.set_runtime_environs(engine_extras)
...@@ -114,7 +113,6 @@ def single_engine(args): ...@@ -114,7 +113,6 @@ def single_engine(args):
single_envs["train.trainer.trainer"] = trainer single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2" single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single" single_envs["train.trainer.engine"] = "single"
single_envs["train.trainer.device"] = args.device
single_envs["train.trainer.platform"] = envs.get_platform() single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model)) print("use {} engine to run model: {}".format(trainer, args.model))
...@@ -124,7 +122,6 @@ def single_engine(args): ...@@ -124,7 +122,6 @@ def single_engine(args):
def cluster_engine(args): def cluster_engine(args):
def update_workspace(cluster_envs): def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None) workspace = cluster_envs.get("engine_workspace", None)
if not workspace: if not workspace:
...@@ -144,12 +141,13 @@ def cluster_engine(args): ...@@ -144,12 +141,13 @@ def cluster_engine(args):
cluster_envs[name] = value cluster_envs[name] = value
def master(): def master():
role = "MASTER"
from paddlerec.core.engine.cluster.cluster import ClusterEngine from paddlerec.core.engine.cluster.cluster import ClusterEngine
with open(args.backend, 'r') as rb: with open(args.backend, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader) _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs, "_") flattens = envs.flatten_environs(_envs, "_")
flattens["engine_role"] = args.role flattens["engine_role"] = role
flattens["engine_run_config"] = args.model flattens["engine_run_config"] = args.model
flattens["engine_temp_path"] = tempfile.mkdtemp() flattens["engine_temp_path"] = tempfile.mkdtemp()
update_workspace(flattens) update_workspace(flattens)
...@@ -161,12 +159,12 @@ def cluster_engine(args): ...@@ -161,12 +159,12 @@ def cluster_engine(args):
return launch return launch
def worker(): def worker():
role = "WORKER"
trainer = get_trainer_prefix(args) + "ClusterTrainer" trainer = get_trainer_prefix(args) + "ClusterTrainer"
cluster_envs = {} cluster_envs = {}
cluster_envs["train.trainer.trainer"] = trainer cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.engine"] = "cluster" cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.threads"] = envs.get_runtime_environ("CPU_NUM") cluster_envs["train.trainer.threads"] = envs.get_runtime_environ("CPU_NUM")
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
print("launch {} engine with cluster to with model: {}".format( print("launch {} engine with cluster to with model: {}".format(
trainer, args.model)) trainer, args.model))
...@@ -175,7 +173,9 @@ def cluster_engine(args): ...@@ -175,7 +173,9 @@ def cluster_engine(args):
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
if args.role == "WORKER": role = os.getenv("PADDLE_PADDLEREC_ROLE", "MASTER")
if role == "WORKER":
return worker() return worker()
else: else:
return master() return master()
...@@ -186,7 +186,6 @@ def cluster_mpi_engine(args): ...@@ -186,7 +186,6 @@ def cluster_mpi_engine(args):
cluster_envs = {} cluster_envs = {}
cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer" cluster_envs["train.trainer.trainer"] = "CtrCodingTrainer"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
...@@ -208,8 +207,6 @@ def local_cluster_engine(args): ...@@ -208,8 +207,6 @@ def local_cluster_engine(args):
cluster_envs["train.trainer.strategy"] = "async" cluster_envs["train.trainer.strategy"] = "async"
cluster_envs["train.trainer.threads"] = "2" cluster_envs["train.trainer.threads"] = "2"
cluster_envs["train.trainer.engine"] = "local_cluster" cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
cluster_envs["CPU_NUM"] = "2" cluster_envs["CPU_NUM"] = "2"
...@@ -235,7 +232,6 @@ def local_mpi_engine(args): ...@@ -235,7 +232,6 @@ def local_mpi_engine(args):
cluster_envs["log_dir"] = "logs" cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.engine"] = "local_cluster" cluster_envs["train.trainer.engine"] = "local_cluster"
cluster_envs["train.trainer.device"] = args.device
cluster_envs["train.trainer.platform"] = envs.get_platform() cluster_envs["train.trainer.platform"] = envs.get_platform()
set_runtime_envs(cluster_envs, args.model) set_runtime_envs(cluster_envs, args.model)
...@@ -258,29 +254,17 @@ def get_abs_model(model): ...@@ -258,29 +254,17 @@ def get_abs_model(model):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='paddle-rec run') parser = argparse.ArgumentParser(description='paddle-rec run')
parser.add_argument("-m", "--model", type=str) parser.add_argument("-m", "--model", type=str)
parser.add_argument("-e", "--engine", type=str,
choices=["single", "local_cluster", "cluster",
"tdm_single", "tdm_local_cluster", "tdm_cluster"])
parser.add_argument("-d", "--device", type=str,
choices=["cpu", "gpu"], default="cpu")
parser.add_argument("-b", "--backend", type=str, default=None) parser.add_argument("-b", "--backend", type=str, default=None)
parser.add_argument("-r", "--role", type=str,
choices=["master", "worker"], default="master")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
envs.set_runtime_environs({"PACKAGE_BASE": abs_dir}) envs.set_runtime_environs({"PACKAGE_BASE": abs_dir})
args = parser.parse_args() args = parser.parse_args()
args.engine = args.engine.upper()
args.device = args.device.upper()
args.role = args.role.upper()
model_name = args.model.split('.')[-1] model_name = args.model.split('.')[-1]
args.model = get_abs_model(args.model) args.model = get_abs_model(args.model)
engine_registry() engine_registry()
which_engine = get_engine(args) which_engine = get_engine(args)
engine = which_engine(args) engine = which_engine(args)
engine.run() engine.run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册