cluster_launcher.py 11.5 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
X
Xi Chen 已提交
18 19 20
import math
import logging
import copy
X
Xi Chen 已提交
21 22 23 24 25

import netaddr
import boto3
import namesgenerator
import paramiko
X
Xi Chen 已提交
26 27
from scp import SCPClient
import requests
X
Xi Chen 已提交
28 29

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
30
parser.add_argument(
X
Xi Chen 已提交
31
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
32 33 34 35 36 37
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
38 39 40 41 42 43 44 45 46 47
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
48

X
Xi Chen 已提交
49 50 51
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
52 53
    default="p2.8xlarge",
    help="your pserver instance type, p2.8xlarge by default")
X
Xi Chen 已提交
54 55 56
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
57 58
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
59

X
Xi Chen 已提交
60 61 62 63 64 65 66 67
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
68
    default="ami-da2c1cbf",
X
Xi Chen 已提交
69 70
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-east-2")
X
Xi Chen 已提交
71 72 73
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
74
    default="ami-da2c1cbf",
X
Xi Chen 已提交
75 76
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-west-2")
X
Xi Chen 已提交
77 78 79 80 81 82

parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
83 84 85 86 87 88 89 90

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
X
Xi Chen 已提交
91
    '--action', type=str, default="create", help="create|cleanup|status")
X
Xi Chen 已提交
92

X
Xi Chen 已提交
93 94 95 96 97 98 99 100
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
101 102 103 104 105 106
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_public_ip', type=str, help="master server public ip")

X
Xi Chen 已提交
107 108
args = parser.parse_args()

X
Xi Chen 已提交
109 110
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

X
Xi Chen 已提交
111 112 113
ec2client = boto3.client('ec2')


X
Xi Chen 已提交
114 115 116 117 118 119 120
def print_arguments():
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


X
Xi Chen 已提交
121 122
def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
123
    logging.info("start creating subnet")
X
Xi Chen 已提交
124
    if not args.vpc_id:
X
Xi Chen 已提交
125
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
126 127 128 129 130 131 132 133 134 135
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
136 137
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
138 139

    if not vpc_cidrBlock:
X
Xi Chen 已提交
140
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
141 142 143 144 145 146 147 148
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
149
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
150 151 152

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
153
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
175
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
176 177 178 179 180 181 182 183
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
184
    logging.info("trying to create subnet")
X
Xi Chen 已提交
185
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
186 187 188
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
189 190 191 192 193 194 195 196

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
197
    logging.info("subnet created")
X
Xi Chen 已提交
198

X
Xi Chen 已提交
199
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
200 201 202 203 204 205 206 207 208
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


X
Xi Chen 已提交
209
def run_instances(image_id, instance_type, count=1, role="MASTER", cmd=""):
X
Xi Chen 已提交
210 211 212 213 214 215 216 217 218
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
219
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
220 221 222 223 224 225 226 227 228 229
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
X
Xi Chen 已提交
230
                "Value": args.task_name + "_master"
X
Xi Chen 已提交
231 232 233 234 235 236 237 238 239 240 241
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
242
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
243
    else:
X
Xi Chen 已提交
244
        logging.info("no instance created")
X
Xi Chen 已提交
245 246
    #create waiter to make sure it's running

X
Xi Chen 已提交
247
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


X
Xi Chen 已提交
267 268 269 270 271
def generate_task_name():
    return namesgenerator.get_random_name()


def init_args():
X
Xi Chen 已提交
272 273 274

    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
275
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
276 277 278 279 280 281 282

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )


X
Xi Chen 已提交
283
def create():
X
Xi Chen 已提交
284

X
Xi Chen 已提交
285
    init_args()
X
Xi Chen 已提交
286

X
Xi Chen 已提交
287 288 289
    # create subnet
    if not args.subnet_id:
        args.subnet_id = create_subnet()
X
Xi Chen 已提交
290

X
Xi Chen 已提交
291
    # create master node
X
Xi Chen 已提交
292

X
Xi Chen 已提交
293 294
    master_instance_response = run_instances(
        image_id="ami-7a05351f", instance_type="t2.nano")
X
Xi Chen 已提交
295

X
Xi Chen 已提交
296
    logging.info("master server started")
X
Xi Chen 已提交
297

X
Xi Chen 已提交
298 299 300
    args.master_server_public_ip = master_instance_response[0][
        "PublicIpAddress"]
    args.master_server_ip = master_instance_response[0]["PrivateIpAddress"]
X
Xi Chen 已提交
301

X
Xi Chen 已提交
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
    logging.info("master server started, master_ip=%s, task_name=%s" %
                 (args.master_server_public_ip, args.task_name))

    # cp config file and pems to master node

    ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
    ssh_client = paramiko.SSHClient()
    ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh_client.connect(
        hostname=args.master_server_public_ip, username="ubuntu", pkey=ssh_key)

    with SCPClient(ssh_client.get_transport()) as scp:
        scp.put(os.path.expanduser("~") + "/" + ".aws",
                recursive=True,
                remote_path='/home/ubuntu/')
        scp.put(args.pem_path,
                remote_path='/home/ubuntu/' + args.key_name + ".pem")

    logging.info("credentials and pem copied to master")

    # set arguments and start docker
    kick_off_cmd = "docker run -d -v /home/ubuntu/.aws:/root/.aws/"
    kick_off_cmd += " -v /home/ubuntu/" + args.key_name + ".pem:/root/" + args.key_name + ".pem"
    kick_off_cmd += " -p " + str(args.master_server_port) + ":" + str(
        args.master_server_port)
    kick_off_cmd += " putcn/paddle_aws_master"

    args_to_pass = copy.copy(args)
    args_to_pass.action = "serve"
    del args_to_pass.pem_path
    del args_to_pass.security_group_ids
    del args_to_pass.master_server_public_ip
    for arg, value in sorted(vars(args_to_pass).iteritems()):
        kick_off_cmd += ' --%s %s' % (arg, value)

    logging.info(kick_off_cmd)
    stdin, stdout, stderr = ssh_client.exec_command(command=kick_off_cmd)
    return_code = stdout.channel.recv_exit_status()
    logging.info(return_code)
    if return_code != 0:
        raise Exception("Error while kicking off master")

    logging.info(
        "master sercer finished init process, visit %s to check master log" %
        (get_master_web_url("/logs")))


def cleanup():
    print requests.post(get_master_web_url("/cleanup")).text


def status():
    print requests.post(get_master_web_url("/logs")).text


def get_master_web_url(path):
X
Xi Chen 已提交
358 359
    return "http://" + args.master_server_public_ip + ":" + str(
        args.master_server_port) + path
X
Xi Chen 已提交
360 361 362 363


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
364 365 366
    if args.action == "create":
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
367
        create()
X
Xi Chen 已提交
368
    elif args.action == "cleanup":
X
Xi Chen 已提交
369 370 371 372 373 374 375
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        cleanup()
    elif args.action == "status":
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        status()