cluster_master.py 17.2 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import json
import math
import time
X
Xi Chen 已提交
20
import threading
X
Xi Chen 已提交
21
import logging
X
Xi Chen 已提交
22 23 24 25 26 27

import netaddr
import boto3
import namesgenerator
import paramiko

X
Xi Chen 已提交
28 29
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer

X
Xi Chen 已提交
30 31 32 33
# You must have aws_access_key_id, aws_secret_access_key, region set in
# ~/.aws/credentials and ~/.aws/config

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
34
parser.add_argument(
X
Xi Chen 已提交
35
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
36 37 38 39 40 41
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
42 43 44 45 46 47 48 49 50 51
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
52

X
Xi Chen 已提交
53 54 55
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
56 57
    default="p2.8xlarge",
    help="your pserver instance type, p2.8xlarge by default")
X
Xi Chen 已提交
58 59 60
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
61 62
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
63

X
Xi Chen 已提交
64 65 66 67 68 69 70 71
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
72 73 74
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-east-2"
)
X
Xi Chen 已提交
75 76 77
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
78 79 80 81 82 83 84 85 86
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-west-2"
)

parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
    '--pserver_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"),
    help="pserver bash file path")

parser.add_argument(
    '--trainer_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"),
    help="trainer bash file path")

X
Xi Chen 已提交
106
parser.add_argument(
X
Xi Chen 已提交
107
    '--action', type=str, default="serve", help="create|cleanup|serve")
X
Xi Chen 已提交
108

X
Xi Chen 已提交
109 110 111 112 113 114 115 116
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
117 118 119 120 121 122
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_ip', type=str, default="", help="master server private ip")

X
Xi Chen 已提交
123 124 125 126
args = parser.parse_args()

ec2client = boto3.client('ec2')

X
Xi Chen 已提交
127 128 129
logging.basicConfig(
    filename='master.log', level=logging.INFO, format='%(asctime)s %(message)s')

X
Xi Chen 已提交
130 131 132

def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
133
    logging.info("start creating subnet")
X
Xi Chen 已提交
134
    if not args.vpc_id:
X
Xi Chen 已提交
135
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
136 137 138 139 140 141 142 143 144 145
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
146 147
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
148 149

    if not vpc_cidrBlock:
X
Xi Chen 已提交
150
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
151 152 153 154 155 156 157 158
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
159
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
160 161 162

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
163
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
185
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
186 187 188 189 190 191 192 193
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
194
    logging.info("trying to create subnet")
X
Xi Chen 已提交
195
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
196 197 198
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
199 200 201 202 203 204 205 206

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
207
    logging.info("subnet created")
X
Xi Chen 已提交
208

X
Xi Chen 已提交
209
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


def generate_task_name():
    return namesgenerator.get_random_name()


def script_to_str(file_path):
    if not file_path:
        return "echo $PSERVER_HOSTS"
    file = open(file_path, 'r')
    text = file.read().strip()
    file.close()
    return text


def run_instances(image_id, instance_type, count, role, cmd=""):
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
242
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
                "Value": args.task_name
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
265
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
266
    else:
X
Xi Chen 已提交
267
        logging.info("no instance created")
X
Xi Chen 已提交
268 269
    #create waiter to make sure it's running

X
Xi Chen 已提交
270
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


def create_pservers():
X
Xi Chen 已提交
291 292 293 294 295 296
    try:
        return run_instances(
            image_id=args.pserver_image_id,
            instance_type=args.pserver_instance_type,
            count=args.pserver_count,
            role="PSERVER", )
X
Xi Chen 已提交
297 298
    except Exception:
        logging.exception("error while trying to create pservers")
X
Xi Chen 已提交
299
        cleanup(args.task_name)
X
Xi Chen 已提交
300 301 302


def create_trainers(kickoff_cmd, pserver_endpoints_str):
X
Xi Chen 已提交
303 304 305 306 307 308
    try:
        responses = []
        for i in xrange(args.trainer_count):
            cmd = kickoff_cmd.format(
                PSERVER_HOSTS=pserver_endpoints_str,
                DOCKER_IMAGE=args.docker_image,
X
Xi Chen 已提交
309 310 311 312 313
                TRAINER_INDEX=str(i),
                TASK_NAME=args.task_name,
                MASTER_ENDPOINT=args.master_server_ip + ":" +
                str(args.master_server_port))
            logging.info(cmd)
