cluster_master.py 20.8 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import json
import math
import time
X
Xi Chen 已提交
20
import threading
X
Xi Chen 已提交
21
import logging
X
Xi Chen 已提交
22 23 24 25 26 27

import netaddr
import boto3
import namesgenerator
import paramiko

X
Xi Chen 已提交
28 29
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer

X
Xi Chen 已提交
30 31 32 33
# You must have aws_access_key_id, aws_secret_access_key, region set in
# ~/.aws/credentials and ~/.aws/config

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
34
parser.add_argument(
X
Xi Chen 已提交
35
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
36 37 38 39 40 41
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
42 43 44 45 46 47 48 49 50 51
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
52

X
Xi Chen 已提交
53 54 55
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
56 57
    default="c5.2xlarge",
    help="your pserver instance type, c5.2xlarge by default")
X
Xi Chen 已提交
58 59 60
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
61 62
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
63

X
Xi Chen 已提交
64 65 66 67 68 69 70 71
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
72 73 74
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-east-2"
)
X
Xi Chen 已提交
75 76 77
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
78 79 80 81 82 83 84 85 86
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-west-2"
)

parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
    '--pserver_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"),
    help="pserver bash file path")

X
Xi Chen 已提交
100 101 102
parser.add_argument(
    '--pserver_command', type=str, default="", help="pserver start command")

X
Xi Chen 已提交
103 104 105 106 107 108
parser.add_argument(
    '--trainer_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"),
    help="trainer bash file path")

X
Xi Chen 已提交
109 110 111
parser.add_argument(
    '--trainer_command', type=str, default="", help="trainer start command")

X
Xi Chen 已提交
112
parser.add_argument(
X
Xi Chen 已提交
113
    '--action', type=str, default="serve", help="create|cleanup|serve")
X
Xi Chen 已提交
114

X
Xi Chen 已提交
115 116 117 118 119 120 121 122
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
123 124 125 126 127 128
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_ip', type=str, default="", help="master server private ip")

X
Xi Chen 已提交
129 130 131 132
args = parser.parse_args()

ec2client = boto3.client('ec2')

X
Xi Chen 已提交
133 134
args.log_path = os.path.join(os.path.dirname(__file__), "logs/")

X
Xi Chen 已提交
135
logging.basicConfig(
X
Xi Chen 已提交
136 137 138
    filename=args.log_path + 'master.log',
    level=logging.INFO,
    format='%(asctime)s %(message)s')
X
Xi Chen 已提交
139

X
Xi Chen 已提交
140 141
log_files = ["master.log"]

X
Xi Chen 已提交
142 143 144

def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
145
    logging.info("start creating subnet")
X
Xi Chen 已提交
146
    if not args.vpc_id:
X
Xi Chen 已提交
147
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
148 149 150 151 152 153 154 155 156 157
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
158 159
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
160 161

    if not vpc_cidrBlock:
X
Xi Chen 已提交
162
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
163 164 165 166 167 168 169 170
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
171
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
172 173 174

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
175
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
197
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
198 199 200 201 202 203 204 205
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
206
    logging.info("trying to create subnet")
X
Xi Chen 已提交
207
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
208 209 210
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
211 212 213 214 215 216 217 218

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
219
    logging.info("subnet created")
X
Xi Chen 已提交
220

X
Xi Chen 已提交
221
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


def generate_task_name():
    return namesgenerator.get_random_name()


def script_to_str(file_path):
    if not file_path:
        return "echo $PSERVER_HOSTS"
    file = open(file_path, 'r')
    text = file.read().strip()
    file.close()
    return text


def run_instances(image_id, instance_type, count, role, cmd=""):
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
254
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
                "Value": args.task_name
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
277
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
278
    else:
X
Xi Chen 已提交
279
        logging.info("no instance created")
X
Xi Chen 已提交
280 281
    #create waiter to make sure it's running

X
Xi Chen 已提交
282
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


def create_pservers():
X
Xi Chen 已提交
303 304 305 306 307 308
    try:
        return run_instances(
            image_id=args.pserver_image_id,
            instance_type=args.pserver_instance_type,
            count=args.pserver_count,
            role="PSERVER", )
