From f3a55f2192b4094e81b3a65b15e3ee00250dd339 Mon Sep 17 00:00:00 2001 From: Xi Chen Date: Fri, 13 Apr 2018 11:20:45 -0700 Subject: [PATCH] add no clean option --- .../client/cluster_launcher.py | 19 ++++++++++++++++++- .../aws_benchmarking/server/cluster_master.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/tools/aws_benchmarking/client/cluster_launcher.py b/tools/aws_benchmarking/client/cluster_launcher.py index 3a6cc57b3..594378ff8 100644 --- a/tools/aws_benchmarking/client/cluster_launcher.py +++ b/tools/aws_benchmarking/client/cluster_launcher.py @@ -26,6 +26,16 @@ import paramiko from scp import SCPClient import requests + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( '--key_name', type=str, default="", help="required, key pair name") @@ -117,6 +127,12 @@ parser.add_argument( default="putcn/paddle_aws_master:latest", help="master docker image id") +parser.add_argument( + '--no_clean_up', + type=str2bool, + default=False, + help="whether to clean up after training") + args = parser.parse_args() logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -347,7 +363,8 @@ def create(): del args_to_pass.master_docker_image del args_to_pass.master_server_public_ip for arg, value in sorted(vars(args_to_pass).iteritems()): - kick_off_cmd += ' --%s %s' % (arg, value) + if value: + kick_off_cmd += ' --%s %s' % (arg, value) logging.info(kick_off_cmd) stdin, stdout, stderr = ssh_client.exec_command(command=kick_off_cmd) diff --git a/tools/aws_benchmarking/server/cluster_master.py b/tools/aws_benchmarking/server/cluster_master.py index 5e63b5a8b..798228b35 100644 --- a/tools/aws_benchmarking/server/cluster_master.py +++ b/tools/aws_benchmarking/server/cluster_master.py @@ -27,8 +27,17 @@ import paramiko from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer + # You must have aws_access_key_id, aws_secret_access_key, region set in # ~/.aws/credentials and ~/.aws/config +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -126,6 +135,12 @@ parser.add_argument( parser.add_argument( '--master_server_ip', type=str, default="", help="master server private ip") +parser.add_argument( + '--no_clean_up', + type=str2bool, + default=False, + help="whether to clean up after training") + args = parser.parse_args() ec2client = boto3.client('ec2') @@ -414,6 +429,9 @@ def create_trainers(kickoff_cmd, pserver_endpoints_str): def cleanup(task_name): + if args.no_clean_up: + logging.info("no clean up option set, going to leave the setup running") + return #shutdown all ec2 instances print("going to clean up " + task_name + " instances") instances_response = ec2client.describe_instances(Filters=[{ -- GitLab