X
Xi Chen 已提交
314 315 316 317 318 319 320 321
            responses.append(
                run_instances(
                    image_id=args.trainer_image_id,
                    instance_type=args.trainer_instance_type,
                    count=1,
                    role="TRAINER",
                    cmd=cmd, )[0])
        return responses
X
Xi Chen 已提交
322 323
    except Exception:
        logging.exception("error while trying to create trainers")
X
Xi Chen 已提交
324
        cleanup(args.task_name)
X
Xi Chen 已提交
325 326 327 328


def cleanup(task_name):
    #shutdown all ec2 instances
X
Xi Chen 已提交
329
    print("going to clean up " + task_name + " instances")
X
Xi Chen 已提交
330 331 332
    instances_response = ec2client.describe_instances(Filters=[{
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
333 334 335
    }])

    instance_ids = []
X
Xi Chen 已提交
336 337 338 339
    if len(instances_response["Reservations"]) > 0:
        for reservation in instances_response["Reservations"]:
            for instance in reservation["Instances"]:
                instance_ids.append(instance["InstanceId"])
X
Xi Chen 已提交
340

X
Xi Chen 已提交
341
        ec2client.terminate_instances(InstanceIds=instance_ids)
X
Xi Chen 已提交
342

X
Xi Chen 已提交
343 344 345
        instance_termination_waiter = ec2client.get_waiter(
            'instance_terminated')
        instance_termination_waiter.wait(InstanceIds=instance_ids)
X
Xi Chen 已提交
346

X
Xi Chen 已提交
347
    #delete the subnet created
X
Xi Chen 已提交
348 349

    subnet = ec2client.describe_subnets(Filters=[{
X
Xi Chen 已提交
350 351
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
352 353
    }])

X
Xi Chen 已提交
354 355
    if len(subnet["Subnets"]) > 0:
        ec2client.delete_subnet(SubnetId=subnet["Subnets"][0]["SubnetId"])
X
Xi Chen 已提交
356
    # no subnet delete waiter, just leave it.
X
Xi Chen 已提交
357
    logging.info("Clearnup done")
X
Xi Chen 已提交
358 359 360
    return


X
Xi Chen 已提交
361 362 363 364 365 366 367 368 369
def kickoff_pserver(host, pserver_endpoints_str):
    try:
        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=host, username="ubuntu", pkey=ssh_key)
        cmd = (script_to_str(args.pserver_bash_file)).format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
X
Xi Chen 已提交
370 371 372 373 374
            PSERVER_PORT=args.pserver_port,
            TASK_NAME=args.task_name,
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)
X
Xi Chen 已提交
375 376
        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
        return_code = stdout.channel.recv_exit_status()
X
Xi Chen 已提交
377
        logging.info(return_code)
X
Xi Chen 已提交
378 379
        if return_code != 0:
            raise Exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
380 381
    except Exception:
        logging.exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
382 383 384 385 386
        cleanup(args.task_name)
    finally:
        ssh_client.close()


X
Xi Chen 已提交
387 388
def init_args():

X
Xi Chen 已提交
389 390
    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
391
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
392 393 394 395 396 397

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )

X
Xi Chen 已提交
398 399 400 401 402 403 404 405 406 407 408
    args.trainers_job_done_count = 0


def create_cluster():

    if not args.subnet_id:
        logging.info("creating subnet for this task")
        args.subnet_id = create_subnet()
        logging.info("subnet %s created" % (args.subnet_id))

    logging.info("creating pservers")
X
Xi Chen 已提交
409
    pserver_create_response = create_pservers()
X
Xi Chen 已提交
410
    logging.info("pserver created, collecting pserver ips")
X
Xi Chen 已提交
411 412 413 414 415 416 417 418

    pserver_endpoints = []
    for pserver in pserver_create_response:
        pserver_endpoints.append(pserver["NetworkInterfaces"][0][
            "PrivateIpAddress"] + ":" + args.pserver_port)

    pserver_endpoints_str = ",".join(pserver_endpoints)

X
Xi Chen 已提交
419
    logging.info("kicking off pserver training process")
