cluster_master.py 23.6 KB
Newer Older
X
Xi Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import json
import math
import time
X
Xi Chen 已提交
20
import threading
X
Xi Chen 已提交
21
import logging
X
Xi Chen 已提交
22
import copy
X
Xi Chen 已提交
23
import csv
X
Xi Chen 已提交
24 25 26 27 28 29

import netaddr
import boto3
import namesgenerator
import paramiko

X
Xi Chen 已提交
30 31
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer

X
Xi Chen 已提交
32

X
Xi Chen 已提交
33 34
# You must have aws_access_key_id, aws_secret_access_key, region set in
# ~/.aws/credentials and ~/.aws/config
X
Xi Chen 已提交
35 36 37 38 39 40 41 42
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

X
Xi Chen 已提交
43 44

parser = argparse.ArgumentParser(description=__doc__)
X
Xi Chen 已提交
45
parser.add_argument(
X
Xi Chen 已提交
46
    '--key_name', type=str, default="", help="required, key pair name")
X
Xi Chen 已提交
47 48 49 50 51 52
parser.add_argument(
    '--security_group_id',
    type=str,
    default="",
    help="required, the security group id associated with your VPC")

X
Xi Chen 已提交
53 54 55 56 57 58 59 60 61 62
parser.add_argument(
    '--vpc_id',
    type=str,
    default="",
    help="The VPC in which you wish to run test")
parser.add_argument(
    '--subnet_id',
    type=str,
    default="",
    help="The Subnet_id in which you wish to run test")
X
Xi Chen 已提交
63

X
Xi Chen 已提交
64 65 66
parser.add_argument(
    '--pserver_instance_type',
    type=str,
X
Xi Chen 已提交
67 68
    default="c5.2xlarge",
    help="your pserver instance type, c5.2xlarge by default")
X
Xi Chen 已提交
69 70 71
parser.add_argument(
    '--trainer_instance_type',
    type=str,
X
Xi Chen 已提交
72 73
    default="p2.8xlarge",
    help="your trainer instance type, p2.8xlarge by default")
X
Xi Chen 已提交
74

X
Xi Chen 已提交
75 76 77 78 79 80 81 82
parser.add_argument(
    '--task_name',
    type=str,
    default="",
    help="the name you want to identify your job")
parser.add_argument(
    '--pserver_image_id',
    type=str,
X
Xi Chen 已提交
83 84 85
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-east-2"
)
X
Xi Chen 已提交
86 87 88
parser.add_argument(
    '--trainer_image_id',
    type=str,
X
Xi Chen 已提交
89 90 91 92 93 94 95 96 97
    default="ami-da2c1cbf",
    help="ami id for system image, default one has nvidia-docker ready, use ami-1ae93962 for us-west-2"
)

parser.add_argument(
    '--availability_zone',
    type=str,
    default="us-east-2a",
    help="aws zone id to place ec2 instances")
X
Xi Chen 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110

parser.add_argument(
    '--trainer_count', type=int, default=1, help="Trainer count")

parser.add_argument(
    '--pserver_count', type=int, default=1, help="Pserver count")

parser.add_argument(
    '--pserver_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "pserver.sh.template"),
    help="pserver bash file path")

X
Xi Chen 已提交
111 112 113
parser.add_argument(
    '--pserver_command', type=str, default="", help="pserver start command")

X
Xi Chen 已提交
114 115 116 117 118 119
parser.add_argument(
    '--trainer_bash_file',
    type=str,
    default=os.path.join(os.path.dirname(__file__), "trainer.sh.template"),
    help="trainer bash file path")

X
Xi Chen 已提交
120 121 122
parser.add_argument(
    '--trainer_command', type=str, default="", help="trainer start command")

X
Xi Chen 已提交
123
parser.add_argument(
X
Xi Chen 已提交
124
    '--action', type=str, default="serve", help="create|cleanup|serve")
X
Xi Chen 已提交
125

X
Xi Chen 已提交
126 127 128 129 130 131 132 133
parser.add_argument('--pem_path', type=str, help="private key file")

