# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License import os import copy import paddle import threading import numpy as np import warnings import logging from functools import reduce import paddle.fluid.core as core from paddle.framework.io import _to_LodTensor from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid.io import is_parameter, is_belong_to_optimizer def is_valid_list_index(list, index): if index >= -len(list) and index < len(list): return True else: return False def is_dim_shard(mapping): if mapping != -1: return True else: return False def is_dim_replicate(mapping): if mapping == -1: return True else: return False def compute_compatible_dim_mapping(dim_mappings): if not dim_mappings: return None compatible_mapping = dim_mappings[0] for mapping in dim_mappings: if compatible_mapping == -1: compatible_mapping = mapping elif mapping == -1: continue elif compatible_mapping == mapping: continue else: return None return compatible_mapping def compute_compatible_dims_mapping(dims_mapping_list): if not dims_mapping_list: return None length = len(dims_mapping_list[0]) for dims_mapping in dims_mapping_list: assert dims_mapping is not None, \ "Dims mapping must not be None for compatible computation" assert len(dims_mapping) == length, \ "The length of dims_mapping in list must be same for compatible computation." compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( list(dim_mappings)) if compatible_dim_mapping is None: return None compatible_result.append(compatible_dim_mapping) return compatible_result def compute_compatible_process_mesh(process_mesh_list): compatible_process_mesh = None if not process_mesh_list: return compatible_process_mesh for process_mesh in process_mesh_list: if process_mesh is not None: if compatible_process_mesh is None or compatible_process_mesh == process_mesh: compatible_process_mesh = process_mesh else: return None return compatible_process_mesh def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): assert len(dims_mapping_list) == len(index_list) changed = False dim_mappings = [] for i in range(len(dims_mapping_list)): assert is_valid_list_index(dims_mapping_list[i], index_list[i]) dim_mappings.append(dims_mapping_list[i][index_list[i]]) compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) if compatible_dim_mapping is None: return False for i in range(len(dims_mapping_list)): if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: dims_mapping_list[i][index_list[i]] = compatible_dim_mapping changed = True return changed def append_distributed_attr_suffix(name): """ Append auto parallel suffix for distributed attribute name. """ return name + core.kAutoParallelSuffix() def remove_distributed_attr_suffix(name): """ Remove auto parallel suffix from distributed attribute name. """ return name.strip(core.kAutoParallelSuffix()) def check_distributed_attr_for_program(program, dist_context=None): from .dist_context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() assert dist_context.is_initialized_for_program(), \ "Distributed attributes must be initialized before check." for block in program.blocks: for tensor in block.vars.values(): dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor) if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): return False for op in block.ops: dist_op = dist_context.get_dist_op_for_graph(tensor) op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if (op_dist_attr is not None) and (not dist_op.is_valid()): return False return True def print_program_with_dist_attr(program, dist_context=None): """ This function reuses the original program output ability with a distributed context. Using lock can avoid multiple threads change the default distributed context simultaneously. """ lock = threading.Lock() lock.acquire() from .dist_context import get_default_distributed_context from .dist_context import set_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() print(program) else: original_default_context = get_default_distributed_context() set_default_distributed_context(dist_context) print(program) set_default_distributed_context(original_default_context) lock.release() def _get_comm_group(processes, shape, axis, rank): """ Given a rank and the processes mesh the rank belongs to, compute the communication peers of the rank based on the give axis in the mesh. Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. the rank communication peers of rank 0 (included) are following: in axis 0: [0, 1] in axis 1: [0, 2] in axis 2: [0, 4] in axis 3: [0, 8] """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous assert rank in processes, "rank [{}] is NOT in processes group {}".format( rank, processes) rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] # select comm group for i in range(shape[axis]): coordinates_in_group[i][axis] = i ranks_in_group_relative = [ _coordinate2linear_idx(shape, coordinate) for coordinate in coordinates_in_group ] ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] return sorted(ranks_in_group) def _get_idx_in_axis(processes, shape, axis, rank): """ Given a rank and the processes mesh the rank belongs to, compute the index of the rank in given axis. Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. the index of rank 22 are: in axis 0: 1 in axis 1: 1 in axis 2: 2 """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) return coordinate[axis] def _coordinate2linear_idx(mesh_shape, coordinate): """ convert a coordinate in multidimensional mesh space into a scala idx in linear space. it use Row-major order for dimension conversion. so it has: [most_significant_dim, ..., least_significant_dim] assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] linear_idx of a n dimensional coordinate is: I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + I[n-3] * ( S[n-4] * .... S[0]) + ... I[1] * ( S[0]) + I[0] """ # NOTE the following function work based on a strong an assumption # that the processes in mesh are # 1. starts from 0 # 2. continuous # it will be wrong if ths above condition doesnot meet, # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} # if you want a more general mapping, you should use cartesian product assert len(mesh_shape) == len( coordinate ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( mesh_shape, coordinate) for i in range(len(mesh_shape)): assert coordinate[ i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( i, coordinate) assert coordinate[i] < mesh_shape[ i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( i, mesh_shape, coordinate) base = mesh_shape[-1] linear_idx = coordinate[-1] # row major order for i in range(len(mesh_shape) - 2, -1, -1): linear_idx += base * coordinate[i] base *= mesh_shape[i] return linear_idx def _linear_idx2coordinate(mesh_shape, linear_idx): """ mapping a linear scala into multidimensional mesh space, return it coordinate in that space. it is the inverse function of _coordinate2linear_idx. assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] the coordinate given linear_idx is: I[0] = linear_idx % S[0] I[0] = (linear_idx / S[0]) % S[1] I[0] = (linear_idx / (S[0] * S[1])) % S[2] .... """ assert linear_idx >= 0, "linear index [{}] is least than zero".format( linear_idx) assert linear_idx < np.prod( mesh_shape ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( mesh_shape, linear_idx) base = 1 coordinate = [-1] * len(mesh_shape) for i in reversed(range(len(mesh_shape))): offset = linear_idx / base coordinate[i] = int(offset % mesh_shape[i]) base *= mesh_shape[i] # row major order return coordinate def _get_corresponding_rank(dist_context, target_mesh, rank): # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. # we assume that all mesh are evenly divide from a parent mesh and should have same size. # to revise this in future. coordinate = None for mesh in dist_context.process_meshes: if rank in mesh.processes and mesh.topology == target_mesh.topology: coordinate = _linear_idx2coordinate(mesh.topology, mesh.processes.index(rank)) break assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( rank) return target_mesh.processes[_coordinate2linear_idx(mesh.topology, coordinate)] def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping mesh = dist_attr.process_mesh.topology assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( var_shape, mapping) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: new_shape.append(var_shape[idx]) else: new_shape.append(var_shape[idx] * mesh[mapping[idx]]) return new_shape def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): from .dist_context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() for var in dist_main_prog.list_vars(): if var.is_data: tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( var) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) var.desc.set_shape(inverse_shape) dim_mapping = tensor_dist_attr.dims_mapping dim_mapping = [-1] * len(dim_mapping) tensor_dist_attr.dims_mapping = dim_mapping dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) def _update_addition_info(addition_info): """ Update default addition_info with inputs """ add_info = {"epoch": 0, "batch": 0, "batch_size": 0} if not addition_info: return add_info elif not isinstance(addition_info, dict): raise TypeError("The type of 'addition_info' should be 'dict', " "but got '{}'.".format(str(type(addition_info)))) else: for item, value in addition_info.items(): if item not in ["epoch", "batch", "batch_size"]: raise ValueError( "The key of 'addition_info' should be one of the " "['epoch', 'batch', 'batch_size'], but got '{}'." .format(str(item))) if not isinstance(value, int): raise ValueError( "The value of 'addition_info' should be 'int', " "but got '{}'.".format(str(type(value)))) add_info[item] = value return add_info def _check_valid_path(file_path): """ Validity check of input file path """ if not file_path: return file_path elif isinstance(file_path, list): for file in file_path: if not isinstance(file, str): raise TypeError("The type of file path should be 'str', " "but got '{}'.".format(str(type(file)))) if not os.path.exists(file): raise ValueError("The file path '{}' does not exist." .format(file)) return file_path else: raise TypeError("The type of file path should be 'list', " "but got '{}'.".format(str(type(file_path)))) def _check_param_dict(param_dict): if not param_dict: raise ValueError("'param_dict' cannot be None.") elif not isinstance(param_dict, dict): raise TypeError("The type of 'param_dict' should be 'dict', " "but got '{}'.".format(str(type(param_dict)))) else: for name, value in param_dict.items(): if not isinstance(name, str): raise TypeError( "The type of key of 'param_dict' should be 'str', " "but got '{}'.".format(str(type(name)))) if not isinstance(value, paddle.fluid.LoDTensor): raise TypeError( "The type of value of 'param_dict' should be 'LoDTensor', " "but got '{}'.".format(str(type(value)))) return param_dict def _check_dist_attr(dist_attr): if not dist_attr: return dist_attr elif not isinstance(dist_attr, dict): raise TypeError("The type of 'dist_attr' should be 'dict', " "but got '{}'.".format(str(type(dist_attr)))) else: for name, value in dist_attr.items(): if not isinstance(name, str): raise TypeError( "The type of param name of 'dist_attr' should be 'str', " "but got '{}'.".format(str(type(name)))) if not isinstance(value, dict): raise TypeError( "The type of distributed attribute should be 'dict', " "but got '{}'".format(str(type(value)))) attr = ['process_shape', 'process_group', 'dims_mapping'] if list(value.keys()) != attr: raise ValueError( "The key of distributed attribute should be " "'['process_shape', 'process_group', 'dims_mapping']', " "but got {}.".format(str(value.keys()))) return dist_attr def save_distributed_checkpoint(program, checkpoint_path, dist_attr_path, addition_info=None, is_integrated=False, dist_context=None): """ Save model parameter state, optimzer state, distributed attribute and additional information of each rank. Args: program(Program): The program to be saved. checkpoint_path(str): The path of the checkpoint file to be saved. dist_attr_path(str): The path of distributed attribute file to be saved. addition_info(dict, optional): Additional information, key should be selected in ['epoch', 'batch', 'batch_size']. Default values are 0, when 'addition_info' is None. Default: None. is_integrated(bool, optional): Whether to integrate param before save. Default: False. dist_context(DistributedContext ,optional): collect related distributed information for program Returns: None Examples: .. code-block:: python path = os.path.join("./output", "step_%d" % step) os.makedirs(path, exist_ok=True) add_info = {'batch': step, "batch_size": global_batch_size} save_distributed_checkpoint(program, path, path, add_info) """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.fluid.framework.Program) assert isinstance(is_integrated, bool) if dist_context is None: dist_context = get_default_distributed_context() addition_info = _update_addition_info(addition_info) if not is_integrated: _save_distributed_state_dict(program, addition_info, checkpoint_path) _save_distributed_attribute(program, dist_attr_path, dist_context) else: # TODO: integrate param before save raise NotImplementedError( "Integrating parameter has not been implemented.") def load_distributed_checkpoint(checkpoint_path, dist_attr_path): """ Load parameter, optimizer, distributed attribute and addition_info. Args: checkpoint_path(list[str]): model parameter file path, must be in order of rank id. dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. Returns: param_dict(dict): parameters' value of all ranks. dist_attr(dict): parameters' distributed attribute. addition_info(dict): additional information user saved in last training. Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. Examples: .. code-block:: python ckpt_path = ['./model_state_rank0.pdmodel', './model_state_rank1.pdmodel'] dist_attr_path = ['./dist_attr_rank0.pdattr', './dist_attr_rank1.pdattr'] param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) """ assert _check_valid_path(checkpoint_path), \ "'checkpoint_path' cannot be None." assert _check_valid_path(dist_attr_path), \ "'dist_attr_path' cannot be None." state_dict_info = _load_distributed_state_dict(checkpoint_path) dist_attr = _load_distributed_attribute(dist_attr_path) param_dict = state_dict_info["model"] addition_info = state_dict_info["addition_info"] return param_dict, dist_attr, addition_info def load_checkpoint_into_program(checkpoint_path, dist_attr_path, program, dist_context=None): """ Load parameter, optimizer, distributed attribute and addition_info into model. Args: checkpoint_path(list[str]): model parameter file path, must be in order of rank id. dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. program(Program): the program to be updated with checkpoint_path. dist_context(DistributedContext ,optional): collect related distributed information for program Returns: addition_info(dict): user saved in last train. Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. Examples: .. code-block:: python exe.run(startup_program) ckpt_path = ['./model_state_rank0.pdmodel', './model_state_rank1.pdmodel'] dist_attr_path = ['./dist_attr_rank0.pdattr', './dist_attr_rank1.pdattr'] load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.fluid.framework.Program) assert _check_valid_path(checkpoint_path), \ "'checkpoint_path' cannot be None." assert _check_valid_path(dist_attr_path), \ "'dist_attr_path' cannot be None." if dist_context is None: dist_context = get_default_distributed_context() all_state_dict_info = _load_distributed_state_dict(checkpoint_path) all_pre_dist_attr = _load_distributed_attribute(dist_attr_path) all_cur_dist_attr = get_dist_attr(program, dist_context) all_param_dict = all_state_dict_info["model"] addition_info = all_state_dict_info["addition_info"] sliced_param_dict = merge_and_slice_parameter( all_param_dict, all_pre_dist_attr, all_cur_dist_attr) load_parameter_into_program(sliced_param_dict, program) return addition_info def load_parameter_into_program(param_dict, program): """ Load parameters into program. Args: param_dict(dict): parameters' name and value. program(Program): the program to be updated """ _check_param_dict(param_dict) assert program and isinstance(program, paddle.fluid.framework.Program) program.set_state_dict(param_dict) def _save_distributed_attribute(program, dist_attr_path, dist_context): """ Save distributed attribute of all parameters """ # TODO: just save a complete distributed attribute file rank_id = paddle.distributed.get_rank() dist_attr_name = os.path.join(dist_attr_path, "dist_attr_rank{}.pdattr".format(rank_id)) dist_attr_dict = { "model": get_dist_attr(program, dist_context), "world_size": paddle.distributed.get_world_size() } paddle.save(dist_attr_dict, dist_attr_name) logging.info("Already saved distributed attribute to '{}'.".format( dist_attr_path)) def _load_distributed_attribute(dist_attr_path): """ Load parameters' distributed attribute from dist_attr_path """ total_dist_attr = {} for dist_attr_file in dist_attr_path: dist_attr = paddle.load(dist_attr_file) pre_world_size = dist_attr["world_size"] assert pre_world_size == len(dist_attr_path), \ "The number of 'dist_attr_path' must be equal to the last training world size." for name, attr in dist_attr["model"].items(): if name not in total_dist_attr: total_dist_attr[name] = attr return total_dist_attr def _save_distributed_state_dict(program, addition_info, checkpoint_path): """ Save parameters' state_dict """ rank = paddle.distributed.get_rank() ckpt_file_name = os.path.join(checkpoint_path, "model_state_rank{}.pdmodel".format(rank)) state_dict = { "model": program.state_dict(), "world_size": paddle.distributed.get_world_size(), "addition_info": addition_info } paddle.save(state_dict, ckpt_file_name) logging.info("Already saved model to '{}'.".format(checkpoint_path)) def _load_distributed_state_dict(checkpoint_path): """ Load parameters' state_dict from checkpoint_path """ all_state_dict = {} for idx, ckpt_file in enumerate(checkpoint_path): state_dict_info = paddle.load(ckpt_file, return_numpy=True) pre_world_size = state_dict_info["world_size"] assert pre_world_size == len(checkpoint_path), \ "The number of 'checkpoint_path' must be equal to the last training world size." if idx == 0: addition_info = state_dict_info["addition_info"] for name, value in state_dict_info["model"].items(): if name in all_state_dict: all_state_dict[name].append(np.array(value)) else: all_state_dict[name] = [np.array(value)] all_state_dict_info = { "model": all_state_dict, "addition_info": addition_info } return all_state_dict_info def get_dist_attr(program, dist_context=None): """ Get distributed attribute of current rank. Args: program(Program): main program for training """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.fluid.framework.Program) if dist_context is None: dist_context = get_default_distributed_context() dist_attr = {} for var in program.list_vars(): if is_parameter(var) or is_belong_to_optimizer(var): tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( var) process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping dist_attr[var.name] = { "process_shape": process_mesh.topology, "process_group": process_mesh.processes, "dims_mapping": dims_mapping } return dist_attr def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): """ Merge parameters with previous dist_attr and slice parameters with current dist_attr Arags: dist_param_dict(dict): parameters' value of all ranks. pre_dist_attr(dict): parameters' dist_attr of last training process. cur_dist_attr(dict): parameters' dist_attr of current training process. Returns: dist_param_dict(dict): parameters' value of current rank. """ assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." assert _check_dist_attr(cur_dist_attr), "'pre_dist_attr' cannot be None." assert isinstance(dist_param_dict, dict), \ "The type of 'dist_param_dict' should be 'dict', but got {}.".format( str(type(dist_param_dict))) for name, value in dist_param_dict.items(): if not isinstance(name, str): raise TypeError("The key of 'dist_param_dict' is parameter's name, " "and its type should be 'str', but got {}." .format(str(type(name)))) if not isinstance(value, list) or not all( isinstance(v, np.ndarray) for v in value): raise TypeError( "The value of 'dist_param_dict' is parameter's value of all ranks, " "and its type should be 'list(numpy.ndarray)'.") param_not_in_pre = [] param_not_in_cur = [] logging.info("Start to merge and slice parameters.") for var_name in cur_dist_attr.keys(): if var_name not in pre_dist_attr: param_not_in_pre.append(var_name) continue pre_attr = pre_dist_attr[var_name] cur_attr = cur_dist_attr[var_name] if pre_attr == cur_attr: # skip merge and slice rank_id = paddle.distributed.get_rank() index = cur_attr["process_group"].index(rank_id) param = dist_param_dict[var_name][index] dist_param_dict[var_name] = _to_LodTensor(param) continue pre_param = dist_param_dict[var_name] pre_dims_mapping = pre_attr["dims_mapping"] cur_dims_mapping = cur_attr["dims_mapping"] if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: complete_param = _merge_parameter_with_dist_attr(pre_param, pre_attr) dist_param_dict[var_name] = complete_param else: complete_param = pre_param[0] dist_param_dict[var_name] = _to_LodTensor(complete_param) if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: sliced_param = _slice_parameter_with_dist_attr(complete_param, cur_attr) dist_param_dict[var_name] = sliced_param for var_name in pre_dist_attr: if var_name not in cur_dist_attr: param_not_in_cur.append(var_name) dist_param_dict.pop(var_name) if param_not_in_pre: warnings.warn("Parameters '{}' are not found in last training process." .format(str(param_not_in_pre))) if param_not_in_cur: warnings.warn( "Parameters '{}' are not found in current training process." .format(str(param_not_in_cur))) return dist_param_dict def _merge_parameter_with_dist_attr(param_list, dist_attr): """ Merge parameter with distributed attribute """ from .reshard import _compute_complete_shape, _compute_partition_index dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # get the complete shape of the parameter complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, dims_mapping) # merge the parameter with dist_attr partition_param_list = [] merged_partiton = [] for process in process_group: partition_index = _compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group) index = process_group.index(process) if partition_index not in merged_partiton: merged_partiton.append(partition_index) _merge_parameter(partition_param_list, param_list[index], partition_index, complete_shape) assert len(partition_param_list) == 1 or not partition_param_list, \ "Fail to merge parameter" complete_param = _to_LodTensor(partition_param_list[0][0]) return complete_param def _slice_parameter_with_dist_attr(param, dist_attr): """ Slice parameter with distributed attribute """ param = np.array(param) if isinstance(param, paddle.fluid.LoDTensor) else param dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # slice the parameter with dist_attr partition_index_list = _get_split_indices(param.shape, dims_mapping, process_shape, process_group) sliced_param_list = _slice_parameter(param, partition_index_list, len(partition_index_list)) # get the current parameter's index in sliced_param_list rank_id = paddle.distributed.get_rank() sliced_param_index = _get_sliced_param_index( rank_id, param.shape, dims_mapping, process_shape, process_group) sliced_param = _to_LodTensor(sliced_param_list[sliced_param_index]) return sliced_param def _merge_parameter(partition_param_list, param, partition_index, complete_shape): """ Merge partitial parameters to a complete one. Returns: None Examples: .. code-block:: python import numpy as np partition_param_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])] param = np.array([[[1.13, 1.14]]]) partition_index = [[0,1],[0,1],[2,4]] _merge_parameter(partition_param_list, param, partition_index) # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] """ from .reshard import _compute_concat_info if len(partition_param_list) == 1: is_complete_data = True for idx, item in enumerate(partition_param_list[0][1]): if item[0] != 0 or item[1] != complete_shape[idx]: is_complete_data = False break if is_complete_data: return if not partition_param_list: partition_param_list.append((param, partition_index)) else: i = 0 while i < len(partition_param_list): concat_axis, first_order, new_partition = _compute_concat_info( partition_param_list[i][1], partition_index) if concat_axis != -1: if first_order == 0: new_param = np.concatenate( (partition_param_list[i][0], param), axis=concat_axis) else: new_param = np.concatenate( (param, partition_param_list[i][0]), axis=concat_axis) partition_param_list.pop(i) _merge_parameter(partition_param_list, new_param, new_partition, complete_shape) break i += 1 def _slice_parameter(complete_param, partition_index_list, length): """ Slice a complete parameter. Returns: sliced_param_list(list): sliced parameters with 'partition_index_list' Examples: .. code-block:: python import numpy as np complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) rank = 2 complete_shape = [1, 1, 6] dims_mapping = [-1, -1, 0] process_shape = [3] process_group = [0, 1, 2] sliced_param_list = _slice_parameter(complete_param, [[], [], [2, 4]], 3) # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] """ sliced_param_list = [] axis = len(complete_param.shape) - length sliced_param = np.split( complete_param, partition_index_list[axis], axis=axis) if length == 1: return sliced_param for param in sliced_param: sliced_param_list.extend( _slice_parameter(param, partition_index_list, length - 1)) return sliced_param_list def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, process_group): """ Get sliced_param's index of current rank in all sliced parameters list. Returns: sliced_param_index(int): the index of sliced param in sliced_param_list Examples: .. code-block:: python import numpy as np complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) rank = 2 complete_shape = [1, 1, 6] dims_mapping = [-1, -1, 0] process_shape = [3] process_group = [0, 1, 2] slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) # slice_param: # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] index = _get_sliced_param_index(rank, complete_shape, dims_mapping process_shape, process_group) # index: 2 """ from .reshard import _compute_partition_index partition_index = _compute_partition_index( rank, complete_shape, dims_mapping, process_shape, process_group) sliced_param_index = 0 for i, shape in enumerate(complete_shape): if dims_mapping[i] == -1: slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] if shape == 1: index = 0 else: index = (partition_index[i][0] + 1) // slice_shape sliced_param_index = sliced_param_index * (shape // slice_shape) + index return sliced_param_index def _get_split_indices(complete_shape, dims_mapping, process_shape, process_group): """ Get split indices of every dimension. Returns: split_indices_list(list): the split indices of every dimension of the parameter Examples: .. code-block:: python import numpy as np complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) complete_shape = [1, 1, 6] dims_mapping = [-1, -1, 0] process_shape = [3] process_group = [0, 1, 2] index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) # index: [[], [], [2, 4]] """ from .reshard import _compute_partition_index split_indices_list = [] for process in process_group: partition_index = _compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group) if split_indices_list: for dim in range(len(partition_index)): split_indices_list[dim].extend(partition_index[dim]) else: split_indices_list = partition_index split_indices_list = list( map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, complete_shape)) split_indices_list = [sorted(x) for x in split_indices_list] return split_indices_list def set_grad_var_shape(program, dist_context): from .operators.common import infer_shape from paddle.distributed.fleet.meta_optimizers.common import OpRole block = program.global_block() vars = block.vars for op in block.ops: if op.type == "sum": continue if int(op.attr('op_role')) == int(OpRole.Backward): op_dist_attr = dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None for var_name in op.output_arg_names: assert "@GRAD" in var_name forward_var_name = var_name[:var_name.find("@GRAD")] if op.type == "c_allreduce_sum" or op.type == "c_identity" or op.type == "scale": forward_var_name = op.input_arg_names[0] need_set_shape_list = [ "reshape2_grad", "softmax_with_cross_entropy_grad", "transpose2_grad", "softmax_grad", "cross_entropy_grad2", "dropout_grad", "unsqueeze2_grad" ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", "softmax", "cross_entropy2", "dropout", "unsqueeze2" ] if op.type in need_set_shape_list: for forward_op in block.ops: assert int(forward_op.attr('op_role')) != int( OpRole.Backward) idx = need_set_shape_list.index(op.type) forward_op_name = forward_list[idx] if forward_op.type == forward_op_name and forward_var_name in forward_op.input_arg_names: op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) break forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_var_name) assert forward_input_dist_attr is not None, f"{forward_var_name}" forward_var = vars[forward_var_name] forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program( forward_var) assert forward_var_dist_attr is not None grad_var = vars[var_name] ref_shape = infer_shape(block, forward_var, forward_var_dist_attr, forward_input_dist_attr) if list(grad_var.shape) != ref_shape: grad_var.desc.set_shape(ref_shape) def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr op_desc = dist_op.serial_op.desc # The following statement will be replaced by a more elegent way if op_desc.type() == "shape" or op_desc.type() == "slice": return False output_names = op_desc.output_names() xshape_arg_names = [] if "XShape" in output_names: xshape_arg_names = op_desc.output("XShape") batch_dim_mappings = [] for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): assert mapping == -1, \ "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ .format(op_desc.type(), idx, mapping) batch_dim_mappings.append(dims_mapping[0]) for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): assert mapping == -1, \ "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ .format(op_desc.type(), idx, mapping) batch_dim_mappings.append(dims_mapping[0]) else: assert dims_mapping[0] == -1, \ "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ .format(op_desc.type(), mapping) if len(dims_mapping) > 2: for idx, mapping in enumerate(dims_mapping[2:]): assert mapping == -1, \ "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ .format(op_desc.type(), idx, mapping) batch_dim_mappings.append(dims_mapping[1]) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) assert compatible_dim_mapping is not None, "There is no compatible dim mapping." for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if compatible_dim_mapping != dims_mapping[0]: dims_mapping[0] = compatible_dim_mapping changed = True for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if compatible_dim_mapping != dims_mapping[0]: dims_mapping[0] = compatible_dim_mapping changed = True else: if compatible_dim_mapping != dims_mapping[1]: dims_mapping[1] = compatible_dim_mapping changed = True return changed def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr op_desc = dist_op.serial_op.desc input_arg_names = op_desc.input_arg_names() input_dims_mapping_dict = {} input_dims_mapping_lens = {} max_dims_mapping_len = -1 for arg_name in input_arg_names: dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if max_dims_mapping_len < len(dims_mapping): max_dims_mapping_len = len(dims_mapping) input_dims_mapping_dict[arg_name] = dims_mapping input_dims_mapping_lens[arg_name] = len(dims_mapping) dims_mapping_list = [] for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] for i in range(input_dims_mapping_lens[arg_name]): new_idx = (max_dims_mapping_len - input_dims_mapping_lens[arg_name]) + i new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] dims_mapping_list.append(new_dims_mapping) else: dims_mapping_list.append(input_dims_mapping_dict[arg_name]) output_arg_names = op_desc.output_arg_names() for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) assert len(dims_mapping) == max_dims_mapping_len dims_mapping_list.append(dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) assert compatible_dims_mapping is not None, "There is no compatible dim mapping." for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: new_dims_mapping = [ -1 for _ in range(input_dims_mapping_lens[arg_name]) ] for i in range(input_dims_mapping_lens[arg_name]): new_idx = (max_dims_mapping_len - input_dims_mapping_lens[arg_name]) + i new_dims_mapping[i] = compatible_dims_mapping[new_idx] if new_dims_mapping != input_dims_mapping_dict[arg_name]: op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) changed = True else: if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: op_dist_attr.set_input_dims_mapping(arg_name, compatible_dims_mapping) changed = True for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if compatible_dims_mapping != dims_mapping: op_dist_attr.set_output_dims_mapping(arg_name, compatible_dims_mapping) changed = True return changed def get_all_distributed_main_program(serial_program_info, dist_context): "Get all distributed main programs by dist_context." from .dist_context import DistributedOperatorContext cluster = serial_program_info.cluster all_dist_main_program = [] ranks = paddle.distributed.get_world_size() if cluster is None else len( cluster.get_all_devices("GPU")) for rank_id in range(ranks): used_dist_context = copy.deepcopy(dist_context) used_dist_context._dist_op_context = DistributedOperatorContext() dist_main_program, dist_startup_program = get_specified_distributed_main_program( serial_program_info, used_dist_context, rank_id) all_dist_main_program.append(dist_main_program) return all_dist_main_program def get_specified_distributed_main_program(serial_program_info, dist_context, rank_id): "Get distributed main program by the given dist_context and rank_id." from .partitioner import Partitioner from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER from .process_group import _g_process_group_map, ProcessGroup dist_strategy = paddle.distributed.fleet.DistributedStrategy() train_program = serial_program_info.train_program startup_program = serial_program_info.startup_program loss = serial_program_info.loss optimizer = serial_program_info.optimizer partitioner = Partitioner(dist_strategy, dist_context, rank_id) dist_main_program, dist_startup_program = partitioner.transpile_forward( train_program, startup_program) dist_params_grads = partitioner.apply_backward( loss, train_program, startup_program, dist_main_program, dist_startup_program) opt_ops = partitioner.apply_optimize( copy.deepcopy(optimizer), dist_params_grads, dist_main_program, dist_startup_program) set_grad_var_shape(dist_main_program, dist_context) make_data_unshard(dist_main_program, dist_startup_program, dist_context) reshard(dist_main_program, dist_startup_program, rank_id, dist_context) HAS_SENT.clear() HAS_RECV.clear() HAS_ALLGATHER.clear() _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) return dist_main_program, dist_startup_program class SerialProgramInfo: def __init__(self, train_program, satrtup_program, loss, optimizer, cluster=None): self._train_program = train_program self._startup_program = satrtup_program self._loss = loss self._optimizer = optimizer self._cluster = cluster @property def train_program(self): return self._train_program @property def startup_program(self): return self._startup_program @property def loss(self): return self._loss @property def optimizer(self): return self._optimizer @property def cluster(self): return self._cluster def get_standalone_cost_data(distributed_programs): def _compute_runtime(op_cost, op, vars): runtime = 0 try: runtime = float(op_cost["op_time"]) except: return runtime op_config = op_cost["config"] total_static_input_size = 0 total_actual_input_size = 0 parsed_info = op_config.split("\n") variable = "(Variable)" for info in parsed_info: variable = "(Variable)" if "(Variable)" in info else "(list" if variable in info: arg_name_lower = info[:info.find(variable) - 1] shape_left_boundary = info.find("[") shape_right_boundary = info.find("]") assert shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary, "Get shape failed." shape = info[shape_left_boundary + 1: shape_right_boundary].split(",") shape = list(map(lambda x: int(x.strip()), shape)) dtype_factor = 1 total_static_input_size += reduce(lambda x, y: x * y, shape) # print(arg_name_lower) if op.type == "c_embedding": arg_name_lower = "w" if arg_name_lower == "weight" else "ids" for arg_name in op.input_names: if arg_name.lower() == arg_name_lower: for var_name in op.input(arg_name): var = vars[var_name] total_actual_input_size += reduce( lambda x, y: x * y, var.shape) break assert total_static_input_size > 0 and total_actual_input_size > 0, "Get input size failed." actual_runtime = total_actual_input_size / total_static_input_size * runtime return actual_runtime cost_model = paddle.cost_model.CostModel() cost_model.static_cost_data() DEFAULT_MULTIPLE = 2 OP_NAME_MAPPING = { "c_embedding": "embedding", "matmul_v2": "matmul", "transpose2": "transpose", "reshape2": "reshape", "unsqueeze2": "unsqueeze", "reduce_sum": "sum", "elementwise_div": "divide" } standalone_cost_data = [] not_enum_ops = ["create_py_reader", "create_double_buffer_reader", "read"] for distributed_program in distributed_programs: cost_data = {} vars = distributed_program.global_block().vars for op in distributed_program.global_block().ops: runtime = 0 if op.type in not_enum_ops: cost_data[op.desc.id()] = runtime continue dtype = str(vars[op.input_arg_names[0]] .dtype) if op.input_arg_names else "float32" if int(op.attr('op_role')) == int(OpRole.Backward): if "_grad" in op.type: forward_op_name = op.type[:-5] if forward_op_name in OP_NAME_MAPPING.keys(): forward_op_name = OP_NAME_MAPPING[forward_op_name] op_cost = cost_model.get_static_op_time( forward_op_name, forward=False, dtype=dtype) if op_cost: runtime = _compute_runtime(op_cost, op, vars) else: op_cost = cost_model.get_static_op_time( forward_op_name, dtype=dtype) if op_cost: runtime = 2 * _compute_runtime(op_cost, op, vars) elif int(op.attr('op_role')) == int(OpRole.Forward): op_name = OP_NAME_MAPPING[ op.type] if op.type in OP_NAME_MAPPING.keys() else op.type op_cost = cost_model.get_static_op_time(op_name) if op_cost: runtime = _compute_runtime(op_cost, op, vars) cost_data[op.desc.id()] = runtime standalone_cost_data.append(cost_data) return standalone_cost_data