X
Xi Chen 已提交
420
    pserver_threads = []
X
Xi Chen 已提交
421
    for pserver in pserver_create_response:
X
Xi Chen 已提交
422 423 424 425 426 427 428 429 430
        pserver_thread = threading.Thread(
            target=kickoff_pserver,
            args=(pserver["PublicIpAddress"], pserver_endpoints_str))
        pserver_thread.start()
        pserver_threads.append(pserver_thread)

    for pserver_thread in pserver_threads:
        pserver_thread.join()

X
Xi Chen 已提交
431
    logging.info("all pserver training process started")
X
Xi Chen 已提交
432

X
Xi Chen 已提交
433
    logging.info("creating trainers and kicking off trainer training process")
X
Xi Chen 已提交
434 435 436
    create_trainers(
        kickoff_cmd=script_to_str(args.trainer_bash_file),
        pserver_endpoints_str=pserver_endpoints_str)
X
Xi Chen 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
    logging.info("trainers created")


def start_server(args):
    class S(BaseHTTPRequestHandler):
        def _set_headers(self):
            self.send_response(200)
            self.send_header('Content-type', 'text/text')
            self.end_headers()

        def do_HEAD(self):
            self._set_headers()

        def do_404(self):
            self.send_response(404)
            self.send_header('Content-type', 'text/text')
            self.end_headers()
            logging.info("Received invalid GET request" + self.path)
            self.wfile.write("NO ACTION FOUND")

        def do_GET(self):
            self._set_headers()
            request_path = self.path
            if request_path == "/status" or request_path == "/logs":
                logging.info("Received request to return status")
                with open("master.log", "r") as logfile:
                    self.wfile.write(logfile.read().strip())
            else:
                self.do_404()

        def do_POST(self):

            request_path = self.path

            if request_path == "/save_data":
                self._set_headers()
                logging.info("Received request to save data")
                self.wfile.write("DATA SAVED!")
                content_length = int(self.headers['Content-Length'])
                post_data = self.rfile.read(content_length)
                if args.task_name:
                    with open(args.task_name + ".txt", "a") as text_file:
                        text_file.write(post_data + "\n")

            elif request_path == "/cleanup":
                self._set_headers()
                logging.info("Received request to cleanup cluster")
                cleanup(args.task_name)
                self.wfile.write("cleanup in progress")

            elif request_path == "/trainer_job_done":
                self._set_headers()
                logging.info("Received request to increase job done count")
                args.trainers_job_done_count += 1
                self.wfile.write(
                    str(args.trainers_job_done_count) + " tainers job done")
                if args.trainers_job_done_count >= args.trainer_count:
                    logging.info("going to clean up")
                    cleanup(args.task_name)

            else:
                self.do_404()

    server_address = ('', args.master_server_port)
    httpd = HTTPServer(server_address, S)
    logging.info("HTTP server is starting")
    httpd.serve_forever()
X
Xi Chen 已提交
504 505 506


def print_arguments():
X
Xi Chen 已提交
507
    logging.info('-----------  Configuration Arguments -----------')
X
Xi Chen 已提交
508
    for arg, value in sorted(vars(args).iteritems()):
X
Xi Chen 已提交
509 510
        logging.info('%s: %s' % (arg, value))
    logging.info('------------------------------------------------')
X
Xi Chen 已提交
511 512 513 514


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
515
    if args.action == "create":
X
Xi Chen 已提交
516
        logging.info("going to create cluster")
X
Xi Chen 已提交
517 518
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
519 520
        init_args()
        create_cluster()
X
Xi Chen 已提交
521
    elif args.action == "cleanup":
X
Xi Chen 已提交
522
        logging.info("going to cleanup cluster")
X
Xi Chen 已提交
523 524 525
        if not args.task_name:
            raise ValueError("task_name is required")
        cleanup(args.task_name)
X
Xi Chen 已提交
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
    elif args.action == "serve":
        # serve mode
        if not args.master_server_ip:
            raise ValueError(
                "No master server ip set, please run with --action create")

        logging.info("going to start serve and create cluster")

        init_args()

        logging.info("starting server in another thread")
        server_thread = threading.Thread(target=start_server, args=(args, ))
        server_thread.start()

        create_cluster()
        server_thread.join()