提交 ca8af949 编写于 作者: X Xi Chen

update features mentioned by @helin

上级 82530e24
...@@ -88,7 +88,7 @@ parser.add_argument( ...@@ -88,7 +88,7 @@ parser.add_argument(
'--pserver_count', type=int, default=1, help="Pserver count") '--pserver_count', type=int, default=1, help="Pserver count")
parser.add_argument( parser.add_argument(
'--action', type=str, default="serve", help="create|cleanup|status") '--action', type=str, default="create", help="create|cleanup|status")
parser.add_argument('--pem_path', type=str, help="private key file") parser.add_argument('--pem_path', type=str, help="private key file")
...@@ -355,7 +355,8 @@ def status(): ...@@ -355,7 +355,8 @@ def status():
def get_master_web_url(path): def get_master_web_url(path):
return "http://" + args.master_server_public_ip + ":" + args.master_server_port + path return "http://" + args.master_server_public_ip + ":" + str(
args.master_server_port) + path
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -127,6 +127,8 @@ ec2client = boto3.client('ec2') ...@@ -127,6 +127,8 @@ ec2client = boto3.client('ec2')
logging.basicConfig( logging.basicConfig(
filename='master.log', level=logging.INFO, format='%(asctime)s %(message)s') filename='master.log', level=logging.INFO, format='%(asctime)s %(message)s')
log_files = ["master.log"]
def create_subnet(): def create_subnet():
# if no vpc id provided, list vpcs # if no vpc id provided, list vpcs
...@@ -299,28 +301,103 @@ def create_pservers(): ...@@ -299,28 +301,103 @@ def create_pservers():
cleanup(args.task_name) cleanup(args.task_name)
def log_to_file(source, filename):
if not filename in log_files:
log_files.append(filename)
with open(filename, "a") as log_file:
for line in iter(source.readline, ""):
log_file.write(line)
def create_trainers(kickoff_cmd, pserver_endpoints_str): def create_trainers(kickoff_cmd, pserver_endpoints_str):
def create_and_start_trainer(trainer_index):
logging.info("trainer " + str(trainer_index) + " is starting")
instance_response = run_instances(
image_id=args.trainer_image_id,
instance_type=args.trainer_instance_type,
count=1,
role="TRAINER", )[0]
trainer_ip = instance_response["PrivateIpAddress"]
logging.info("trainer " + str(trainer_index) + " started")
ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_client.connect(hostname=trainer_ip, username="ubuntu", pkey=ssh_key)
logging.info("trainer " + str(trainer_index) +
" terminal connected via ssh")
cmd = kickoff_cmd.format(
PSERVER_HOSTS=pserver_endpoints_str,
DOCKER_IMAGE=args.docker_image,
TRAINER_INDEX=str(trainer_index),
TASK_NAME=args.task_name,
MASTER_ENDPOINT=args.master_server_ip + ":" +
str(args.master_server_port))
logging.info(cmd)
stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
# read and save output log
logging.info("trainer " + str(trainer_index) +
" command executed, keep fetching log")
stdout_thread = threading.Thread(
target=log_to_file,
args=(
stdout,
"trainer_" + str(trainer_index) + ".log", ))
stderr_thread = threading.Thread(
target=log_to_file,
args=(
stderr,
"trainer_" + str(trainer_index) + "_err.log", ))
stdout_thread.start()
stderr_thread.start()
stdout_thread.join()
stderr_thread.join()
return_code = stdout.channel.recv_exit_status()
if return_code != 0:
trainer_create_results[trainer_index] = {'has_error': True}
raise ValueError("trainer didn't finish with exit code 0")
ssh_client.close()
# multi thread starting trainer instance and run kickoff command
trainer_threads = []
trainer_create_results = {}
try: try:
responses = []
for i in xrange(args.trainer_count): for i in xrange(args.trainer_count):
cmd = kickoff_cmd.format( logging.info("starting tread for trainer " + str(i))
PSERVER_HOSTS=pserver_endpoints_str, trainer_thread = threading.Thread(
DOCKER_IMAGE=args.docker_image, target=create_and_start_trainer, args=(i, ))
TRAINER_INDEX=str(i), trainer_thread.start()
TASK_NAME=args.task_name, trainer_threads.append(trainer_thread)
MASTER_ENDPOINT=args.master_server_ip + ":" +
str(args.master_server_port)) for trainer_thread in trainer_threads:
logging.info(cmd) trainer_thread.join()
responses.append(
run_instances( for result in trainer_create_results:
image_id=args.trainer_image_id, if result["has_error"]:
instance_type=args.trainer_instance_type, logging.error(
count=1, "error during trainer starting or training, destorying the while cluster "
role="TRAINER", )
cmd=cmd, )[0]) cleanup(args.task_name)
return responses break
except Exception:
logging.exception("error while trying to create trainers") logging.info("all trainers stopped")
except Exception, e:
logging.info(
"Training exception, clean up resources, please check log for more info"
)
finally:
cleanup(args.task_name) cleanup(args.task_name)
...@@ -373,6 +450,21 @@ def kickoff_pserver(host, pserver_endpoints_str): ...@@ -373,6 +450,21 @@ def kickoff_pserver(host, pserver_endpoints_str):
str(args.master_server_port)) str(args.master_server_port))
logging.info(cmd) logging.info(cmd)
stdin, stdout, stderr = ssh_client.exec_command(command=cmd) stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
stdout_thread = threading.Thread(
target=log_to_file, args=(
stdout,
"pserver_" + host + ".log", ))
stderr_thread = threading.Thread(
target=log_to_file, args=(
stderr,
"pserver_" + host + "_err.log", ))
stdout_thread.start()
stderr_thread.start()
stdout_thread.join()
stderr_thread.join()
return_code = stdout.channel.recv_exit_status() return_code = stdout.channel.recv_exit_status()
logging.info(return_code) logging.info(return_code)
if return_code != 0: if return_code != 0:
...@@ -421,20 +513,21 @@ def create_cluster(): ...@@ -421,20 +513,21 @@ def create_cluster():
for pserver in pserver_create_response: for pserver in pserver_create_response:
pserver_thread = threading.Thread( pserver_thread = threading.Thread(
target=kickoff_pserver, target=kickoff_pserver,
args=(pserver["PublicIpAddress"], pserver_endpoints_str)) args=(pserver["PrivateIpAddress"], pserver_endpoints_str))
pserver_thread.start() pserver_thread.start()
pserver_threads.append(pserver_thread) pserver_threads.append(pserver_thread)
for pserver_thread in pserver_threads:
pserver_thread.join()
logging.info("all pserver training process started") logging.info("all pserver training process started")
logging.info("creating trainers and kicking off trainer training process") logging.info("creating trainers and kicking off trainer training process")
create_trainers( create_trainers(
kickoff_cmd=script_to_str(args.trainer_bash_file), kickoff_cmd=script_to_str(args.trainer_bash_file),
pserver_endpoints_str=pserver_endpoints_str) pserver_endpoints_str=pserver_endpoints_str)
logging.info("trainers created")
for pserver_thread in pserver_threads:
pserver_thread.join()
logging.info("all process ended")
def start_server(args): def start_server(args):
...@@ -455,12 +548,20 @@ def start_server(args): ...@@ -455,12 +548,20 @@ def start_server(args):
self.wfile.write("NO ACTION FOUND") self.wfile.write("NO ACTION FOUND")
def do_GET(self): def do_GET(self):
self._set_headers()
request_path = self.path request_path = self.path
if request_path == "/status" or request_path == "/logs": if request_path == "/status" or request_path == "/master_logs":
self._set_headers()
logging.info("Received request to return status") logging.info("Received request to return status")
with open("master.log", "r") as logfile: with open("master.log", "r") as logfile:
self.wfile.write(logfile.read().strip()) self.wfile.write(logfile.read().strip())
elif request_path == "/list_logs":
self._set_headers()
self.wfile.write("\n".join(log_files))
elif "/log/" in request_path:
log_file_path = request_path.replace("/log/")
with open(log_file_path, "r") as logfile:
self.wfile.write(logfile.read().strip())
else: else:
self.do_404() self.do_404()
...@@ -484,16 +585,6 @@ def start_server(args): ...@@ -484,16 +585,6 @@ def start_server(args):
cleanup(args.task_name) cleanup(args.task_name)
self.wfile.write("cleanup in progress") self.wfile.write("cleanup in progress")
elif request_path == "/trainer_job_done":
self._set_headers()
logging.info("Received request to increase job done count")
args.trainers_job_done_count += 1
self.wfile.write(
str(args.trainers_job_done_count) + " tainers job done")
if args.trainers_job_done_count >= args.trainer_count:
logging.info("going to clean up")
cleanup(args.task_name)
else: else:
self.do_404() self.do_404()
...@@ -539,3 +630,12 @@ if __name__ == "__main__": ...@@ -539,3 +630,12 @@ if __name__ == "__main__":
create_cluster() create_cluster()
server_thread.join() server_thread.join()
elif args.action == "test":
init_args()
if not args.subnet_id:
logging.info("creating subnet for this task")
args.subnet_id = create_subnet()
logging.info("subnet %s created" % (args.subnet_id))
create_trainers(
kickoff_cmd=script_to_str(args.trainer_bash_file),
pserver_endpoints_str="11.22.33.44:5476")
#!/bin/bash #!/bin/bash
nvidia-docker run -p {PSERVER_PORT}:{PSERVER_PORT} -e "MASTER_ENDPOINT={MASTER_ENDPOINT}" -e "TASK_NAME={TASK_NAME}" -e "TRAINING_ROLE=PSERVER" -e "PSERVER_HOSTS={PSERVER_HOSTS}" {DOCKER_IMAGE} nvidia-docker run -i -p {PSERVER_PORT}:{PSERVER_PORT} -e "MASTER_ENDPOINT={MASTER_ENDPOINT}" -e "TASK_NAME={TASK_NAME}" -e "TRAINING_ROLE=PSERVER" -e "PSERVER_HOSTS={PSERVER_HOSTS}" {DOCKER_IMAGE}
\ No newline at end of file \ No newline at end of file
#!/bin/bash #!/bin/bash
nvidia-docker run -e "MASTER_ENDPOINT={MASTER_ENDPOINT}" -e "TASK_NAME={TASK_NAME}" -e "TRAINER_INDEX={TRAINER_INDEX}" -e "TRAINING_ROLE=TRAINER" -e "PSERVER_HOSTS={PSERVER_HOSTS}" {DOCKER_IMAGE} nvidia-docker run -i -e "MASTER_ENDPOINT={MASTER_ENDPOINT}" -e "TASK_NAME={TASK_NAME}" -e "TRAINER_INDEX={TRAINER_INDEX}" -e "TRAINING_ROLE=TRAINER" -e "PSERVER_HOSTS={PSERVER_HOSTS}" {DOCKER_IMAGE}
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册