cluster_master.py 21.5 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import json
import math
import time
X
Xi Chen 已提交
20
import threading
X
Xi Chen 已提交
21
import logging
X
Xi Chen 已提交
22 23 24 25 26 27

import netaddr
import boto3
import namesgenerator
import paramiko

X
Xi Chen 已提交
28 29
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer

X
Xi Chen 已提交
30

X
Xi Chen 已提交
31 32
# You must have aws_access_key_id, aws_secret_access_key, region set in
# ~/.aws/credentials and ~/.aws/config
X
Xi Chen 已提交
33 34 35 36 37 38 39 40
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

X
Xi Chen 已提交
41 42

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
43
parser.add_argument(
X
Xi Chen 已提交
44
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
45 46 47 48 49 50
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
51 52 53 54 55 56 57 58 59 60
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
61

X
Xi Chen 已提交
62 63 64
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
65 66
    default="c5.2xlarge",
    help="your pserver instance type, c5.2xlarge by default")
X
Xi Chen 已提交
67 68 69
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
70 71
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
72

X
Xi Chen 已提交
73 74 75 76 77 78 79 80
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
81 82 83
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-east-2"
)
X
Xi Chen 已提交
84 85 86
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
87 88 89 90 91 92 93 94 95
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-west-2"
)

parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
    '--pserver_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"),
    help="pserver bash file path")

X
Xi Chen 已提交
109 110 111
parser.add_argument(
    '--pserver_command', type=str, default="", help="pserver start command")

X
Xi Chen 已提交
112 113 114 115 116 117
parser.add_argument(
    '--trainer_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"),
    help="trainer bash file path")

X
Xi Chen 已提交
118 119 120
parser.add_argument(
    '--trainer_command', type=str, default="", help="trainer start command")

X
Xi Chen 已提交
121
parser.add_argument(
X
Xi Chen 已提交
122
    '--action', type=str, default="serve", help="create|cleanup|serve")
X
Xi Chen 已提交
123

X
Xi Chen 已提交
124 125 126 127 128 129 130 131
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
132 133 134 135 136 137
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_ip', type=str, default="", help="master server private ip")

X
Xi Chen 已提交
138 139 140 141 142 143
parser.add_argument(
    '--no_clean_up',
    type=str2bool,
    default=False,
    help="whether to clean up after training")

X
Xi Chen 已提交
144 145 146 147
args = parser.parse_args()

ec2client = boto3.client('ec2')

X
Xi Chen 已提交
148 149
args.log_path = os.path.join(os.path.dirname(__file__), "logs/")

X
Xi Chen 已提交
150
logging.basicConfig(
X
Xi Chen 已提交
151 152 153
    filename=args.log_path + 'master.log',
    level=logging.INFO,
    format='%(asctime)s %(message)s')
X
Xi Chen 已提交
154

X
Xi Chen 已提交
155 156
log_files = ["master.log"]

X
Xi Chen 已提交
157 158 159

def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
160
    logging.info("start creating subnet")
X
Xi Chen 已提交
161
    if not args.vpc_id:
X
Xi Chen 已提交
162
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
163 164 165 166 167 168 169 170 171 172
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
173 174
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
175 176

    if not vpc_cidrBlock:
X
Xi Chen 已提交
177
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
178 179 180 181 182 183 184 185
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
186
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
187 188 189

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
190
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
212
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
213 214 215 216 217 218 219 220
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
221
    logging.info("trying to create subnet")
X
Xi Chen 已提交
222
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
223 224 225
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
226 227 228 229 230 231 232 233

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
234
    logging.info("subnet created")
X
Xi Chen 已提交
235

X
Xi Chen 已提交
236
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


def generate_task_name():
    return namesgenerator.get_random_name()


def script_to_str(file_path):
    if not file_path:
        return "echo $PSERVER_HOSTS"
    file = open(file_path, 'r')
    text = file.read().strip()
    file.close()
    return text


def run_instances(image_id, instance_type, count, role, cmd=""):
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
269
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
                "Value": args.task_name
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
292
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
293
    else:
X
Xi Chen 已提交
294
        logging.info("no instance created")
X
Xi Chen 已提交
295 296
    #create waiter to make sure it's running

X
Xi Chen 已提交
297
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