parser.add_argument(
    '--pserver_port', type=str, default="5436", help="pserver port")

parser.add_argument(
    '--docker_image', type=str, default="busybox", help="training docker image")

X
Xi Chen 已提交
134 135 136 137 138 139
parser.add_argument(
    '--master_server_port', type=int, default=5436, help="master server port")

parser.add_argument(
    '--master_server_ip', type=str, default="", help="master server private ip")

X
Xi Chen 已提交
140 141 142 143 144 145
parser.add_argument(
    '--metric_data_identifier',
    type=str,
    default="**metrics_data: ",
    help="key string to identify metrics data")

X
Xi Chen 已提交
146 147 148 149 150 151
parser.add_argument(
    '--no_clean_up',
    type=str2bool,
    default=False,
    help="whether to clean up after training")

X
Xi Chen 已提交
152 153 154 155
args = parser.parse_args()

ec2client = boto3.client('ec2')

X
Xi Chen 已提交
156 157
args.log_path = os.path.join(os.path.dirname(__file__), "logs/")

X
Xi Chen 已提交
158
logging.basicConfig(
X
Xi Chen 已提交
159 160 161
    filename=args.log_path + 'master.log',
    level=logging.INFO,
    format='%(asctime)s %(message)s')
X
Xi Chen 已提交
162

X
Xi Chen 已提交
163 164
log_files = ["master.log"]

X
Xi Chen 已提交
165 166 167 168 169
metrics = {}

metrics_csv_file_name = "metrics.csv"
is_metrics_file_created = False

X
Xi Chen 已提交
170 171 172

def create_subnet():
    # if no vpc id provided, list vpcs
X
Xi Chen 已提交
173
    logging.info("start creating subnet")
X
Xi Chen 已提交
174
    if not args.vpc_id:
X
Xi Chen 已提交
175
        logging.info("no vpc provided, trying to find the default one")
X
Xi Chen 已提交
176 177 178 179 180 181 182 183 184 185
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "isDefault",
                "Values": ["true", ]
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No default VPC')
        args.vpc_id = vpcs_desc["Vpcs"][0]["VpcId"]
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]

X
Xi Chen 已提交
186 187
        logging.info("default vpc fount with id %s and CidrBlock %s" %
                     (args.vpc_id, vpc_cidrBlock))
X
Xi Chen 已提交
188 189

    if not vpc_cidrBlock:
X
Xi Chen 已提交
190
        logging.info("trying to find cidrblock for vpc")
X
Xi Chen 已提交
191 192 193 194 195 196 197 198
        vpcs_desc = ec2client.describe_vpcs(
            Filters=[{
                "Name": "vpc-id",
                "Values": [args.vpc_id, ],
            }], )
        if len(vpcs_desc["Vpcs"]) == 0:
            raise ValueError('No VPC found')
        vpc_cidrBlock = vpcs_desc["Vpcs"][0]["CidrBlock"]
X
Xi Chen 已提交
199
        logging.info("cidrblock for vpc is %s" % vpc_cidrBlock)
X
Xi Chen 已提交
200 201 202

    # list subnets in vpc in order to create a new one

X
Xi Chen 已提交
203
    logging.info("trying to find ip blocks for new subnet")
X
Xi Chen 已提交
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    subnets_desc = ec2client.describe_subnets(
        Filters=[{
            "Name": "vpc-id",
            "Values": [args.vpc_id, ],
        }], )

    ips_taken = []
    for subnet_dec in subnets_desc["Subnets"]:
        ips_taken.append(subnet_dec["CidrBlock"])

    ip_blocks_avaliable = netaddr.IPSet(
        [vpc_cidrBlock]) ^ netaddr.IPSet(ips_taken)
    # adding 10 addresses as buffer
    cidr_prefix = 32 - math.ceil(
        math.log(args.pserver_count + args.trainer_count + 10, 2))
    if cidr_prefix <= 16:
        raise ValueError('Too many nodes to fit in current VPC')

    for ipnetwork in ip_blocks_avaliable.iter_cidrs():
        try:
            subnet_cidr = ipnetwork.subnet(int(cidr_prefix)).next()