X
Xi Chen 已提交
309 310
    except Exception:
        logging.exception("error while trying to create pservers")
X
Xi Chen 已提交
311
        cleanup(args.task_name)
X
Xi Chen 已提交
312 313


X
Xi Chen 已提交
314 315 316
def log_to_file(source, filename):
    if not filename in log_files:
        log_files.append(filename)
X
Xi Chen 已提交
317
    with open(args.log_path + filename, "a") as log_file:
X
Xi Chen 已提交
318 319 320 321
        for line in iter(source.readline, ""):
            log_file.write(line)


X
Xi Chen 已提交
322
def create_trainers(kickoff_cmd, pserver_endpoints_str):
X
Xi Chen 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    def create_and_start_trainer(trainer_index):
        logging.info("trainer " + str(trainer_index) + " is starting")

        instance_response = run_instances(
            image_id=args.trainer_image_id,
            instance_type=args.trainer_instance_type,
            count=1,
            role="TRAINER", )[0]
        trainer_ip = instance_response["PrivateIpAddress"]

        logging.info("trainer " + str(trainer_index) + " started")

        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=trainer_ip, username="ubuntu", pkey=ssh_key)

        logging.info("trainer " + str(trainer_index) +
                     " terminal connected via ssh")

        cmd = kickoff_cmd.format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
            TRAINER_INDEX=str(trainer_index),
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
348 349
            TRAINER_COUNT=args.trainer_count,
            COMMAND=args.trainer_command,
