提交 45d87ade 编写于 作者: X Xi Chen

minor tweaks

上级 94ad30e5
...@@ -49,8 +49,8 @@ parser.add_argument( ...@@ -49,8 +49,8 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--pserver_instance_type', '--pserver_instance_type',
type=str, type=str,
default="p2.8xlarge", default="c5.2xlarge",
help="your pserver instance type, p2.8xlarge by default") help="your pserver instance type, c5.2xlarge by default")
parser.add_argument( parser.add_argument(
'--trainer_instance_type', '--trainer_instance_type',
type=str, type=str,
...@@ -68,6 +68,10 @@ parser.add_argument( ...@@ -68,6 +68,10 @@ parser.add_argument(
default="ami-da2c1cbf", default="ami-da2c1cbf",
help="ami id for system image, default one has nvidia-docker ready, \ help="ami id for system image, default one has nvidia-docker ready, \
use ami-1ae93962 for us-east-2") use ami-1ae93962 for us-east-2")
parser.add_argument(
'--pserver_command', type=str, default="", help="pserver start command")
parser.add_argument( parser.add_argument(
'--trainer_image_id', '--trainer_image_id',
type=str, type=str,
...@@ -75,6 +79,9 @@ parser.add_argument( ...@@ -75,6 +79,9 @@ parser.add_argument(
help="ami id for system image, default one has nvidia-docker ready, \ help="ami id for system image, default one has nvidia-docker ready, \
use ami-1ae93962 for us-west-2") use ami-1ae93962 for us-west-2")
parser.add_argument(
'--trainer_command', type=str, default="", help="trainer start command")
parser.add_argument( parser.add_argument(
'--availability_zone', '--availability_zone',
type=str, type=str,
...@@ -104,6 +111,12 @@ parser.add_argument( ...@@ -104,6 +111,12 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--master_server_public_ip', type=str, help="master server public ip") '--master_server_public_ip', type=str, help="master server public ip")
parser.add_argument(
'--master_docker_image',
type=str,
default="putcn/paddle_aws_master:latest",
help="master docker image id")
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
...@@ -322,14 +335,16 @@ def create(): ...@@ -322,14 +335,16 @@ def create():
# set arguments and start docker # set arguments and start docker
kick_off_cmd = "docker run -d -v /home/ubuntu/.aws:/root/.aws/" kick_off_cmd = "docker run -d -v /home/ubuntu/.aws:/root/.aws/"
kick_off_cmd += " -v /home/ubuntu/" + args.key_name + ".pem:/root/" + args.key_name + ".pem" kick_off_cmd += " -v /home/ubuntu/" + args.key_name + ".pem:/root/" + args.key_name + ".pem"
kick_off_cmd += " -v /home/ubuntu/logs/:/root/logs/"
kick_off_cmd += " -p " + str(args.master_server_port) + ":" + str( kick_off_cmd += " -p " + str(args.master_server_port) + ":" + str(
args.master_server_port) args.master_server_port)
kick_off_cmd += " putcn/paddle_aws_master" kick_off_cmd += " " + args.master_docker_image
args_to_pass = copy.copy(args) args_to_pass = copy.copy(args)
args_to_pass.action = "serve" args_to_pass.action = "serve"
del args_to_pass.pem_path del args_to_pass.pem_path
del args_to_pass.security_group_ids del args_to_pass.security_group_ids
del args_to_pass.master_docker_image
del args_to_pass.master_server_public_ip del args_to_pass.master_server_public_ip
for arg, value in sorted(vars(args_to_pass).iteritems()): for arg, value in sorted(vars(args_to_pass).iteritems()):
kick_off_cmd += ' --%s %s' % (arg, value) kick_off_cmd += ' --%s %s' % (arg, value)
......
...@@ -53,8 +53,8 @@ parser.add_argument( ...@@ -53,8 +53,8 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--pserver_instance_type', '--pserver_instance_type',
type=str, type=str,
default="p2.8xlarge", default="c5.2xlarge",
help="your pserver instance type, p2.8xlarge by default") help="your pserver instance type, c5.2xlarge by default")
parser.add_argument( parser.add_argument(
'--trainer_instance_type', '--trainer_instance_type',
type=str, type=str,
...@@ -97,12 +97,18 @@ parser.add_argument( ...@@ -97,12 +97,18 @@ parser.add_argument(
default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"), default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"),
help="pserver bash file path") help="pserver bash file path")
parser.add_argument(
'--pserver_command', type=str, default="", help="pserver start command")
parser.add_argument( parser.add_argument(
'--trainer_bash_file', '--trainer_bash_file',
type=str, type=str,
default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"), default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"),
help="trainer bash file path") help="trainer bash file path")
parser.add_argument(
'--trainer_command', type=str, default="", help="trainer start command")
parser.add_argument( parser.add_argument(
'--action', type=str, default="serve", help="create|cleanup|serve") '--action', type=str, default="serve", help="create|cleanup|serve")
...@@ -124,8 +130,12 @@ args = parser.parse_args() ...@@ -124,8 +130,12 @@ args = parser.parse_args()
ec2client = boto3.client('ec2') ec2client = boto3.client('ec2')
args.log_path = os.path.join(os.path.dirname(__file__), "logs/")
logging.basicConfig( logging.basicConfig(
filename='master.log', level=logging.INFO, format='%(asctime)s %(message)s') filename=args.log_path + 'master.log',
level=logging.INFO,
format='%(asctime)s %(message)s')
log_files = ["master.log"] log_files = ["master.log"]
...@@ -304,7 +314,7 @@ def create_pservers(): ...@@ -304,7 +314,7 @@ def create_pservers():
def log_to_file(source, filename): def log_to_file(source, filename):
if not filename in log_files: if not filename in log_files:
log_files.append(filename) log_files.append(filename)
with open(filename, "a") as log_file: with open(args.log_path + filename, "a") as log_file:
for line in iter(source.readline, ""): for line in iter(source.readline, ""):
log_file.write(line) log_file.write(line)
...@@ -335,6 +345,8 @@ def create_trainers(kickoff_cmd, pserver_endpoints_str): ...@@ -335,6 +345,8 @@ def create_trainers(kickoff_cmd, pserver_endpoints_str):
DOCKER_IMAGE=args.docker_image, DOCKER_IMAGE=args.docker_image,
TRAINER_INDEX=str(trainer_index), TRAINER_INDEX=str(trainer_index),
TASK_NAME=args.task_name, TASK_NAME=args.task_name,
TRAINER_COUNT=args.trainer_count,
COMMAND=args.trainer_command,
MASTER_ENDPOINT=args.master_server_ip + ":" + MASTER_ENDPOINT=args.master_server_ip + ":" +
str(args.master_server_port)) str(args.master_server_port))
logging.info(cmd) logging.info(cmd)
...@@ -446,6 +458,9 @@ def kickoff_pserver(host, pserver_endpoints_str): ...@@ -446,6 +458,9 @@ def kickoff_pserver(host, pserver_endpoints_str):
DOCKER_IMAGE=args.docker_image, DOCKER_IMAGE=args.docker_image,
PSERVER_PORT=args.pserver_port, PSERVER_PORT=args.pserver_port,
TASK_NAME=args.task_name, TASK_NAME=args.task_name,
COMMAND=args.pserver_command,
TRAINER_COUNT=args.trainer_count,
SERVER_ENDPOINT=host + ":" + str(args.pserver_port),
MASTER_ENDPOINT=args.master_server_ip + ":" + MASTER_ENDPOINT=args.master_server_ip + ":" +
str(args.master_server_port)) str(args.master_server_port))
logging.info(cmd) logging.info(cmd)
...@@ -553,14 +568,17 @@ def start_server(args): ...@@ -553,14 +568,17 @@ def start_server(args):
if request_path == "/status" or request_path == "/master_logs": if request_path == "/status" or request_path == "/master_logs":
self._set_headers() self._set_headers()
logging.info("Received request to return status") logging.info("Received request to return status")
with open("master.log", "r") as logfile: with open(args.log_path + "master.log", "r") as logfile:
self.wfile.write(logfile.read().strip()) self.wfile.write(logfile.read().strip())
elif request_path == "/list_logs": elif request_path == "/list_logs":
self._set_headers() self._set_headers()
self.wfile.write("\n".join(log_files)) self.wfile.write("\n".join(log_files))
elif "/log/" in request_path: elif "/log/" in request_path:
log_file_path = request_path.replace("/log/") self._set_headers()
with open(log_file_path, "r") as logfile: log_file_path = request_path.replace("/log/", "")
logging.info("requesting log file path is" + args.log_path +
log_file_path)
with open(args.log_path + log_file_path, "r") as logfile:
self.wfile.write(logfile.read().strip()) self.wfile.write(logfile.read().strip())
else: else:
self.do_404() self.do_404()
...@@ -631,11 +649,4 @@ if __name__ == "__main__": ...@@ -631,11 +649,4 @@ if __name__ == "__main__":
create_cluster() create_cluster()
server_thread.join() server_thread.join()
elif args.action == "test": elif args.action == "test":
init_args() start_server(args)
if not args.subnet_id:
logging.info("creating subnet for this task")
args.subnet_id = create_subnet()
logging.info("subnet %s created" % (args.subnet_id))
create_trainers(
kickoff_cmd=script_to_str(args.trainer_bash_file),
pserver_endpoints_str="11.22.33.44:5476")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册