cluster_launcher.py 12.0 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
X
Xi Chen 已提交
18 19 20
import math
import logging
import copy
X
Xi Chen 已提交
21 22 23 24 25

import netaddr
import boto3
import namesgenerator
import paramiko
X
Xi Chen 已提交
26 27
from scp import SCPClient
import requests
X
Xi Chen 已提交
28 29

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
30
parser.add_argument(
X
Xi Chen 已提交
31
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
32 33 34 35 36 37
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
38 39 40 41 42 43 44 45 46 47
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
48

X
Xi Chen 已提交
49 50 51
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
52 53
    default="c5.2xlarge",
    help="your pserver instance type, c5.2xlarge by default")
X
Xi Chen 已提交
54 55 56
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
57 58
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
59

X
Xi Chen 已提交
60 61 62 63 64 65 66 67
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
68
    default="ami-da2c1cbf",
X
Xi Chen 已提交
69 70
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-east-2")
X
Xi Chen 已提交
71 72 73 74

parser.add_argument(
    '--pserver_command', type=str, default="", help="pserver start command")

X
Xi Chen 已提交
75 76 77
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
78
    default="ami-da2c1cbf",
X
Xi Chen 已提交
79 80
    help="ami id for system image, default one has nvidia-docker ready, \
    use ami-1ae93962 for us-west-2")
X
Xi Chen 已提交
81

X
Xi Chen 已提交
82 83 84
parser.add_argument(
    '--trainer_command', type=str, default="", help="trainer start command")

X
Xi Chen 已提交
85 86 87 88 89
parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
90 91 92 93 94 95 96 97

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
X
Xi Chen 已提交
98
    '--action', type=str, default="create", help="create|cleanup|status")
X
Xi Chen 已提交
99

X
Xi Chen 已提交
100 101 102 103 104 105 106 107
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
108 109 110 111 112 113
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_public_ip', type=str, help="master server public ip")

X
Xi Chen 已提交
114 115 116 117 118 119
parser.add_argument(
    '--master_docker_image',
    type=str,
    default="putcn/paddle_aws_master:latest",
    help="master docker image id")

X
Xi Chen 已提交
120 121
args = parser.parse_args()

X
Xi Chen 已提交
122 123
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

X
Xi Chen 已提交
124 125 126
ec2client = boto3.client('ec2')


X
Xi Chen 已提交
127 128 129 130 131 132 133
def print_arguments():
    print('-----------  Configuration Arguments -----------')
    for arg, value in sorted(vars(args).iteritems()):
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


X
Xi Chen 已提交
134 135
def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
136
    logging.info("start creating subnet")
X
Xi Chen 已提交
137
    if not args.vpc_id:
X
Xi Chen 已提交
138
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
139 140 141 142 143 144 145 146 147 148
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
149 150
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
151 152

    if not vpc_cidrBlock:
X
Xi Chen 已提交
153
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
154 155 156 157 158 159 160 161
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
162
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
163 164 165

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
166
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
188
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
189 190 191 192 193 194 195 196
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
197
    logging.info("trying to create subnet")
X
Xi Chen 已提交
198
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
199 200 201
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
202 203 204 205 206 207 208 209

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
210
    logging.info("subnet created")
X
Xi Chen 已提交
211

X
Xi Chen 已提交
212
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
213 214 215 216 217 218 219 220 221
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


X
Xi Chen 已提交
222
def run_instances(image_id, instance_type, count=1, role="MASTER", cmd=""):
X
Xi Chen 已提交
223 224 225 226 227 228 229 230 231
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
232
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
233 234 235 236 237 238 239 240 241 242
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
X
Xi Chen 已提交
243
                "Value": args.task_name + "_master"
X
Xi Chen 已提交
244 245 246 247 248 249 250 251 252 253 254
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
255
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
256
    else:
X
Xi Chen 已提交
257
        logging.info("no instance created")
X
Xi Chen 已提交
258 259
    #create waiter to make sure it's running

X
Xi Chen 已提交
260
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


X
Xi Chen 已提交
280 281 282 283 284
def generate_task_name():
    return namesgenerator.get_random_name()


def init_args():
X
Xi Chen 已提交
285 286 287

    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