X
Xi Chen 已提交
225
            logging.info("subnet ip block found %s" % (subnet_cidr))
X
Xi Chen 已提交
226 227 228 229 230 231 232 233
            break
        except Exception:
            pass

    if not subnet_cidr:
        raise ValueError(
            'No avaliable subnet to fit required nodes in current VPC')

X
Xi Chen 已提交
234
    logging.info("trying to create subnet")
X
Xi Chen 已提交
235
    subnet_desc = ec2client.create_subnet(
X
Xi Chen 已提交
236 237 238
        CidrBlock=str(subnet_cidr),
        VpcId=args.vpc_id,
        AvailabilityZone=args.availability_zone)
X
Xi Chen 已提交
239 240 241 242 243 244 245 246

    subnet_id = subnet_desc["Subnet"]["SubnetId"]

    subnet_waiter = ec2client.get_waiter('subnet_available')
    # sleep for 1s before checking its state
    time.sleep(1)
    subnet_waiter.wait(SubnetIds=[subnet_id, ])

X
Xi Chen 已提交
247
    logging.info("subnet created")
X
Xi Chen 已提交
248

X
Xi Chen 已提交
249
    logging.info("adding tags to newly created subnet")
X
Xi Chen 已提交
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
    ec2client.create_tags(
        Resources=[subnet_id, ],
        Tags=[{
            "Key": "Task_name",
            'Value': args.task_name
        }])
    return subnet_id


def generate_task_name():
    return namesgenerator.get_random_name()


def script_to_str(file_path):
    if not file_path:
        return "echo $PSERVER_HOSTS"
    file = open(file_path, 'r')
    text = file.read().strip()
    file.close()
    return text


def run_instances(image_id, instance_type, count, role, cmd=""):
X
Xi Chen 已提交
273 274
    if count == 0:
        return []
X
Xi Chen 已提交
275 276 277 278 279 280 281 282 283
    response = ec2client.run_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        MaxCount=count,
        MinCount=count,
        UserData=cmd,
        DryRun=False,
        InstanceInitiatedShutdownBehavior="stop",
        KeyName=args.key_name,
X
Xi Chen 已提交
284
        Placement={'AvailabilityZone': args.availability_zone},
X
Xi Chen 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
        NetworkInterfaces=[{
            'DeviceIndex': 0,
            'SubnetId': args.subnet_id,
            "AssociatePublicIpAddress": True,
            'Groups': args.security_group_ids
        }],
        TagSpecifications=[{
            'ResourceType': "instance",
            'Tags': [{
                "Key": 'Task_name',
                "Value": args.task_name
            }, {
                "Key": 'Role',
                "Value": role
            }]
        }])

    instance_ids = []
    for instance in response["Instances"]:
        instance_ids.append(instance["InstanceId"])

    if len(instance_ids) > 0:
X
Xi Chen 已提交
307
        logging.info(str(len(instance_ids)) + " instance(s) created")
X
Xi Chen 已提交
308
    else:
X
Xi Chen 已提交
309
        logging.info("no instance created")
X
Xi Chen 已提交
310 311
    #create waiter to make sure it's running

X
Xi Chen 已提交
312
    logging.info("waiting for instance to become accessible")
X
Xi Chen 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    waiter = ec2client.get_waiter('instance_status_ok')
    waiter.wait(
        Filters=[{
            "Name": "instance-status.status",
            "Values": ["ok"]
        }, {
            "Name": "instance-status.reachability",
            "Values": ["passed"]
        }, {
            "Name": "instance-state-name",
            "Values": ["running"]
        }],
        InstanceIds=instance_ids)

    instances_response = ec2client.describe_instances(InstanceIds=instance_ids)

    return instances_response["Reservations"][0]["Instances"]


def create_pservers():
X
Xi Chen 已提交
333 334 335 336 337 338
    try:
        return run_instances(
            image_id=args.pserver_image_id,
            instance_type=args.pserver_instance_type,
            count=args.pserver_count,
            role="PSERVER", )