def create_pservers():
X
Xi Chen 已提交
318 319 320 321 322 323
    try:
        return run_instances(
            image_id=args.pserver_image_id,
            instance_type=args.pserver_instance_type,
            count=args.pserver_count,
            role="PSERVER", )
X
Xi Chen 已提交
324 325
    except Exception:
        logging.exception("error while trying to create pservers")
X
Xi Chen 已提交
326
        cleanup(args.task_name)
X
Xi Chen 已提交
327 328


X
Xi Chen 已提交
329 330 331
def log_to_file(source, filename):
    if not filename in log_files:
        log_files.append(filename)
X
Xi Chen 已提交
332
    with open(args.log_path + filename, "a") as log_file:
X
Xi Chen 已提交
333 334 335 336
        for line in iter(source.readline, ""):
            log_file.write(line)


X
Xi Chen 已提交
337
def create_trainers(kickoff_cmd, pserver_endpoints_str):
X
Xi Chen 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
    def create_and_start_trainer(trainer_index):
        logging.info("trainer " + str(trainer_index) + " is starting")

        instance_response = run_instances(
            image_id=args.trainer_image_id,
            instance_type=args.trainer_instance_type,
            count=1,
            role="TRAINER", )[0]
        trainer_ip = instance_response["PrivateIpAddress"]

        logging.info("trainer " + str(trainer_index) + " started")

        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=trainer_ip, username="ubuntu", pkey=ssh_key)

        logging.info("trainer " + str(trainer_index) +
                     " terminal connected via ssh")

        cmd = kickoff_cmd.format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
            TRAINER_INDEX=str(trainer_index),
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
363 364
            TRAINER_COUNT=args.trainer_count,
            COMMAND=args.trainer_command,