288
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
289 290 291 292 293 294 295

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )


X
Xi Chen 已提交
296
def create():
X
Xi Chen 已提交
297

X
Xi Chen 已提交
298
    init_args()
X
Xi Chen 已提交
299

X
Xi Chen 已提交
300 301 302
    # create subnet
    if not args.subnet_id:
        args.subnet_id = create_subnet()
X
Xi Chen 已提交
303

X
Xi Chen 已提交
304
    # create master node
X
Xi Chen 已提交
305

X
Xi Chen 已提交
306 307
    master_instance_response = run_instances(
        image_id="ami-7a05351f", instance_type="t2.nano")
X
Xi Chen 已提交
308

X
Xi Chen 已提交
309
    logging.info("master server started")
X
Xi Chen 已提交
310

X
Xi Chen 已提交
311 312 313
    args.master_server_public_ip = master_instance_response[0][
        "PublicIpAddress"]
    args.master_server_ip = master_instance_response[0]["PrivateIpAddress"]
X
Xi Chen 已提交
314

X
Xi Chen 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
    logging.info("master server started, master_ip=%s, task_name=%s" %
                 (args.master_server_public_ip, args.task_name))

    # cp config file and pems to master node

    ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
    ssh_client = paramiko.SSHClient()
    ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh_client.connect(
        hostname=args.master_server_public_ip, username="ubuntu", pkey=ssh_key)

    with SCPClient(ssh_client.get_transport()) as scp:
        scp.put(os.path.expanduser("~") + "/" + ".aws",
                recursive=True,
                remote_path='/home/ubuntu/')
        scp.put(args.pem_path,
                remote_path='/home/ubuntu/' + args.key_name + ".pem")

    logging.info("credentials and pem copied to master")

    # set arguments and start docker
    kick_off_cmd = "docker run -d -v /home/ubuntu/.aws:/root/.aws/"
    kick_off_cmd += " -v /home/ubuntu/" + args.key_name + ".pem:/root/" + args.key_name + ".pem"
X
Xi Chen 已提交
338
    kick_off_cmd += " -v /home/ubuntu/logs/:/root/logs/"
X
Xi Chen 已提交
339 340
    kick_off_cmd += " -p " + str(args.master_server_port) + ":" + str(
        args.master_server_port)
X
Xi Chen 已提交
341
    kick_off_cmd += " " + args.master_docker_image
X
Xi Chen 已提交
342 343 344 345 346

    args_to_pass = copy.copy(args)
    args_to_pass.action = "serve"
    del args_to_pass.pem_path
    del args_to_pass.security_group_ids
X
Xi Chen 已提交
347
    del args_to_pass.master_docker_image
X
Xi Chen 已提交
348 349 350 351 352 353 354 355 356 357 358 359
    del args_to_pass.master_server_public_ip
    for arg, value in sorted(vars(args_to_pass).iteritems()):
        kick_off_cmd += ' --%s %s' % (arg, value)

    logging.info(kick_off_cmd)
    stdin, stdout, stderr = ssh_client.exec_command(command=kick_off_cmd)
    return_code = stdout.channel.recv_exit_status()
    logging.info(return_code)
    if return_code != 0:
        raise Exception("Error while kicking off master")

    logging.info(
X
Xi Chen 已提交
360 361
        "master server finished init process, visit %s to check master log" %
        (get_master_web_url("/status")))
X
Xi Chen 已提交
362 363 364 365 366 367 368


def cleanup():
    print requests.post(get_master_web_url("/cleanup")).text


def status():
X
Xi Chen 已提交
369
    print requests.post(get_master_web_url("/status")).text
X
Xi Chen 已提交
370 371 372


def get_master_web_url(path):
X
Xi Chen 已提交
373 374
    return "http://" + args.master_server_public_ip + ":" + str(
        args.master_server_port) + path
X
Xi Chen 已提交
375 376 377 378


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
379 380 381
    if args.action == "create":
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
382
        create()
X
Xi Chen 已提交
383
    elif args.action == "cleanup":
X
Xi Chen 已提交
384 385 386 387 388 389 390
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        cleanup()
    elif args.action == "status":
        if not args.master_server_public_ip:
            raise ValueError("master_server_public_ip is required")
        status()