# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """launch train script""" import os import sys import json from argparse import ArgumentParser def parse_args(): """ parse args . Args: Returns: args. Examples: >>> parse_args() """ parser = ArgumentParser(description="mindspore distributed training launch " "helper utilty that will spawn up " "multiple distributed processes") parser.add_argument("--nproc_per_node", type=int, default=1, help="The number of processes to launch on each node, " "for D training, this is recommended to be set " "to the number of D in your system so that " "each process can be bound to a single D.") parser.add_argument("--visible_devices", type=str, default="0,1,2,3,4,5,6,7", help="will use the visible devices sequentially") parser.add_argument("--server_id", type=str, default="", help="server ip") parser.add_argument("--training_script", type=str, help="The full path to the single D training " "program/script to be launched in parallel, " "followed by all the arguments for the " "training script") # rest from the training program args, unknown = parser.parse_known_args() args.training_script_args = unknown return args def main(): print("start", __file__) args = parse_args() print(args) visible_devices = args.visible_devices.split(',') assert os.path.isfile(args.training_script) assert len(visible_devices) >= args.nproc_per_node print('visible_devices:{}'.format(visible_devices)) if not args.server_id: print('pleaser input server ip!!!') exit(0) print('server_id:{}'.format(args.server_id)) # construct hccn_table hccn_configs = open('/etc/hccn.conf', 'r').readlines() device_ips = {} for hccn_item in hccn_configs: hccn_item = hccn_item.strip() if hccn_item.startswith('address_'): device_id, device_ip = hccn_item.split('=') device_id = device_id.split('_')[1] device_ips[device_id] = device_ip print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) hccn_table = {} hccn_table['board_id'] = '0x0000' hccn_table['chip_info'] = '910' hccn_table['deploy_mode'] = 'lab' hccn_table['group_count'] = '1' hccn_table['group_list'] = [] instance_list = [] usable_dev = '' for instance_id in range(args.nproc_per_node): instance = {} instance['devices'] = [] device_id = visible_devices[instance_id] device_ip = device_ips[device_id] usable_dev += str(device_id) instance['devices'].append({ 'device_id': device_id, 'device_ip': device_ip, }) instance['rank_id'] = str(instance_id) instance['server_id'] = args.server_id instance_list.append(instance) hccn_table['group_list'].append({ 'device_num': str(args.nproc_per_node), 'server_num': '1', 'group_name': '', 'instance_count': str(args.nproc_per_node), 'instance_list': instance_list, }) hccn_table['para_plane_nic_location'] = 'device' hccn_table['para_plane_nic_name'] = [] for instance_id in range(args.nproc_per_node): eth_id = visible_devices[instance_id] hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) hccn_table['para_plane_nic_num'] = str(args.nproc_per_node) hccn_table['status'] = 'completed' # save hccn_table to file table_path = os.getcwd() if not os.path.exists(table_path): os.mkdir(table_path) table_fn = os.path.join(table_path, 'rank_table_{}p_{}_{}.json'.format(args.nproc_per_node, usable_dev, args.server_id)) with open(table_fn, 'w') as table_fp: json.dump(hccn_table, table_fp, indent=4) sys.stdout.flush() # spawn the processes for rank_id in range(0, args.nproc_per_node): device_id = visible_devices[rank_id] device_dir = os.path.join(os.getcwd(), 'device{}'.format(rank_id)) rank_process = 'export RANK_SIZE={} && export RANK_ID={} && export DEVICE_ID={} && '.format(args.nproc_per_node, rank_id, device_id) if args.nproc_per_node > 1: rank_process += 'export MINDSPORE_HCCL_CONFIG_PATH={} && '.format(table_fn) rank_process += 'export RANK_TABLE_FILE={} && '.format(table_fn) rank_process += 'rm -rf {dir} && mkdir {dir} && cd {dir} && python {script} '.format(dir=device_dir, script=args.training_script ) rank_process += ' '.join(args.training_script_args) + ' > log{}.log 2>&1 &'.format(rank_id) os.system(rank_process) if __name__ == "__main__": main()