cluster_launcher.py 12.5 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
X
Xi Chen 已提交
18 19 20
import math
import logging
import copy
X
Xi Chen 已提交
21 22 23 24 25

import netaddr
import boto3
import namesgenerator
import paramiko
X
Xi Chen 已提交
26 27
from scp import SCPClient
import requests
X
Xi Chen 已提交
28

X
Xi Chen 已提交
29 30 31 32 33 34 35 36 37 38

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


X
Xi Chen 已提交
39
parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
40
parser.add_argument(
X
Xi Chen 已提交
41
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
42 43 44 45 46 47
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
48 49 50 51 52 53 54 55 56 57
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
58

X
Xi Chen 已提交
59 60 61
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
62 63
    default="c5.2xlarge",
    help="your pserver instance type, c5.2xlarge by default")
X
Xi Chen 已提交
64 65 66
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
67 68
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
69

X
Xi Chen 已提交
70 71 72 73 74 75 76 77
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
78
    default="ami-da2c1cbf",
X
Xi Chen 已提交
79 80
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-east-2")
X
Xi Chen 已提交
81 82

parser.add_argument(
X
Xi Chen 已提交
83 84 85 86 87
    '--pserver_command',
    type=str,
    default="",
    help="pserver start command, format example: python,vgg.py,batch_size:128,is_local:yes"
)
X
Xi Chen 已提交
88

X
Xi Chen 已提交
89 90 91
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
92
    default="ami-da2c1cbf",
X
Xi Chen 已提交
93 94
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-west-2")
X
Xi Chen 已提交
95

X
Xi Chen 已提交
96
parser.add_argument(
X
Xi Chen 已提交
97 98 99 100 101
    '--trainer_command',
    type=str,
    default="",
    help="trainer start command, format example: python,vgg.py,batch_size:128,is_local:yes"
)
X
Xi Chen 已提交
102

X
Xi Chen 已提交
103 104 105 106 107
parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
108 109 110 111 112 113 114 115

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
X
Xi Chen 已提交
116
    '--action', type=str, default="create", help="create|cleanup|status")
X
Xi Chen 已提交
117

X
Xi Chen 已提交
118 119 120 121 122 123 124 125
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
126 127 128 129 130 131
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_public_ip', type=str, help="master server public ip")

X
Xi Chen 已提交
132 133 134 135 136 137
parser.add_argument(
    '--master_docker_image',
    type=str,
    default="putcn/paddle_aws_master:latest",
    help="master docker image id")

X
Xi Chen 已提交
138 139 140 141 142 143
parser.add_argument(
    '--no_clean_up',
    type=str2bool,
    default=False,
    help="whether to clean up after training")

X
Xi Chen 已提交
144 145
args = parser.parse_args()

X
Xi Chen 已提交
146 147
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

X
Xi Chen 已提交
148 149 150
ec2client = boto3.client('ec2')


X
Xi Chen 已提交
151 152 153 154 155 156 157
def print_arguments():
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


X
Xi Chen 已提交
158 159
def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
160
    logging.info("start creating subnet")
X
Xi Chen 已提交
161
    if not args.vpc_id:
X
Xi Chen 已提交
162
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
163 164 165 166 167 168 169 170 171 172
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
173 174
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
175 176

    if not vpc_cidrBlock:
X
Xi Chen 已提交
177
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
178 179 180 181 182 183 184 185
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
186
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
187 188 189

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
190
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
212
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
213 214 215 216 217 218 219 220
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
221
    logging.info("trying to create subnet")
X
Xi Chen 已提交
222
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
223 224 225
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
226 227 228 229 230 231 232 233

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
234
    logging.info("subnet created")
X
Xi Chen 已提交
235

X
Xi Chen 已提交
236
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
237 238 239 240 241 242 243 244 245
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


X
Xi Chen 已提交
246
def run_instances(image_id, instance_type, count=1, role="MASTER", cmd=""):
X
Xi Chen 已提交
247 248 249 250 251 252 253 254 255
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
256
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
257 258 259 260 261 262 263 264 265 266
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
X
Xi Chen 已提交
267
                "Value": args.task_name + "_master"
X
Xi Chen 已提交
268 269 270 271 272 273 274 275 276 277 278
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
279
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
280
    else:
X
Xi Chen 已提交
281
        logging.info("no instance created")
X
Xi Chen 已提交
282 283
    #create waiter to make sure it's running

X
Xi Chen 已提交
284
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


