launch.py 6.1 KB
Newer Older
W
wandongdong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""launch train script"""
import os
import sys
import json
W
wandongdong 已提交
19
import subprocess
W
wandongdong 已提交
20
import shutil
W
wandongdong 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
from argparse import ArgumentParser

def parse_args():
    """
    parse args .

    Args:

    Returns:
        args.

    Examples:
        >>> parse_args()
    """
    parser = ArgumentParser(description="mindspore distributed training launch "
                                        "helper utilty that will spawn up "
                                        "multiple distributed processes")
    parser.add_argument("--nproc_per_node", type=int, default=1,
                        help="The number of processes to launch on each node, "
                             "for D training, this is recommended to be set "
                             "to the number of D in your system so that "
                             "each process can be bound to a single D.")
    parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7",
                        help="will use the visible devices sequentially")
    parser.add_argument("--server_id", type=str, default="",
                        help="server ip")
    parser.add_argument("--training_script", type=str,
                        help="The full path to the single D training "
                             "program/script to be launched in parallel, "
                             "followed by all the arguments for the "
                             "training script")
    # rest from the training program
    args, unknown = parser.parse_known_args()
    args.training_script_args = unknown
    return args


def main():
    print("start", __file__)
    args = parse_args()
    print(args)
    visible_devices = args.visible_devices.split(',')
    assert os.path.isfile(args.training_script)
    assert len(visible_devices) >= args.nproc_per_node
    print('visible_devices:{}'.format(visible_devices))
    if not args.server_id:
        print('pleaser input server ip!!!')
        exit(0)
    print('server_id:{}'.format(args.server_id))

    # construct hccn_table
    hccn_configs = open('/etc/hccn.conf', 'r').readlines()
    device_ips = {}
    for hccn_item in hccn_configs:
        hccn_item = hccn_item.strip()
        if hccn_item.startswith('address_'):
            device_id, device_ip = hccn_item.split('=')
            device_id = device_id.split('_')[1]
            device_ips[device_id] = device_ip
            print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
    hccn_table = {}
    hccn_table['board_id'] = '0x0000'
    hccn_table['chip_info'] = '910'
    hccn_table['deploy_mode'] = 'lab'
    hccn_table['group_count'] = '1'
    hccn_table['group_list'] = []
    instance_list = []
    usable_dev = ''
    for instance_id in range(args.nproc_per_node):
        instance = {}
        instance['devices'] = []
        device_id = visible_devices[instance_id]
        device_ip = device_ips[device_id]
        usable_dev += str(device_id)
        instance['devices'].append({
            'device_id': device_id,
            'device_ip': device_ip,
        })
        instance['rank_id'] = str(instance_id)
        instance['server_id'] = args.server_id
        instance_list.append(instance)
    hccn_table['group_list'].append({
        'device_num': str(args.nproc_per_node),
        'server_num': '1',
        'group_name': '',
        'instance_count': str(args.nproc_per_node),
        'instance_list': instance_list,
    })
    hccn_table['para_plane_nic_location'] = 'device'
    hccn_table['para_plane_nic_name'] = []
    for instance_id in range(args.nproc_per_node):
        eth_id = visible_devices[instance_id]
        hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
    hccn_table['para_plane_nic_num'] = str(args.nproc_per_node)
    hccn_table['status'] = 'completed'

    # save hccn_table to file
    table_path = os.getcwd()
    if not os.path.exists(table_path):
        os.mkdir(table_path)
    table_fn = os.path.join(table_path,
                            'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id))
    with open(table_fn, 'w') as table_fp:
        json.dump(hccn_table, table_fp, indent=4)
    sys.stdout.flush()

    # spawn the processes
W
wandongdong 已提交
128 129
    processes = []
    cmds = []
W
wandongdong 已提交
130 131 132
    log_files = []
    env = os.environ.copy()
    env['RANK_SIZE'] = str(args.nproc_per_node)
W
wandongdong 已提交
133
    cur_path = os.getcwd()
W
wandongdong 已提交
134
    for rank_id in range(0, args.nproc_per_node):
W
wandongdong 已提交
135
        os.chdir(cur_path)
136
        device_id = visible_devices[rank_id]
W
wandongdong 已提交
137
        device_dir = os.path.join(cur_path, 'device{}'.format(rank_id))
W
wandongdong 已提交
138 139
        env['RANK_ID'] = str(rank_id)
        env['DEVICE_ID'] = str(device_id)
140
        if args.nproc_per_node > 1:
W
wandongdong 已提交
141 142 143 144 145
            env['MINDSPORE_HCCL_CONFIG_PATH'] = table_fn
            env['RANK_TABLE_FILE'] = table_fn
        if os.path.exists(device_dir):
            shutil.rmtree(device_dir)
        os.mkdir(device_dir)
W
wandongdong 已提交
146
        os.chdir(device_dir)
W
wandongdong 已提交
147 148 149 150
        cmd = [sys.executable, '-u']
        cmd.append(args.training_script)
        cmd.extend(args.training_script_args)
        log_file = open('{dir}/log{id}.log'.format(dir=device_dir, id=rank_id), 'w')
W
wandongdong 已提交
151
        process = subprocess.Popen(cmd, stdout=log_file, stderr=log_file, env=env)
W
wandongdong 已提交
152
        processes.append(process)
W
wandongdong 已提交
153 154 155
        cmds.append(cmd)
        log_files.append(log_file)
    for process, cmd, log_file in zip(processes, cmds, log_files):
W
wandongdong 已提交
156 157 158
        process.wait()
        if process.returncode != 0:
            raise subprocess.CalledProcessError(returncode=process, cmd=cmd)
W
wandongdong 已提交
159
        log_file.close()
W
wandongdong 已提交
160 161 162 163


if __name__ == "__main__":
    main()