提交 9368aec8 编写于 作者: T tangwei12

fix bug, cluster training

上级 2ebef2b7
......@@ -27,7 +27,8 @@
train:
threads: 12
epochs: 10
trainer: "LocalClusterTraining"
trainer: "ClusterTraining"
container: "local"
pserver_num: 2
trainer_num: 2
......
......@@ -26,15 +26,11 @@
import os
import yaml
from eleps.trainer.factory import TrainerFactory
if __name__ == "__main__":
abs_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(abs_dir, 'ctr-dnn_train_single.yaml'), 'r') as rb:
global_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
trainer = TrainerFactory.create(global_config)
yaml = os.path.join(abs_dir, 'ctr-dnn_train_cluster.yaml')
trainer = TrainerFactory.create(yaml)
trainer.run()
......@@ -42,10 +42,9 @@ class ClusterTrainerWithDataloader(TranspileTrainer):
class ClusterTrainerWithDataset(TranspileTrainer):
def processor_register(self):
role = PaddleCloudRoleMaker()
fleet.init(role)
if role.is_server():
if fleet.is_server():
self.regist_context_processor('uninit', self.instance)
self.regist_context_processor('init_pass', self.init)
self.regist_context_processor('server_pass', self.server)
......@@ -74,6 +73,9 @@ class ClusterTrainerWithDataset(TranspileTrainer):
return strategy
def init(self, context):
print("init pass")
self.model.input()
self.model.net()
self.metrics = self.model.metrics()
......
......@@ -35,7 +35,7 @@ from eleps.trainer.single_trainer import SingleTrainerWithDataset
from eleps.trainer.cluster_trainer import ClusterTrainerWithDataloader
from eleps.trainer.cluster_trainer import ClusterTrainerWithDataset
from eleps.trainer.local_engine import local_launch
from eleps.trainer.local_engine import Launch
from eleps.trainer.ctr_trainer import CtrPaddleTrainer
from eleps.utils import envs
......@@ -91,24 +91,26 @@ class TrainerFactory(object):
cluster_envs["start_port"] = envs.get_global_env("train.start_port")
cluster_envs["log_dir"] = envs.get_global_env("train.log_dirname")
envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value"))
print(envs.pretty_print_envs(cluster_envs, ("Cluster Global Envs", "Value")))
local_launch(cluster_envs, yaml_config)
launch = Launch(cluster_envs, yaml_config)
return launch
@staticmethod
def create(config):
_config = None
if os.path.exists(config) and os.path.isfile(config):
with open(config, 'r') as rb:
_config = yaml.load(rb.read())
_config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else:
raise ValueError("eleps's config only support yaml")
envs.set_global_envs(_config)
train_mode = envs.get_global_env("train.trainer")
mode = envs.get_global_env("train.trainer")
container = envs.get_global_env("train.container")
instance = str2bool(os.getenv("CLUSTER_INSTANCE", "0"))
if train_mode == "LocalClusterTraining" and not instance:
if mode == "ClusterTraining" and container == "local" and not instance:
trainer = TrainerFactory._build_engine(config)
else:
trainer = TrainerFactory._build_trainer(_config)
......@@ -120,4 +122,6 @@ class TrainerFactory(object):
if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError("need a yaml file path argv")
TrainerFactory.create(sys.argv[1])
trainer = TrainerFactory.create(sys.argv[1])
trainer.run()
......@@ -38,7 +38,8 @@ def start_procs(args, yaml):
user_endpoints_ips = [x.split(":")[0] for x in user_endpoints.split(",")]
user_endpoints_port = [x.split(":")[1] for x in user_endpoints.split(",")]
factory = os.path.join(os.path.abspath(os.path.dirname(__file__)), "factory.py")
factory = "eleps.trainer.factory"
cmd = [sys.executable, "-u", "-m", factory, yaml]
for i in range(server_num):
current_env.update({
......@@ -48,15 +49,14 @@ def start_procs(args, yaml):
"PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": user_endpoints_ips[i]
})
cmd = [sys.executable, "-u", factory, yaml]
if args.log_dir is not None:
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/server.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env)
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
for i in range(worker_num):
......@@ -66,16 +66,14 @@ def start_procs(args, yaml):
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i)
})
cmd = [sys.executable, "-u", args.training_script
] + args.training_script_args
if args.log_dir is not None:
if logs_dir is not None:
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn, cwd=os.getcwd())
else:
proc = subprocess.Popen(cmd, env=current_env)
proc = subprocess.Popen(cmd, env=current_env, cwd=os.getcwd())
procs.append(proc)
# only wait worker to finish here
......@@ -93,6 +91,11 @@ def start_procs(args, yaml):
procs[i].terminate()
print("all parameter server are killed", file=sys.stderr)
class Launch():
def __init__(self, envs, trainer):
self.envs = envs
self.trainer = trainer
def run(self):
start_procs(self.envs, self.trainer)
def local_launch(envs, trainer):
start_procs(envs, trainer)
......@@ -48,7 +48,7 @@ def get_global_envs():
return global_envs
def pretty_print_envs(envs, header):
def pretty_print_envs(envs, header=None):
spacing = 5
max_k = 45
max_v = 20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册