X
Xi Chen 已提交
304 305 306 307 308
def generate_task_name():
    return namesgenerator.get_random_name()


def init_args():
X
Xi Chen 已提交
309 310 311

    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
312
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
313 314 315 316 317 318 319

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )


X
Xi Chen 已提交
320
def create():
X
Xi Chen 已提交
321

X
Xi Chen 已提交
322
    init_args()
X
Xi Chen 已提交
323

X
Xi Chen 已提交
324 325 326
    # create subnet
    if not args.subnet_id:
        args.subnet_id = create_subnet()
X
Xi Chen 已提交
327

X
Xi Chen 已提交
328
    # create master node
X
Xi Chen 已提交
329

X
Xi Chen 已提交
330 331
    master_instance_response = run_instances(
        image_id="ami-7a05351f", instance_type="t2.nano")
X
Xi Chen 已提交
332

X
Xi Chen 已提交
333
    logging.info("master server started")
X
Xi Chen 已提交
334

X
Xi Chen 已提交
335 336 337
    args.master_server_public_ip = master_instance_response[0][
        "PublicIpAddress"]
    args.master_server_ip = master_instance_response[0]["PrivateIpAddress"]
X
Xi Chen 已提交
338

X
Xi Chen 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
    logging.info("master server started, master_ip=%s, task_name=%s" %
                 (args.master_server_public_ip, args.task_name))

    # cp config file and pems to master node

    ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
    ssh_client = paramiko.SSHClient()
    ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh_client.connect(
        hostname=args.master_server_public_ip, username="ubuntu", pkey=ssh_key)

    with SCPClient(ssh_client.get_transport()) as scp:
        scp.put(os.path.expanduser("~") + "/" + ".aws",
                recursive=True,
                remote_path='/home/ubuntu/')
        scp.put(args.pem_path,
                remote_path='/home/ubuntu/' + args.key_name + ".pem")

    logging.info("credentials and pem copied to master")

    # set arguments and start docker
    kick_off_cmd = "docker run -d -v /home/ubuntu/.aws:/root/.aws/"
    kick_off_cmd += " -v /home/ubuntu/" + args.key_name + ".pem:/root/" + args.key_name + ".pem"
X
Xi Chen 已提交
362
    kick_off_cmd += " -v /home/ubuntu/logs/:/root/logs/"
X
Xi Chen 已提交
363 364
    kick_off_cmd += " -p " + str(args.master_server_port) + ":" + str(
        args.master_server_port)
X
Xi Chen 已提交
365
    kick_off_cmd += " " + args.master_docker_image
X
Xi Chen 已提交
366 367 368 369 370

    args_to_pass = copy.copy(args)
    args_to_pass.action = "serve"
    del args_to_pass.pem_path
    del args_to_pass.security_group_ids
X
Xi Chen 已提交
371
    del args_to_pass.master_docker_image
X
Xi Chen 已提交
372 373
    del args_to_pass.master_server_public_ip
    for arg, value in sorted(vars(args_to_pass).iteritems()):
X
Xi Chen 已提交
374 375
        if value:
            kick_off_cmd += ' --%s %s' % (arg, value)
X
Xi Chen 已提交
376 377 378 379 380 381 382 383 384

    logging.info(kick_off_cmd)
    stdin, stdout, stderr = ssh_client.exec_command(command=kick_off_cmd)
    return_code = stdout.channel.recv_exit_status()
    logging.info(return_code)
    if return_code != 0:
        raise Exception("Error while kicking off master")

    logging.info(
X
Xi Chen 已提交
385 386
        "master server finished init process, visit %s to check master log" %
        (get_master_web_url("/status")))
X
Xi Chen 已提交
387 388 389 390 391 392 393


def cleanup():
    print requests.post(get_master_web_url("/cleanup")).text


def status():
X
Xi Chen 已提交
394
    print requests.post(get_master_web_url("/status")).text
X
Xi Chen 已提交
395 396 397


def get_master_web_url(path):
X
Xi Chen 已提交
398 399
    return "http://" + args.master_server_public_ip + ":" + str(
        args.master_server_port) + path
X
Xi Chen 已提交
400 401 402 403


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
404 405 406
    if args.action == "create":
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
407
        create()
X
Xi Chen 已提交
408
    elif args.action == "cleanup":
X
Xi Chen 已提交
409 410 411 412 413 414 415
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        cleanup()
    elif args.action == "status":
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        status()