X
Xi Chen 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)

        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)

        # read and save output log

        logging.info("trainer " + str(trainer_index) +
                     " command executed, keep fetching log")

        stdout_thread = threading.Thread(
            target=log_to_file,
            args=(
                stdout,
                "trainer_" + str(trainer_index) + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file,
            args=(
                stderr,
                "trainer_" + str(trainer_index) + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

        return_code = stdout.channel.recv_exit_status()
        if return_code != 0:
            trainer_create_results[trainer_index] = {'has_error': True}
            raise ValueError("trainer didn't finish with exit code 0")

        ssh_client.close()

    # multi thread starting trainer instance and run kickoff command

    trainer_threads = []
    trainer_create_results = {}
X
Xi Chen 已提交
403 404
    try:
        for i in xrange(args.trainer_count):
X
Xi Chen 已提交
405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
            logging.info("starting tread for trainer " + str(i))
            trainer_thread = threading.Thread(
                target=create_and_start_trainer, args=(i, ))
            trainer_thread.start()
            trainer_threads.append(trainer_thread)

        for trainer_thread in trainer_threads:
            trainer_thread.join()

        for result in trainer_create_results:
            if result["has_error"]:
                logging.error(
                    "error during trainer starting or training, destorying the while cluster "
                )
                cleanup(args.task_name)
                break

        logging.info("all trainers stopped")
    except Exception, e:
        logging.info(
            "Training exception, clean up resources, please check log for more info"
        )
    finally:
X
Xi Chen 已提交
428
        cleanup(args.task_name)
X
Xi Chen 已提交
429 430 431


def cleanup(task_name):
X
Xi Chen 已提交
432 433 434
    if args.no_clean_up:
        logging.info("no clean up option set, going to leave the setup running")
        return
X
Xi Chen 已提交
435
    #shutdown all ec2 instances
X
Xi Chen 已提交
436
    print("going to clean up " + task_name + " instances")
X
Xi Chen 已提交
437 438 439
    instances_response = ec2client.describe_instances(Filters=[{
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
440 441 442
    }])

    instance_ids = []
X
Xi Chen 已提交
443 444 445 446
    if len(instances_response["Reservations"]) > 0:
        for reservation in instances_response["Reservations"]:
            for instance in reservation["Instances"]:
                instance_ids.append(instance["InstanceId"])
X
Xi Chen 已提交
447

X
Xi Chen 已提交
448
        ec2client.terminate_instances(InstanceIds=instance_ids)
X
Xi Chen 已提交
449

X
Xi Chen 已提交
450 451 452
        instance_termination_waiter = ec2client.get_waiter(
            'instance_terminated')
        instance_termination_waiter.wait(InstanceIds=instance_ids)
X
Xi Chen 已提交
453

X
Xi Chen 已提交
454
    #delete the subnet created
X
Xi Chen 已提交
455 456

    subnet = ec2client.describe_subnets(Filters=[{
X
Xi Chen 已提交
457 458
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
459 460
    }])

X
Xi Chen 已提交
461 462
    if len(subnet["Subnets"]) > 0:
        ec2client.delete_subnet(SubnetId=subnet["Subnets"][0]["SubnetId"])
X
Xi Chen 已提交
463
    # no subnet delete waiter, just leave it.
X
Xi Chen 已提交
464
    logging.info("Clearnup done")
X
Xi Chen 已提交
465 466 467
    return


X
Xi Chen 已提交
468 469 470 471 472 473 474 475 476
def kickoff_pserver(host, pserver_endpoints_str):
    try:
        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=host, username="ubuntu", pkey=ssh_key)
        cmd = (script_to_str(args.pserver_bash_file)).format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
X
Xi Chen 已提交
477 478
            PSERVER_PORT=args.pserver_port,
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
479 480
            COMMAND=args.pserver_command,
            TRAINER_COUNT=args.trainer_count,
481 482 483
            TRAINER_INDEX=0,
            # there is no way to use 0.0.0.0:port to start pserver
            # has to docker --network="host" with host ip to make this work
X
Xi Chen 已提交
484
            SERVER_ENDPOINT=host + ":" + str(args.pserver_port),
X
Xi Chen 已提交
485 486 487
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)
X
Xi Chen 已提交
488
        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
X
Xi Chen 已提交
489 490 491 492 493 494 495 496 497 498 499 500 501 502 503

        stdout_thread = threading.Thread(
            target=log_to_file, args=(
                stdout,
                "pserver_" + host + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file, args=(
                stderr,
                "pserver_" + host + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

X
Xi Chen 已提交
504
        return_code = stdout.channel.recv_exit_status()
X
Xi Chen 已提交
505
        logging.info(return_code)
X
Xi Chen 已提交
506 507
        if return_code != 0:
            raise Exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
508 509
    except Exception:
        logging.exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
510 511 512 513 514
        cleanup(args.task_name)
    finally:
        ssh_client.close()


X
Xi Chen 已提交
515 516
def init_args():

X
Xi Chen 已提交
517 518
    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
519
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
520 521 522 523 524 525

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )

X
Xi Chen 已提交
526 527 528 529 530 531 532 533 534 535 536
    args.trainers_job_done_count = 0


def create_cluster():

    if not args.subnet_id:
        logging.info("creating subnet for this task")
        args.subnet_id = create_subnet()
        logging.info("subnet %s created" % (args.subnet_id))

    logging.info("creating pservers")
X
Xi Chen 已提交
537
    pserver_create_response = create_pservers()
X
Xi Chen 已提交
538
    logging.info("pserver created, collecting pserver ips")
X
Xi Chen 已提交
539 540 541 542 543 544 545 546

    pserver_endpoints = []
    for pserver in pserver_create_response:
        pserver_endpoints.append(pserver["NetworkInterfaces"][0][
            "PrivateIpAddress"] + ":" + args.pserver_port)

    pserver_endpoints_str = ",".join(pserver_endpoints)

X
Xi Chen 已提交
547
    logging.info("kicking off pserver training process")
X
Xi Chen 已提交
548
    pserver_threads = []
X
Xi Chen 已提交
549
    for pserver in pserver_create_response:
X
Xi Chen 已提交
550 551
        pserver_thread = threading.Thread(
            target=kickoff_pserver,
X
Xi Chen 已提交
552
            args=(pserver["PrivateIpAddress"], pserver_endpoints_str))
X
Xi Chen 已提交
553 554 555
        pserver_thread.start()
        pserver_threads.append(pserver_thread)

X
Xi Chen 已提交
556
    logging.info("all pserver training process started")
X
Xi Chen 已提交
557

X
Xi Chen 已提交
558
    logging.info("creating trainers and kicking off trainer training process")
X
Xi Chen 已提交
559 560 561
    create_trainers(
        kickoff_cmd=script_to_str(args.trainer_bash_file),
        pserver_endpoints_str=pserver_endpoints_str)
X
Xi Chen 已提交
562 563 564 565 566

    for pserver_thread in pserver_threads:
        pserver_thread.join()

    logging.info("all process ended")
X
Xi Chen 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586


def start_server(args):
    class S(BaseHTTPRequestHandler):
        def _set_headers(self):
            self.send_response(200)
            self.send_header('Content-type', 'text/text')
            self.end_headers()

        def do_HEAD(self):
            self._set_headers()

        def do_404(self):
            self.send_response(404)
            self.send_header('Content-type', 'text/text')
            self.end_headers()
            logging.info("Received invalid GET request" + self.path)
            self.wfile.write("NO ACTION FOUND")

        def do_GET(self):
X
Xi Chen 已提交
587

X
Xi Chen 已提交
588
            request_path = self.path
X
Xi Chen 已提交
589 590
            if request_path == "/status" or request_path == "/master_logs":
                self._set_headers()
X
Xi Chen 已提交
591
                logging.info("Received request to return status")
X
Xi Chen 已提交
592
                with open(args.log_path + "master.log", "r") as logfile:
X
Xi Chen 已提交
593
                    self.wfile.write(logfile.read().strip())
594
            elif request_path == "/list_logs" or request_path == "/logs":
X
Xi Chen 已提交
595 596 597
                self._set_headers()
                self.wfile.write("\n".join(log_files))
            elif "/log/" in request_path:
X
Xi Chen 已提交
598 599 600 601 602
                self._set_headers()
                log_file_path = request_path.replace("/log/", "")
                logging.info("requesting log file path is" + args.log_path +
                             log_file_path)
                with open(args.log_path + log_file_path, "r") as logfile:
X
Xi Chen 已提交
603
                    self.wfile.write(logfile.read().strip())
X
Xi Chen 已提交
604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633
            else:
                self.do_404()

        def do_POST(self):

            request_path = self.path

            if request_path == "/save_data":
                self._set_headers()
                logging.info("Received request to save data")
                self.wfile.write("DATA SAVED!")
                content_length = int(self.headers['Content-Length'])
                post_data = self.rfile.read(content_length)
                if args.task_name:
                    with open(args.task_name + ".txt", "a") as text_file:
                        text_file.write(post_data + "\n")

            elif request_path == "/cleanup":
                self._set_headers()
                logging.info("Received request to cleanup cluster")
                cleanup(args.task_name)
                self.wfile.write("cleanup in progress")

            else:
                self.do_404()

    server_address = ('', args.master_server_port)
    httpd = HTTPServer(server_address, S)
    logging.info("HTTP server is starting")
    httpd.serve_forever()
X
Xi Chen 已提交
634 635 636


def print_arguments():
X
Xi Chen 已提交
637
    logging.info('-----------  Configuration Arguments -----------')
X
Xi Chen 已提交
638
    for arg, value in sorted(vars(args).iteritems()):
X
Xi Chen 已提交
639 640
        logging.info('%s: %s' % (arg, value))
    logging.info('------------------------------------------------')
X
Xi Chen 已提交
641 642 643 644


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
645
    if args.action == "create":
X
Xi Chen 已提交
646
        logging.info("going to create cluster")
X
Xi Chen 已提交
647 648
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
649 650
        init_args()
        create_cluster()
X
Xi Chen 已提交
651
    elif args.action == "cleanup":
X
Xi Chen 已提交
652
        logging.info("going to cleanup cluster")
X
Xi Chen 已提交
653 654 655
        if not args.task_name:
            raise ValueError("task_name is required")
        cleanup(args.task_name)
X
Xi Chen 已提交
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671
    elif args.action == "serve":
        # serve mode
        if not args.master_server_ip:
            raise ValueError(
                "No master server ip set, please run with --action create")

        logging.info("going to start serve and create cluster")

        init_args()

        logging.info("starting server in another thread")
        server_thread = threading.Thread(target=start_server, args=(args, ))
        server_thread.start()

        create_cluster()
        server_thread.join()
X
Xi Chen 已提交
672
    elif args.action == "test":
X
Xi Chen 已提交
673
        start_server(args)