X
Xi Chen 已提交
339 340
    except Exception:
        logging.exception("error while trying to create pservers")
X
Xi Chen 已提交
341
        cleanup(args.task_name)
X
Xi Chen 已提交
342 343


X
Xi Chen 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
def save_metrics_data(str_msg):
    #parse msg
    logging.info("found metrics data, saving it to csv file")
    global is_metrics_file_created
    metrics_raw = str_msg.split(",")
    with open(args.log_path + metrics_csv_file_name, 'a') as csvfile:
        csv_fieldnames = []
        csv_write_data = {}
        for metric in metrics_raw:
            metric_data = metric.split("=")
            metric_key = metric_data[0].strip()
            metric_val = float(metric_data[1].strip())
            if not metric_key in metrics:
                metrics[metric_key] = []
            metric_repo = metrics[metric_key]
            metric_repo.append(metric_val)
            csv_fieldnames.append(metric_key)
            csv_write_data[metric_key] = metric_val
        writer = csv.DictWriter(csvfile, fieldnames=csv_fieldnames)
        if not is_metrics_file_created:
            writer.writeheader()
            is_metrics_file_created = True
        writer.writerow(csv_write_data)
        logging.info("csv file appended")


X
Xi Chen 已提交
370 371 372
def log_to_file(source, filename):
    if not filename in log_files:
        log_files.append(filename)
X
Xi Chen 已提交
373
    with open(args.log_path + filename, "a") as log_file:
X
Xi Chen 已提交
374 375
        for line in iter(source.readline, ""):
            log_file.write(line)
X
Xi Chen 已提交
376 377 378 379
            if (line.startswith(args.metric_data_identifier)):
                #found key data, trying to add to csv
                line = line.replace(args.metric_data_identifier, "")
                save_metrics_data(line)
X
Xi Chen 已提交
380 381


X
Xi Chen 已提交
382 383
def parse_command(command_raw, defaults={}):
    if not command_raw:
X
Xi Chen 已提交
384
        command_raw = ""
X
Xi Chen 已提交
385 386 387 388 389 390 391 392 393
    commands_processed = []
    parameter_map = copy.copy(defaults)
    for seg in command_raw.split(","):
        if ":" in seg:
            parameters = seg.split(":")
            parameter_map[parameters[0]] = parameters[1]
        else:
            commands_processed.append(seg)
    for key, val in parameter_map.iteritems():
X
Xi Chen 已提交
394
        commands_processed.append("--" + key + " " + str(val))
X
Xi Chen 已提交
395 396 397
    return " ".join(commands_processed)


X
Xi Chen 已提交
398
def create_trainers(kickoff_cmd, pserver_endpoints_str):
X
Xi Chen 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
    def create_and_start_trainer(trainer_index):
        logging.info("trainer " + str(trainer_index) + " is starting")

        instance_response = run_instances(
            image_id=args.trainer_image_id,
            instance_type=args.trainer_instance_type,
            count=1,
            role="TRAINER", )[0]
        trainer_ip = instance_response["PrivateIpAddress"]

        logging.info("trainer " + str(trainer_index) + " started")

        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=trainer_ip, username="ubuntu", pkey=ssh_key)

        logging.info("trainer " + str(trainer_index) +
                     " terminal connected via ssh")

        cmd = kickoff_cmd.format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
            TRAINER_INDEX=str(trainer_index),
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
424
            TRAINER_COUNT=args.trainer_count,
X
Xi Chen 已提交
425
            COMMAND=parse_command(args.trainer_command, {"device": "GPU"}),