X
Xi Chen 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)

        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)

        # read and save output log

        logging.info("trainer " + str(trainer_index) +
                     " command executed, keep fetching log")

        stdout_thread = threading.Thread(
            target=log_to_file,
            args=(
                stdout,
                "trainer_" + str(trainer_index) + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file,
            args=(
                stderr,
                "trainer_" + str(trainer_index) + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

        return_code = stdout.channel.recv_exit_status()
        if return_code != 0:
            trainer_create_results[trainer_index] = {'has_error': True}
            raise ValueError("trainer didn't finish with exit code 0")

        ssh_client.close()

    # multi thread starting trainer instance and run kickoff command

    trainer_threads = []
    trainer_create_results = {}
X
Xi Chen 已提交
388 389
    try:
        for i in xrange(args.trainer_count):
X
Xi Chen 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
            logging.info("starting tread for trainer " + str(i))
            trainer_thread = threading.Thread(
                target=create_and_start_trainer, args=(i, ))
            trainer_thread.start()
            trainer_threads.append(trainer_thread)

        for trainer_thread in trainer_threads:
            trainer_thread.join()

        for result in trainer_create_results:
            if result["has_error"]:
                logging.error(
                    "error during trainer starting or training, destorying the while cluster "
                )
                cleanup(args.task_name)
                break

        logging.info("all trainers stopped")
    except Exception, e:
        logging.info(
            "Training exception, clean up resources, please check log for more info"
        )
    finally:
X
Xi Chen 已提交
413
        cleanup(args.task_name)
X
Xi Chen 已提交
414 415 416 417


def cleanup(task_name):
    #shutdown all ec2 instances
X
Xi Chen 已提交
418
    print("going to clean up " + task_name + " instances")
X
Xi Chen 已提交
419 420 421
    instances_response = ec2client.describe_instances(Filters=[{
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
422 423 424
    }])

    instance_ids = []
X
Xi Chen 已提交
425 426 427 428
    if len(instances_response["Reservations"]) > 0:
        for reservation in instances_response["Reservations"]:
            for instance in reservation["Instances"]:
                instance_ids.append(instance["InstanceId"])
X
Xi Chen 已提交
429

X
Xi Chen 已提交
430
        ec2client.terminate_instances(InstanceIds=instance_ids)
X
Xi Chen 已提交
431

X
Xi Chen 已提交
432 433 434
        instance_termination_waiter = ec2client.get_waiter(
            'instance_terminated')
        instance_termination_waiter.wait(InstanceIds=instance_ids)
X
Xi Chen 已提交
435

X
Xi Chen 已提交
436
    #delete the subnet created
X
Xi Chen 已提交
437 438

    subnet = ec2client.describe_subnets(Filters=[{
X
Xi Chen 已提交
439 440
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
441 442
    }])

X
Xi Chen 已提交
443 444
    if len(subnet["Subnets"]) > 0:
        ec2client.delete_subnet(SubnetId=subnet["Subnets"][0]["SubnetId"])
X
Xi Chen 已提交
445
    # no subnet delete waiter, just leave it.
X
Xi Chen 已提交
446
    logging.info("Clearnup done")
X
Xi Chen 已提交
447 448 449
    return


X
Xi Chen 已提交
450 451 452 453 454 455 456 457 458
def kickoff_pserver(host, pserver_endpoints_str):
    try:
        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=host, username="ubuntu", pkey=ssh_key)
        cmd = (script_to_str(args.pserver_bash_file)).format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
X
Xi Chen 已提交
459 460
            PSERVER_PORT=args.pserver_port,
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
461 462 463
            COMMAND=args.pserver_command,
            TRAINER_COUNT=args.trainer_count,
            SERVER_ENDPOINT=host + ":" + str(args.pserver_port),
X
Xi Chen 已提交
464 465 466
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)
X
Xi Chen 已提交
467
        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
X
Xi Chen 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

        stdout_thread = threading.Thread(
            target=log_to_file, args=(
                stdout,
                "pserver_" + host + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file, args=(
                stderr,
                "pserver_" + host + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

X
Xi Chen 已提交
483
        return_code = stdout.channel.recv_exit_status()
X
Xi Chen 已提交
484
        logging.info(return_code)
X
Xi Chen 已提交
485 486
        if return_code != 0:
            raise Exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
487 488
    except Exception:
        logging.exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
489 490 491 492 493
        cleanup(args.task_name)
    finally:
        ssh_client.close()


X
Xi Chen 已提交
494 495
def init_args():

X
Xi Chen 已提交
496 497
    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
498
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
499 500 501 502 503 504

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )

X
Xi Chen 已提交
505 506 507 508 509 510 511 512 513 514 515
    args.trainers_job_done_count = 0


def create_cluster():

    if not args.subnet_id:
        logging.info("creating subnet for this task")
        args.subnet_id = create_subnet()
        logging.info("subnet %s created" % (args.subnet_id))

    logging.info("creating pservers")
X
Xi Chen 已提交
516
    pserver_create_response = create_pservers()
X
Xi Chen 已提交
517
    logging.info("pserver created, collecting pserver ips")
X
Xi Chen 已提交
518 519 520 521 522 523 524 525

    pserver_endpoints = []
    for pserver in pserver_create_response:
        pserver_endpoints.append(pserver["NetworkInterfaces"][0][
            "PrivateIpAddress"] + ":" + args.pserver_port)

    pserver_endpoints_str = ",".join(pserver_endpoints)

X
Xi Chen 已提交
526
    logging.info("kicking off pserver training process")
X
Xi Chen 已提交
527
    pserver_threads = []
X
Xi Chen 已提交
528
    for pserver in pserver_create_response:
X
Xi Chen 已提交
529 530
        pserver_thread = threading.Thread(
            target=kickoff_pserver,
X
Xi Chen 已提交
531
            args=(pserver["PrivateIpAddress"], pserver_endpoints_str))
X
Xi Chen 已提交
532 533 534
        pserver_thread.start()
        pserver_threads.append(pserver_thread)

X
Xi Chen 已提交
535
    logging.info("all pserver training process started")
X
Xi Chen 已提交
536

X
Xi Chen 已提交
537
    logging.info("creating trainers and kicking off trainer training process")
X
Xi Chen 已提交
538 539 540
    create_trainers(
        kickoff_cmd=script_to_str(args.trainer_bash_file),
        pserver_endpoints_str=pserver_endpoints_str)
X
Xi Chen 已提交
541 542 543 544 545

    for pserver_thread in pserver_threads:
        pserver_thread.join()

    logging.info("all process ended")
X
Xi Chen 已提交
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565


def start_server(args):
    class S(BaseHTTPRequestHandler):
        def _set_headers(self):
            self.send_response(200)
            self.send_header('Content-type', 'text/text')
            self.end_headers()

        def do_HEAD(self):
            self._set_headers()

        def do_404(self):
            self.send_response(404)
            self.send_header('Content-type', 'text/text')
            self.end_headers()
            logging.info("Received invalid GET request" + self.path)
            self.wfile.write("NO ACTION FOUND")

        def do_GET(self):
X
Xi Chen 已提交
566

X
Xi Chen 已提交
567
            request_path = self.path
X
Xi Chen 已提交
568 569
            if request_path == "/status" or request_path == "/master_logs":
                self._set_headers()
X
Xi Chen 已提交
570
                logging.info("Received request to return status")
X
Xi Chen 已提交
571
                with open(args.log_path + "master.log", "r") as logfile:
X
Xi Chen 已提交
572
                    self.wfile.write(logfile.read().strip())
X
Xi Chen 已提交
573 574 575 576
            elif request_path == "/list_logs":
                self._set_headers()
                self.wfile.write("\n".join(log_files))
            elif "/log/" in request_path:
X
Xi Chen 已提交
577 578 579 580 581
                self._set_headers()
                log_file_path = request_path.replace("/log/", "")
                logging.info("requesting log file path is" + args.log_path +
                             log_file_path)
                with open(args.log_path + log_file_path, "r") as logfile:
X
Xi Chen 已提交
582
                    self.wfile.write(logfile.read().strip())
X
Xi Chen 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
            else:
                self.do_404()

        def do_POST(self):

            request_path = self.path

            if request_path == "/save_data":
                self._set_headers()
                logging.info("Received request to save data")
                self.wfile.write("DATA SAVED!")
                content_length = int(self.headers['Content-Length'])
                post_data = self.rfile.read(content_length)
                if args.task_name:
                    with open(args.task_name + ".txt", "a") as text_file:
                        text_file.write(post_data + "\n")

            elif request_path == "/cleanup":
                self._set_headers()
                logging.info("Received request to cleanup cluster")
                cleanup(args.task_name)
                self.wfile.write("cleanup in progress")

            else:
                self.do_404()

    server_address = ('', args.master_server_port)
    httpd = HTTPServer(server_address, S)
    logging.info("HTTP server is starting")
    httpd.serve_forever()
X
Xi Chen 已提交
613 614 615


def print_arguments():
X
Xi Chen 已提交
616
    logging.info('-----------  Configuration Arguments -----------')
X
Xi Chen 已提交
617
    for arg, value in sorted(vars(args).iteritems()):
X
Xi Chen 已提交
618 619
        logging.info('%s: %s' % (arg, value))
    logging.info('------------------------------------------------')
X
Xi Chen 已提交
620 621 622 623


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
624
    if args.action == "create":
X
Xi Chen 已提交
625
        logging.info("going to create cluster")
X
Xi Chen 已提交
626 627
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
628 629
        init_args()
        create_cluster()
X
Xi Chen 已提交
630
    elif args.action == "cleanup":
X
Xi Chen 已提交
631
        logging.info("going to cleanup cluster")
X
Xi Chen 已提交
632 633 634
        if not args.task_name:
            raise ValueError("task_name is required")
        cleanup(args.task_name)
X
Xi Chen 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650
    elif args.action == "serve":
        # serve mode
        if not args.master_server_ip:
            raise ValueError(
                "No master server ip set, please run with --action create")

        logging.info("going to start serve and create cluster")

        init_args()

        logging.info("starting server in another thread")
        server_thread = threading.Thread(target=start_server, args=(args, ))
        server_thread.start()

        create_cluster()
        server_thread.join()
X
Xi Chen 已提交
651
    elif args.action == "test":
X
Xi Chen 已提交
652
        start_server(args)