X
Xi Chen 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)

        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)

        # read and save output log

        logging.info("trainer " + str(trainer_index) +
                     " command executed, keep fetching log")

        stdout_thread = threading.Thread(
            target=log_to_file,
            args=(
                stdout,
                "trainer_" + str(trainer_index) + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file,
            args=(
                stderr,
                "trainer_" + str(trainer_index) + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

        return_code = stdout.channel.recv_exit_status()
        if return_code != 0:
            trainer_create_results[trainer_index] = {'has_error': True}
            raise ValueError("trainer didn't finish with exit code 0")

        ssh_client.close()

    # multi thread starting trainer instance and run kickoff command

    trainer_threads = []
    trainer_create_results = {}
X
Xi Chen 已提交
464 465
    try:
        for i in xrange(args.trainer_count):
X
Xi Chen 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488
            logging.info("starting tread for trainer " + str(i))
            trainer_thread = threading.Thread(
                target=create_and_start_trainer, args=(i, ))
            trainer_thread.start()
            trainer_threads.append(trainer_thread)

        for trainer_thread in trainer_threads:
            trainer_thread.join()

        for result in trainer_create_results:
            if result["has_error"]:
                logging.error(
                    "error during trainer starting or training, destorying the while cluster "
                )
                cleanup(args.task_name)
                break

        logging.info("all trainers stopped")
    except Exception, e:
        logging.info(
            "Training exception, clean up resources, please check log for more info"
        )
    finally:
X
Xi Chen 已提交
489
        cleanup(args.task_name)
X
Xi Chen 已提交
490 491 492


def cleanup(task_name):
X
Xi Chen 已提交
493 494 495
    if args.no_clean_up:
        logging.info("no clean up option set, going to leave the setup running")
        return
X
Xi Chen 已提交
496
    #shutdown all ec2 instances
X
Xi Chen 已提交
497
    print("going to clean up " + task_name + " instances")
X
Xi Chen 已提交
498 499 500
    instances_response = ec2client.describe_instances(Filters=[{
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
501 502 503
    }])

    instance_ids = []
X
Xi Chen 已提交
504 505 506 507
    if len(instances_response["Reservations"]) > 0:
        for reservation in instances_response["Reservations"]:
            for instance in reservation["Instances"]:
                instance_ids.append(instance["InstanceId"])
X
Xi Chen 已提交
508

X
Xi Chen 已提交
509
        ec2client.terminate_instances(InstanceIds=instance_ids)
X
Xi Chen 已提交
510

X
Xi Chen 已提交
511 512 513
        instance_termination_waiter = ec2client.get_waiter(
            'instance_terminated')
        instance_termination_waiter.wait(InstanceIds=instance_ids)
X
Xi Chen 已提交
514

X
Xi Chen 已提交
515
    #delete the subnet created
X
Xi Chen 已提交
516 517

    subnet = ec2client.describe_subnets(Filters=[{
X
Xi Chen 已提交
518 519
        "Name": "tag:Task_name",
        "Values": [task_name]
X
Xi Chen 已提交
520 521
    }])

X
Xi Chen 已提交
522 523
    if len(subnet["Subnets"]) > 0:
        ec2client.delete_subnet(SubnetId=subnet["Subnets"][0]["SubnetId"])
X
Xi Chen 已提交
524
    # no subnet delete waiter, just leave it.
X
Xi Chen 已提交
525
    logging.info("Clearnup done")
X
Xi Chen 已提交
526 527 528
    return


X
Xi Chen 已提交
529 530 531 532 533 534 535 536 537
def kickoff_pserver(host, pserver_endpoints_str):
    try:
        ssh_key = paramiko.RSAKey.from_private_key_file(args.pem_path)
        ssh_client = paramiko.SSHClient()
        ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh_client.connect(hostname=host, username="ubuntu", pkey=ssh_key)
        cmd = (script_to_str(args.pserver_bash_file)).format(
            PSERVER_HOSTS=pserver_endpoints_str,
            DOCKER_IMAGE=args.docker_image,
X
Xi Chen 已提交
538 539
            PSERVER_PORT=args.pserver_port,
            TASK_NAME=args.task_name,
X
Xi Chen 已提交
540
            COMMAND=parse_command(args.pserver_command, {"device": "CPU"}),
X
Xi Chen 已提交
541
            TRAINER_COUNT=args.trainer_count,
542 543 544
            TRAINER_INDEX=0,
            # there is no way to use 0.0.0.0:port to start pserver
            # has to docker --network="host" with host ip to make this work
X
Xi Chen 已提交
545
            SERVER_ENDPOINT=host + ":" + str(args.pserver_port),
X
Xi Chen 已提交
546 547 548
            MASTER_ENDPOINT=args.master_server_ip + ":" +
            str(args.master_server_port))
        logging.info(cmd)
X
Xi Chen 已提交
549
        stdin, stdout, stderr = ssh_client.exec_command(command=cmd)
X
Xi Chen 已提交
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564

        stdout_thread = threading.Thread(
            target=log_to_file, args=(
                stdout,
                "pserver_" + host + ".log", ))
        stderr_thread = threading.Thread(
            target=log_to_file, args=(
                stderr,
                "pserver_" + host + "_err.log", ))
        stdout_thread.start()
        stderr_thread.start()

        stdout_thread.join()
        stderr_thread.join()

X
Xi Chen 已提交
565
        return_code = stdout.channel.recv_exit_status()
X
Xi Chen 已提交
566
        logging.info(return_code)
X
Xi Chen 已提交
567 568
        if return_code != 0:
            raise Exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
569 570
    except Exception:
        logging.exception("Error while kicking off pserver training process")
X
Xi Chen 已提交
571 572 573 574 575
        cleanup(args.task_name)
    finally:
        ssh_client.close()


X
Xi Chen 已提交
576 577
def init_args():

X
Xi Chen 已提交
578 579
    if not args.task_name:
        args.task_name = generate_task_name()
X
Xi Chen 已提交
580
        logging.info("task name generated %s" % (args.task_name))
X
Xi Chen 已提交
581 582 583 584 585 586

    if not args.pem_path:
        args.pem_path = os.path.expanduser("~") + "/" + args.key_name + ".pem"
    if args.security_group_id:
        args.security_group_ids = (args.security_group_id, )

X
Xi Chen 已提交
587 588 589 590 591 592 593 594 595 596 597
    args.trainers_job_done_count = 0


def create_cluster():

    if not args.subnet_id:
        logging.info("creating subnet for this task")
        args.subnet_id = create_subnet()
        logging.info("subnet %s created" % (args.subnet_id))

    logging.info("creating pservers")
X
Xi Chen 已提交
598
    pserver_create_response = create_pservers()
X
Xi Chen 已提交
599
    logging.info("pserver created, collecting pserver ips")
X
Xi Chen 已提交
600 601 602 603 604 605 606 607

    pserver_endpoints = []
    for pserver in pserver_create_response:
        pserver_endpoints.append(pserver["NetworkInterfaces"][0][
            "PrivateIpAddress"] + ":" + args.pserver_port)

    pserver_endpoints_str = ",".join(pserver_endpoints)

X
Xi Chen 已提交
608
    logging.info("kicking off pserver training process")
X
Xi Chen 已提交
609
    pserver_threads = []
X
Xi Chen 已提交
610
    for pserver in pserver_create_response:
X
Xi Chen 已提交
611 612
        pserver_thread = threading.Thread(
            target=kickoff_pserver,
X
Xi Chen 已提交
613
            args=(pserver["PrivateIpAddress"], pserver_endpoints_str))
X
Xi Chen 已提交
614 615 616
        pserver_thread.start()
        pserver_threads.append(pserver_thread)

X
Xi Chen 已提交
617
    logging.info("all pserver training process started")
X
Xi Chen 已提交
618

X
Xi Chen 已提交
619
    logging.info("creating trainers and kicking off trainer training process")
X
Xi Chen 已提交
620 621 622
    create_trainers(
        kickoff_cmd=script_to_str(args.trainer_bash_file),
        pserver_endpoints_str=pserver_endpoints_str)
X
Xi Chen 已提交
623 624 625 626 627

    for pserver_thread in pserver_threads:
        pserver_thread.join()

    logging.info("all process ended")
X
Xi Chen 已提交
628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647


def start_server(args):
    class S(BaseHTTPRequestHandler):
        def _set_headers(self):
            self.send_response(200)
            self.send_header('Content-type', 'text/text')
            self.end_headers()

        def do_HEAD(self):
            self._set_headers()

        def do_404(self):
            self.send_response(404)
            self.send_header('Content-type', 'text/text')
            self.end_headers()
            logging.info("Received invalid GET request" + self.path)
            self.wfile.write("NO ACTION FOUND")

        def do_GET(self):
X
Xi Chen 已提交
648

X
Xi Chen 已提交
649
            request_path = self.path
X
Xi Chen 已提交
650 651
            if request_path == "/status" or request_path == "/master_logs":
                self._set_headers()
X
Xi Chen 已提交
652
                logging.info("Received request to return status")
X
Xi Chen 已提交
653
                with open(args.log_path + "master.log", "r") as logfile:
X
Xi Chen 已提交
654
                    self.wfile.write(logfile.read().strip())
655
            elif request_path == "/list_logs" or request_path == "/logs":
X
Xi Chen 已提交
656 657 658
                self._set_headers()
                self.wfile.write("\n".join(log_files))
            elif "/log/" in request_path:
X
Xi Chen 已提交
659 660 661 662 663
                self._set_headers()
                log_file_path = request_path.replace("/log/", "")
                logging.info("requesting log file path is" + args.log_path +
                             log_file_path)
                with open(args.log_path + log_file_path, "r") as logfile:
X
Xi Chen 已提交
664
                    self.wfile.write(logfile.read().strip())
X
Xi Chen 已提交
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684
            else:
                self.do_404()

        def do_POST(self):

            request_path = self.path

            if request_path == "/save_data":
                self._set_headers()
                logging.info("Received request to save data")
                self.wfile.write("DATA SAVED!")
                content_length = int(self.headers['Content-Length'])
                post_data = self.rfile.read(content_length)
                if args.task_name:
                    with open(args.task_name + ".txt", "a") as text_file:
                        text_file.write(post_data + "\n")

            elif request_path == "/cleanup":
                self._set_headers()
                logging.info("Received request to cleanup cluster")
685
                args.no_clean_up = False
X
Xi Chen 已提交
686 687 688 689 690 691 692 693 694 695
                cleanup(args.task_name)
                self.wfile.write("cleanup in progress")

            else:
                self.do_404()

    server_address = ('', args.master_server_port)
    httpd = HTTPServer(server_address, S)
    logging.info("HTTP server is starting")
    httpd.serve_forever()
X
Xi Chen 已提交
696 697 698


def print_arguments():
X
Xi Chen 已提交
699
    logging.info('-----------  Configuration Arguments -----------')
X
Xi Chen 已提交
700
    for arg, value in sorted(vars(args).iteritems()):
X
Xi Chen 已提交
701 702
        logging.info('%s: %s' % (arg, value))
    logging.info('------------------------------------------------')
X
Xi Chen 已提交
703 704 705 706


if __name__ == "__main__":
    print_arguments()
X
Xi Chen 已提交
707
    if args.action == "create":
X
Xi Chen 已提交
708
        logging.info("going to create cluster")
X
Xi Chen 已提交
709 710
        if not args.key_name or not args.security_group_id:
            raise ValueError("key_name and security_group_id are required")
X
Xi Chen 已提交
711 712
        init_args()
        create_cluster()
X
Xi Chen 已提交
713
    elif args.action == "cleanup":
X
Xi Chen 已提交
714
        logging.info("going to cleanup cluster")
X
Xi Chen 已提交
715 716 717
        if not args.task_name:
            raise ValueError("task_name is required")
        cleanup(args.task_name)
X
Xi Chen 已提交
718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733
    elif args.action == "serve":
        # serve mode
        if not args.master_server_ip:
            raise ValueError(
                "No master server ip set, please run with --action create")

        logging.info("going to start serve and create cluster")

        init_args()

        logging.info("starting server in another thread")
        server_thread = threading.Thread(target=start_server, args=(args, ))
        server_thread.start()

        create_cluster()
        server_thread.join()
X
Xi Chen 已提交
734
    elif args.action == "test":
X
Xi Chen 已提交
735
        start_server(args)