# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License import copy import logging import os import threading import warnings from functools import reduce import numpy as np import paddle from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.framework import core from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.static import Variable from ..process_mesh import ProcessMesh from .dist_attribute import DistTensorSpec, OperatorDistAttr, TensorDistAttr OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() __no_shape_var_type__ = [ core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES, core.VarDesc.VarType.LOD_TENSOR_ARRAY, core.VarDesc.VarType.FEED_MINIBATCH, core.VarDesc.VarType.FETCH_LIST, ] __not_naive_data_parallel_op__ = ["expand_v2"] def get_logger(log_level, name="auto_parallel"): logger = logging.getLogger(name) logger.propagate = False if not logger.handlers: logger.setLevel(log_level) log_handler = logging.StreamHandler() log_format = logging.Formatter( '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' ) log_handler.setFormatter(log_format) logger.addHandler(log_handler) else: logger.setLevel(log_level) return logger def is_valid_list_index(list, index): if index >= -len(list) and index < len(list): return True else: return False def is_dim_shard(mapping): if mapping != -1: return True else: return False def is_dim_replicate(mapping): if mapping == -1: return True else: return False def verify_dims_mapping(dims_mapping, process_mesh): if dims_mapping is None: return False if not all(isinstance(d, int) for d in dims_mapping): return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape): return False for i in range(len(process_mesh.shape)): if dims_mapping.count(i) > 1: return False return True def convert_to_dims_mapping(shard_spec, process_mesh): dims_mapping = [] for shard in shard_spec: if shard is None: dims_mapping.append(-1) elif process_mesh.shape[process_mesh.dim_names.index(shard)] == 1: dims_mapping.append(-1) else: dims_mapping.append(process_mesh.dim_names.index(shard)) return dims_mapping def convert_to_shard_spec(dims_mapping, process_mesh): shard_spec = [] for dim_mapping in dims_mapping: if dim_mapping == -1: shard_spec.append(None) else: shard_spec.append(process_mesh.dim_names[dim_mapping]) return shard_spec def verify_shard_spec(shard_spec, tensor_shape, process_mesh): if len(shard_spec) != len(tensor_shape): return False for shard in shard_spec: if shard is not None and not isinstance(shard, str): return False if shard is not None and shard not in process_mesh.dim_names: return False dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh) if not verify_dims_mapping(dims_mapping, process_mesh): return False for i in range(len(tensor_shape)): if ( dims_mapping[i] != -1 and tensor_shape[i] > 0 and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0 ): return False return True def compute_compatible_dim_mapping(dim_mappings): if not dim_mappings: return None compatible_mapping = dim_mappings[0] for mapping in dim_mappings: if compatible_mapping == -1: compatible_mapping = mapping elif mapping == -1: continue elif compatible_mapping == mapping: continue else: return None return compatible_mapping def compute_compatible_dims_mapping(dims_mapping_list): if not dims_mapping_list: return None length = len(dims_mapping_list[0]) for dims_mapping in dims_mapping_list: assert ( dims_mapping is not None ), "Dims mapping must not be None for compatible computation" assert ( len(dims_mapping) == length ), "The length of dims_mapping in list must be same for compatible computation." compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( list(dim_mappings) ) if compatible_dim_mapping is None: return None compatible_result.append(compatible_dim_mapping) return compatible_result def compute_compatible_process_mesh(process_mesh_list): compatible_process_mesh = None if not process_mesh_list: return compatible_process_mesh for process_mesh in process_mesh_list: if process_mesh is not None: if ( compatible_process_mesh is None or compatible_process_mesh == process_mesh ): compatible_process_mesh = process_mesh else: return None return compatible_process_mesh def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): assert len(dims_mapping_list) == len(index_list) changed = False dim_mappings = [] for i in range(len(dims_mapping_list)): assert is_valid_list_index(dims_mapping_list[i], index_list[i]) dim_mappings.append(dims_mapping_list[i][index_list[i]]) compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) if compatible_dim_mapping is None: return False for i in range(len(dims_mapping_list)): if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: dims_mapping_list[i][index_list[i]] = compatible_dim_mapping changed = True return changed def append_distributed_attr_suffix(name): """ Append auto parallel suffix for distributed attribute name. """ return name + core.kAutoParallelSuffix() def remove_distributed_attr_suffix(name): """ Remove auto parallel suffix from distributed attribute name. """ return name.strip(core.kAutoParallelSuffix()) def check_distributed_attr_for_program(program, dist_context=None): from .dist_context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() assert ( dist_context.is_initialized_for_program() ), "Distributed attributes must be initialized before check." for block in program.blocks: for tensor in block.vars.values(): dist_tensor = dist_context.get_dist_tensor_for_graph(tensor) tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( tensor ) if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()): return False for op in block.ops: dist_op = dist_context.get_dist_op_for_graph(tensor) op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if (op_dist_attr is not None) and (not dist_op.is_valid()): return False return True def print_program_with_dist_attr(program, dist_context=None): """ This function reuses the original program output ability with a distributed context. Using lock can avoid multiple threads change the default distributed context simultaneously. """ lock = threading.Lock() lock.acquire() from .dist_context import ( get_default_distributed_context, set_default_distributed_context, ) if dist_context is None: dist_context = get_default_distributed_context() print(program, flush=True) else: original_default_context = get_default_distributed_context() set_default_distributed_context(dist_context) print(program, flush=True) set_default_distributed_context(original_default_context) lock.release() def _get_comm_group(processes, shape, axis, rank): """ Given a rank and the processes mesh the rank belongs to, compute the communication peers of the rank based on the give axis in the mesh. Example: 16 processes managed in a 4-Dimensional mesh with shape of [2, 2, 2, 2]. the rank communication peers of rank 0 (included) are following: in axis 0: [0, 1] in axis 1: [0, 2] in axis 2: [0, 4] in axis 3: [0, 8] """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous assert rank in processes, "rank [{}] is NOT in processes group {}".format( rank, processes ) rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] # select comm group for i in range(shape[axis]): coordinates_in_group[i][axis] = i ranks_in_group_relative = [ _coordinate2linear_idx(shape, coordinate) for coordinate in coordinates_in_group ] ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] return sorted(ranks_in_group) def _get_idx_in_axis(processes, shape, axis, rank): """ Given a rank and the processes mesh the rank belongs to, compute the index of the rank in given axis. Example: 27 processes managed in a 3-Dimensinal mesh with shape of [3, 3, 3]. the index of rank 22 are: in axis 0: 1 in axis 1: 1 in axis 2: 2 """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) return coordinate[axis] def _coordinate2linear_idx(mesh_shape, coordinate): """ convert a coordinate in multidimensional mesh space into a scala idx in linear space. it use Row-major order for dimension conversion. so it has: [most_significant_dim, ..., least_significant_dim] assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] linear_idx of a n dimensional coordinate is: I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + I[n-3] * ( S[n-4] * .... S[0]) + ... I[1] * ( S[0]) + I[0] """ # NOTE the following function work based on a strong an assumption # that the processes in mesh are # 1. starts from 0 # 2. continuous # it will be wrong if ths above condition does not meet, # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} # if you want a more general mapping, you should use cartesian product assert len(mesh_shape) == len( coordinate ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( mesh_shape, coordinate ) for i in range(len(mesh_shape)): assert ( coordinate[i] >= 0 ), "index in dimension [{}] is least than zero. coordinate: {}".format( i, coordinate ) assert ( coordinate[i] < mesh_shape[i] ), "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( i, mesh_shape, coordinate ) base = mesh_shape[-1] linear_idx = coordinate[-1] # row major order for i in range(len(mesh_shape) - 2, -1, -1): linear_idx += base * coordinate[i] base *= mesh_shape[i] return linear_idx def _linear_idx2coordinate(mesh_shape, linear_idx): """ mapping a linear scala into multidimensional mesh space, return it coordinate in that space. it is the inverse function of _coordinate2linear_idx. assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] the coordinate given linear_idx is: I[0] = linear_idx % S[0] I[0] = (linear_idx / S[0]) % S[1] I[0] = (linear_idx / (S[0] * S[1])) % S[2] .... """ assert linear_idx >= 0, "linear index [{}] is least than zero".format( linear_idx ) assert linear_idx < np.prod( mesh_shape ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( mesh_shape, linear_idx ) base = 1 coordinate = [-1] * len(mesh_shape) for i in reversed(range(len(mesh_shape))): offset = linear_idx / base coordinate[i] = int(offset % mesh_shape[i]) base *= mesh_shape[i] # row major order return coordinate def _get_corresponding_rank(dist_context, target_mesh, rank): # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. # we assume that all mesh are evenly divide from a parent mesh and should have same size. # to revise this in future. coordinate = None for mesh in dist_context.process_meshes: if rank in mesh.process_ids and mesh.shape == target_mesh.shape: coordinate = _linear_idx2coordinate( mesh.shape, mesh.process_ids.index(rank) ) break # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( # rank) if coordinate is not None: return target_mesh.process_ids[ _coordinate2linear_idx(mesh.shape, coordinate) ] else: return target_mesh.process_ids[0] def _get_unshard_dist_shape(var, dist_attr): var_shape = var.shape mapping = dist_attr.dims_mapping mesh = dist_attr.process_mesh.shape assert len(var_shape) == len( mapping ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( var_shape, mapping ) new_shape = [] for idx in range(len(var_shape)): if var_shape[idx] == -1 or mapping[idx] == -1: new_shape.append(var_shape[idx]) else: new_shape.append(var_shape[idx] * mesh[mapping[idx]]) return new_shape def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): from .dist_context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() for var in dist_main_prog.list_vars(): if var.is_data: tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( var ) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) var.desc.set_shape(inverse_shape) dim_mapping = tensor_dist_attr.dims_mapping dim_mapping = [-1] * len(dim_mapping) tensor_dist_attr.dims_mapping = dim_mapping dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) def _update_addition_info(addition_info): """Update default addition_info with inputs""" add_info = {"epoch": 0, "batch": 0, "batch_size": 0} if not addition_info: return add_info elif not isinstance(addition_info, dict): raise TypeError( "The type of 'addition_info' should be 'dict', " "but got '{}'.".format(str(type(addition_info))) ) else: for item, value in addition_info.items(): if item not in ["epoch", "batch", "batch_size"]: raise ValueError( "The key of 'addition_info' should be one of the " "['epoch', 'batch', 'batch_size'], but got '{}'.".format( str(item) ) ) if not isinstance(value, int): raise ValueError( "The value of 'addition_info' should be 'int', " "but got '{}'.".format(str(type(value))) ) add_info[item] = value return add_info def _check_valid_path(file_path): """Validity check of input file path""" if not file_path: return file_path elif isinstance(file_path, list): for file in file_path: if not isinstance(file, str): raise TypeError( "The type of file path should be 'str', " "but got '{}'.".format(str(type(file))) ) if not os.path.exists(file): raise ValueError(f"The file path '{file}' does not exist.") return file_path else: raise TypeError( "The type of file path should be 'list', " "but got '{}'.".format(str(type(file_path))) ) def _check_param_dict(param_dict): if not param_dict: raise ValueError("'param_dict' cannot be None.") elif not isinstance(param_dict, dict): raise TypeError( "The type of 'param_dict' should be 'dict', " "but got '{}'.".format(str(type(param_dict))) ) else: for name, value in param_dict.items(): if not isinstance(name, str): raise TypeError( "The type of key of 'param_dict' should be 'str', " "but got '{}'.".format(str(type(name))) ) if not isinstance(value, paddle.fluid.LoDTensor): raise TypeError( "The type of value of 'param_dict' should be 'LoDTensor', " "but got '{}'.".format(str(type(value))) ) return param_dict def _check_dist_attr(dist_attr): if not dist_attr: return dist_attr elif not isinstance(dist_attr, dict): raise TypeError( "The type of 'dist_attr' should be 'dict', " "but got '{}'.".format(str(type(dist_attr))) ) else: for name, value in dist_attr.items(): if not isinstance(name, str): raise TypeError( "The type of param name of 'dist_attr' should be 'str', " "but got '{}'.".format(str(type(name))) ) if not isinstance(value, dict): raise TypeError( "The type of distributed attribute should be 'dict', " "but got '{}'".format(str(type(value))) ) attr = ['process_shape', 'process_group', 'dims_mapping'] if list(value.keys()) != attr: raise ValueError( "The key of distributed attribute should be " "'['process_shape', 'process_group', 'dims_mapping']', " "but got {}.".format(str(value.keys())) ) return dist_attr def save_distributed_checkpoint( program, checkpoint_path, dist_attr_path, addition_info=None, is_integrated=False, dist_context=None, ): """ Save model parameter state, optimizer state, distributed attribute and additional information of each rank. Args: program(Program): The program to be saved. checkpoint_path(str): The path of the checkpoint file to be saved. dist_attr_path(str): The path of distributed attribute file to be saved. addition_info(dict, optional): Additional information, key should be selected in ['epoch', 'batch', 'batch_size']. Default values are 0, when 'addition_info' is None. Default: None. is_integrated(bool, optional): Whether to integrate param before save. Default: False. dist_context(DistributedContext ,optional): collect related distributed information for program Returns: None Examples: .. code-block:: python >>> import os >>> from paddle.distributed.auto_parallel.static.utils import save_distributed_checkpoint >>> step = 16000 >>> global_batch_size = 32 >>> path = os.path.join("./output", "step_%d" % step) >>> os.makedirs(path, exist_ok=True) >>> program = paddle.static.Program() >>> add_info = {'batch': step, "batch_size": global_batch_size} >>> save_distributed_checkpoint(program, path, path, add_info) """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.static.Program) assert isinstance(is_integrated, bool) if dist_context is None: dist_context = get_default_distributed_context() addition_info = _update_addition_info(addition_info) if not is_integrated: _save_distributed_state_dict(program, addition_info, checkpoint_path) _save_distributed_attribute(program, dist_attr_path, dist_context) else: # TODO: integrate param before save raise NotImplementedError( "Integrating parameter has not been implemented." ) def load_distributed_checkpoint(checkpoint_path, dist_attr_path): """ Load parameter, optimizer, distributed attribute and addition_info. Args: checkpoint_path(list[str]): model parameter file path, must be in order of rank id. dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. Returns: param_dict(dict): parameters' value of all ranks. dist_attr(dict): parameters' distributed attribute. addition_info(dict): additional information user saved in last training. Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. Examples: .. code-block:: python >>> # doctest: +SKIP('Depends on external files.') >>> from paddle.distributed.auto_parallel.static.utils import load_distributed_checkpoint >>> ckpt_path = [ ... './model_state_rank0.pdmodel', ... './model_state_rank1.pdmodel', ... ] >>> dist_attr_path = [ ... './dist_attr_rank0.pdattr', ... './dist_attr_rank1.pdattr', ... ] >>> param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) """ assert _check_valid_path( checkpoint_path ), "'checkpoint_path' cannot be None." assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." state_dict_info = _load_distributed_state_dict(checkpoint_path) dist_attr = _load_distributed_attribute(dist_attr_path) param_dict = state_dict_info["model"] addition_info = state_dict_info["addition_info"] return param_dict, dist_attr, addition_info def load_checkpoint_into_program( checkpoint_path, dist_attr_path, program, dist_context=None ): """ Load parameter, optimizer, distributed attribute and addition_info into model. Args: checkpoint_path(list[str]): model parameter file path, must be in order of rank id. dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. program(Program): the program to be updated with checkpoint_path. dist_context(DistributedContext ,optional): collect related distributed information for program Returns: addition_info(dict): user saved in last train. Notes: The return, 'addition_info', is belonging to the first file of checkpoint_path by default. Examples: .. code-block:: python >>> # doctest: +SKIP('Depends on external files.') >>> from paddle.distributed.auto_parallel.static.utils import load_checkpoint_into_program >>> exe.run(startup_program) >>> ckpt_path = [ ... './model_state_rank0.pdmodel', ... './model_state_rank1.pdmodel', ... ] >>> dist_attr_path = [ ... './dist_attr_rank0.pdattr', ... './dist_attr_rank1.pdattr', ... ] >>> load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.static.Program) assert _check_valid_path( checkpoint_path ), "'checkpoint_path' cannot be None." assert _check_valid_path(dist_attr_path), "'dist_attr_path' cannot be None." if dist_context is None: dist_context = get_default_distributed_context() all_state_dict_info = _load_distributed_state_dict(checkpoint_path) all_pre_dist_attr = _load_distributed_attribute(dist_attr_path) all_cur_dist_attr = get_dist_attr(program, dist_context) all_param_dict = all_state_dict_info["model"] addition_info = all_state_dict_info["addition_info"] sliced_param_dict = merge_and_slice_parameter( all_param_dict, all_pre_dist_attr, all_cur_dist_attr ) load_parameter_into_program(sliced_param_dict, program) return addition_info def load_parameter_into_program(param_dict, program): """ Load parameters into program. Args: param_dict(dict): parameters' name and value. program(Program): the program to be updated """ assert isinstance(param_dict, dict) assert program and isinstance(program, paddle.static.Program) if not param_dict: return program.set_state_dict(param_dict) def _save_distributed_attribute(program, dist_attr_path, dist_context): """Save distributed attribute of all parameters""" # TODO: just save a complete distributed attribute file rank_id = paddle.distributed.get_rank() dist_attr_name = os.path.join( dist_attr_path, f"dist_attr_rank{rank_id}.pdattr" ) dist_attr_dict = { "model": get_dist_attr(program, dist_context), "world_size": paddle.distributed.get_world_size(), } paddle.save(dist_attr_dict, dist_attr_name) logging.info(f"Already saved distributed attribute to '{dist_attr_path}'.") def _load_distributed_attribute(dist_attr_path): """Load parameters' distributed attribute from dist_attr_path""" total_dist_attr = {} for dist_attr_file in dist_attr_path: dist_attr = paddle.load(dist_attr_file) pre_world_size = dist_attr["world_size"] assert pre_world_size == len( dist_attr_path ), "The number of 'dist_attr_path' must be equal to the last training world size." for name, attr in dist_attr["model"].items(): if name not in total_dist_attr: total_dist_attr[name] = attr return total_dist_attr def _save_distributed_state_dict(program, addition_info, checkpoint_path): """Save parameters' state_dict""" rank = paddle.distributed.get_rank() ckpt_file_name = os.path.join( checkpoint_path, f"model_state_rank{rank}.pdmodel" ) state_dict = { "model": program.state_dict(), "world_size": paddle.distributed.get_world_size(), "addition_info": addition_info, } paddle.save(state_dict, ckpt_file_name) logging.info(f"Already saved model to '{checkpoint_path}'.") def _load_distributed_state_dict(checkpoint_path): """Load parameters' state_dict from checkpoint_path""" all_state_dict = {} for idx, ckpt_file in enumerate(checkpoint_path): state_dict_info = paddle.load(ckpt_file, return_numpy=True) pre_world_size = state_dict_info["world_size"] assert pre_world_size == len( checkpoint_path ), "The number of 'checkpoint_path' must be equal to the last training world size." if idx == 0: addition_info = state_dict_info["addition_info"] for name, value in state_dict_info["model"].items(): if name in all_state_dict: all_state_dict[name].append(np.array(value)) else: all_state_dict[name] = [np.array(value)] all_state_dict_info = { "model": all_state_dict, "addition_info": addition_info, } return all_state_dict_info def get_dist_attr(program, dist_context=None): """ Get distributed attribute of current rank. Args: program(Program): main program for training """ from .dist_context import get_default_distributed_context assert isinstance(program, paddle.static.Program) if dist_context is None: dist_context = get_default_distributed_context() dist_attr = {} for var in program.list_vars(): if is_parameter(var) or is_belong_to_optimizer(var): tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( var ) process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping dist_attr[var.name] = { "process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping, } return dist_attr def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): """ Merge parameters with previous dist_attr and slice parameters with current dist_attr Arags: dist_param_dict(dict): parameters' value of all ranks. pre_dist_attr(dict): parameters' dist_attr of last training process. cur_dist_attr(dict): parameters' dist_attr of current training process. Returns: dist_param_dict(dict): parameters' value of current rank. """ assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." assert isinstance( dist_param_dict, dict ), "The type of 'dist_param_dict' should be 'dict', but got {}.".format( str(type(dist_param_dict)) ) for name, value in dist_param_dict.items(): if not isinstance(name, str): raise TypeError( "The key of 'dist_param_dict' is parameter's name, " "and its type should be 'str', but got {}.".format( str(type(name)) ) ) if not isinstance(value, list) or not all( isinstance(v, np.ndarray) for v in value ): raise TypeError( "The value of 'dist_param_dict' is parameter's value of all ranks, " "and its type should be 'list(numpy.ndarray)'." ) if cur_dist_attr is None: return {} param_not_in_pre = [] param_not_in_cur = [] logging.info("Start to merge and slice parameters.") for var_name in cur_dist_attr.keys(): if var_name not in pre_dist_attr: param_not_in_pre.append(var_name) continue pre_attr = pre_dist_attr[var_name] cur_attr = cur_dist_attr[var_name] if pre_attr == cur_attr: # skip merge and slice rank_id = paddle.distributed.get_rank() index = cur_attr["process_group"].index(rank_id) param = dist_param_dict[var_name][index] dist_param_dict[var_name] = param continue pre_param = dist_param_dict[var_name] pre_dims_mapping = pre_attr["dims_mapping"] cur_dims_mapping = cur_attr["dims_mapping"] if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: complete_param = _merge_parameter_with_dist_attr( pre_param, pre_attr ) dist_param_dict[var_name] = complete_param else: complete_param = pre_param[0] dist_param_dict[var_name] = complete_param if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: sliced_param = _slice_parameter_with_dist_attr( complete_param, cur_attr ) dist_param_dict[var_name] = sliced_param for var_name in pre_dist_attr: if var_name not in cur_dist_attr: param_not_in_cur.append(var_name) dist_param_dict.pop(var_name) if param_not_in_pre: warnings.warn( "Parameters '{}' are not found in last training process.".format( str(param_not_in_pre) ) ) if param_not_in_cur: warnings.warn( "Parameters '{}' are not found in current training process.".format( str(param_not_in_cur) ) ) return dist_param_dict def _merge_parameter_with_dist_attr(param_list, dist_attr): """Merge parameter with distributed attribute""" from .reshard import Resharder dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # get the complete shape of the parameter complete_shape = Resharder.compute_complete_shape( param_list[0].shape, process_shape, dims_mapping ) # merge the parameter with dist_attr partition_param_list = [] merged_partiton = [] for process in process_group: partition_index = Resharder.compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group ) index = process_group.index(process) if partition_index not in merged_partiton: merged_partiton.append(partition_index) _merge_parameter( partition_param_list, param_list[index], partition_index, complete_shape, ) assert ( len(partition_param_list) == 1 or not partition_param_list ), "Fail to merge parameter" complete_param = partition_param_list[0][0] return complete_param def _slice_parameter_with_dist_attr(param, dist_attr): """Slice parameter with distributed attribute""" param = ( np.array(param) if isinstance(param, paddle.fluid.LoDTensor) else param ) dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # slice the parameter with dist_attr partition_index_list = _get_split_indices( param.shape, dims_mapping, process_shape, process_group ) sliced_param_list = _slice_parameter( param, partition_index_list, len(partition_index_list) ) # get the current parameter's index in sliced_param_list rank_id = paddle.distributed.get_rank() sliced_param_index = _get_sliced_param_index( rank_id, param.shape, dims_mapping, process_shape, process_group ) sliced_param = sliced_param_list[sliced_param_index] return sliced_param def _merge_parameter( partition_param_list, param, partition_index, complete_shape ): """ Merge partitial parameters to a complete one. Returns: None Examples: .. code-block:: python >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.utils import _merge_parameter >>> partition_param_list = [(np.array([[[1.11, 1.12]]]), [[0, 1],[0, 1],[0, 2]])] >>> param = np.array([[[1.13, 1.14]]]) >>> partition_index = [[0, 1],[0, 1],[2, 4]] >>> complete_shape = [2, 2, 4] >>> _merge_parameter(partition_param_list, param, partition_index, complete_shape) >>> print(partition_param_list) [(array([[[1.11, 1.12, 1.13, 1.14]]]), [[0, 1],[0, 1],[0, 4]])] """ from .reshard import Resharder if len(partition_param_list) == 1: is_complete_data = True for idx, item in enumerate(partition_param_list[0][1]): if item[0] != 0 or item[1] != complete_shape[idx]: is_complete_data = False break if is_complete_data: return if not partition_param_list: partition_param_list.append((param, partition_index)) else: i = 0 while i < len(partition_param_list): ( concat_axis, first_order, new_partition, ) = Resharder.compute_concat_info( partition_param_list[i][1], partition_index ) if concat_axis != -1: if first_order == 0: new_param = np.concatenate( (partition_param_list[i][0], param), axis=concat_axis ) else: new_param = np.concatenate( (param, partition_param_list[i][0]), axis=concat_axis ) partition_param_list.pop(i) _merge_parameter( partition_param_list, new_param, new_partition, complete_shape, ) break i += 1 def _slice_parameter(complete_param, partition_index_list, length): """ Slice a complete parameter. Returns: sliced_param_list(list): sliced parameters with 'partition_index_list' Examples: .. code-block:: python >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.utils import _slice_parameter >>> complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> rank = 2 >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> sliced_param_list = _slice_parameter(complete_param, [[], [], [2, 4]], 3) >>> print(sliced_param_list) [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] """ sliced_param_list = [] axis = len(complete_param.shape) - length sliced_param = np.split( complete_param, partition_index_list[axis], axis=axis ) if length == 1: return sliced_param for param in sliced_param: sliced_param_list.extend( _slice_parameter(param, partition_index_list, length - 1) ) return sliced_param_list def _get_sliced_param_index( rank, complete_shape, dims_mapping, process_shape, process_group ): """ Get sliced_param's index of current rank in all sliced parameters list. Returns: sliced_param_index(int): the index of sliced param in sliced_param_list Examples: .. code-block:: python >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.utils import _get_sliced_param_index >>> complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> rank = 2 >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) >>> print(slice_param) [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] >>> index = _get_sliced_param_index(rank, complete_shape, dims_mapping, ... process_shape, process_group) >>> print(index) 2 """ from .reshard import Resharder partition_index = Resharder.compute_partition_index( rank, complete_shape, dims_mapping, process_shape, process_group ) sliced_param_index = 0 for i, shape in enumerate(complete_shape): if dims_mapping[i] == -1: slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] if slice_shape == 1: index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_param_index = sliced_param_index * (shape // slice_shape) + index return sliced_param_index def _get_split_indices( complete_shape, dims_mapping, process_shape, process_group ): """ Get split indices of every dimension. Returns: split_indices_list(list): the split indices of every dimension of the parameter Examples: .. code-block:: python >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.utils import _get_split_indices >>> complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) >>> print(index) [[], [], [2, 4]] """ from .reshard import Resharder split_indices_list = [] for process in process_group: partition_index = Resharder.compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group ) if split_indices_list: for dim in range(len(partition_index)): split_indices_list[dim].extend(partition_index[dim]) else: split_indices_list = partition_index split_indices_list = list( map( lambda x, y: list(set(x) - {y} - {0}), split_indices_list, complete_shape, ) ) split_indices_list = [sorted(x) for x in split_indices_list] return split_indices_list def set_grad_var_shape(program, dist_context): from paddle.distributed.fleet.meta_optimizers.common import OpRole from .operators.common import infer_shape block = program.global_block() vars = block.vars appended_grad_times = 0 grad_var_to_var = dist_context.dist_op_context.grad_var_to_var for idx, op in enumerate(block.ops): if int(op.attr('op_role')) != int(OpRole.Backward): continue if ( int(block.ops[idx - 1].attr('op_role')) == int(OpRole.Forward) or int(block.ops[idx - 1].attr('op_role')) == 257 ): appended_grad_times += 1 if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: break if op.type in ["sum", "concat", "shape"]: continue op_dist_attr = dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr is not None for var_name in op.output_arg_names: if "@GRAD" not in var_name: continue if var_name in grad_var_to_var[appended_grad_times]: forward_var_name = grad_var_to_var[appended_grad_times][ var_name ] else: forward_var_name = var_name[: var_name.find("@GRAD")] if op.type in [ "c_allreduce_sum", "c_identity", "scale", "cast", "fill_any_like", ]: forward_var_name = op.input_arg_names[0] elif ( op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad" ): forward_var_name = None for output_name in op.output_names: if var_name in op.output(output_name): assert "@GRAD" in output_name input_name = output_name[: output_name.find("@GRAD")] assert len(op.input(input_name)) == 1 forward_var_name = op.input(input_name)[0] assert forward_var_name is not None need_set_shape_list = [ "reshape2_grad", "softmax_with_cross_entropy_grad", "transpose2_grad", "softmax_grad", "cross_entropy_grad2", "dropout_grad", "tanh_grad", "slice", "assign", "matmul_v2_triple_grad", "elementwise_add_triple_grad", "fill_constant", "sqrt_grad", "fused_softmax_mask_upper_triangle_grad", "flatten_contiguous_range_grad", "relu_grad", "exp_grad", "sigmoid_grad", "unsqueeze2_grad", "fused_dropout_add_grad", ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", "softmax", "cross_entropy2", "dropout", "tanh", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", "elementwise_add_grad_grad", "shape", "sqrt", "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", "relu", "exp", "sigmoid", "unsqueeze2", "fused_dropout_add", ] if op.type in need_set_shape_list: for forward_op in block.ops: idx = need_set_shape_list.index(op.type) forward_op_name = forward_list[idx] if ( forward_op.type in forward_op_name and forward_var_name in forward_op.input_arg_names ): op_dist_attr = ( dist_context.get_op_dist_attr_for_program( forward_op ) ) break forward_input_dist_attr = op_dist_attr.get_input_dist_attr( forward_var_name ) assert ( forward_input_dist_attr is not None ), f"{forward_var_name, str(op)}" forward_var = vars[forward_var_name] forward_var_dist_attr = ( dist_context.get_tensor_dist_attr_for_program(forward_var) ) assert forward_var_dist_attr is not None grad_var = vars[var_name] ref_shape = infer_shape( block, forward_var, forward_var_dist_attr, forward_input_dist_attr, ) if list(grad_var.shape) != ref_shape: grad_var.desc.set_shape(ref_shape) def is_forward_op(op): op_role = int(op.attr('op_role')) return OP_ROLE_KEY in op.attr_names and ( op_role == int(OpRole.Forward) or op_role == int(OpRole.Loss) ) def is_backward_op(op): return OP_ROLE_KEY in op.attr_names and int( op.all_attrs()[OP_ROLE_KEY] ) & int(OpRole.Backward) def is_optimize_op(op): return OP_ROLE_KEY in op.attr_names and int( op.all_attrs()[OP_ROLE_KEY] ) & int(OpRole.Optimize) def is_lr_sched_op(op): return OP_ROLE_KEY in op.attr_names and int( op.all_attrs()[OP_ROLE_KEY] ) & int(OpRole.Optimize.LRSched) def is_loss_op(op): return OP_ROLE_KEY in op.attr_names and int( op.all_attrs()[OP_ROLE_KEY] ) == (int(OpRole.Forward) | int(OpRole.Loss)) def is_loss_grad_op(op): if OP_ROLE_KEY not in op.attr_names: return False op_role = int(op.all_attrs()[OP_ROLE_KEY]) return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) def is_gradient_clip_op(op): return op.desc.has_attr("op_namescope") and op.desc.attr( "op_namescope" ).startswith("/gradient_clip") def is_prim_op(op): return op.type.endswith("_p") def is_comm_op(op): return op.has_attr("ring_id") def get_loss_op(block): loss_ops = [] for op in block.ops: if is_loss_op(op): assert ( len(op.desc.output_arg_names()) == 1 ), "loss op should only output loss var" loss_ops.append(op) assert len(loss_ops) == 1, "num of loss op is not equal to one" return loss_ops[0] def set_var_dist_attr(dist_context, var, dims_mapping, process_mesh, **kwargs): tensor_dist_attr = TensorDistAttr() tensor_dist_attr.dims_mapping = dims_mapping # TODO get global mesh group if isinstance(process_mesh, (list, np.ndarray)): tensor_dist_attr.process_mesh = ProcessMesh(process_mesh) elif isinstance(process_mesh, core.ProcessMesh): tensor_dist_attr.process_mesh = process_mesh else: raise ValueError( "{} must be a instance of ProcessMesh or list, but receive {}".format( process_mesh, type(process_mesh) ) ) if "mark_annotated" in kwargs and kwargs["mark_annotated"]: tensor_dist_attr.mark_annotated("dims_mapping") tensor_dist_attr.mark_annotated("process_mesh") dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) return tensor_dist_attr def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( new_op, process_mesh, ref_mapping, ctx ): assert process_mesh is not None assert ref_mapping is not None new_op_dist_attr = OperatorDistAttr() for input_varname in new_op.desc.input_arg_names(): new_op_dist_attr.set_input_dims_mapping(input_varname, ref_mapping) for output_varname in new_op.desc.output_arg_names(): new_op_dist_attr.set_output_dims_mapping(output_varname, ref_mapping) new_op_dist_attr.process_mesh = process_mesh ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) def naive_set_dist_op_attr_for_program_by_mesh( new_op, process_mesh, ctx, is_recompute=False ): assert process_mesh is not None new_op_dist_attr = OperatorDistAttr() for input_varname in new_op.desc.input_arg_names(): var = ctx.serial_main_program.global_block().var(input_varname) mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping new_op_dist_attr.set_input_dims_mapping(input_varname, mapping) for output_varname in new_op.desc.output_arg_names(): var = ctx.serial_main_program.global_block().var(output_varname) mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping new_op_dist_attr.set_output_dims_mapping(output_varname, mapping) new_op_dist_attr.process_mesh = process_mesh new_op_dist_attr.is_recompute = is_recompute ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) def update_op_dims_mapping_by_default_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr op_desc = dist_op.serial_op.desc # The following statement will be replaced by a more elegent way if op_desc.type() == "shape" or op_desc.type() == "slice": return False output_names = op_desc.output_names() xshape_arg_names = [] if "XShape" in output_names: xshape_arg_names = op_desc.output("XShape") batch_dim_mappings = [] for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): assert ( mapping == -1 ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( op_desc.type(), idx, mapping ) if len(dims_mapping) >= 1: batch_dim_mappings.append(dims_mapping[0]) for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for idx, mapping in enumerate(dims_mapping[1:]): assert ( mapping == -1 ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( op_desc.type(), idx, mapping ) if len(dims_mapping) >= 1: batch_dim_mappings.append(dims_mapping[0]) else: assert ( dims_mapping[0] == -1 ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part.".format( op_desc.type(), mapping ) if len(dims_mapping) > 2: for idx, mapping in enumerate(dims_mapping[2:]): assert ( mapping == -1 ), "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part.".format( op_desc.type(), idx, mapping ) batch_dim_mappings.append(dims_mapping[1]) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) assert ( compatible_dim_mapping is not None ), "There is no compatible dim mapping." for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if len(dims_mapping) >= 1 and compatible_dim_mapping != dims_mapping[0]: dims_mapping[0] = compatible_dim_mapping changed = True for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if ( len(dims_mapping) >= 1 and compatible_dim_mapping != dims_mapping[0] ): dims_mapping[0] = compatible_dim_mapping changed = True else: if compatible_dim_mapping != dims_mapping[1]: dims_mapping[1] = compatible_dim_mapping changed = True return changed def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_op): changed = False op_dist_attr = dist_op.dist_attr op_desc = dist_op.serial_op.desc input_arg_names = op_desc.input_arg_names() input_dims_mapping_dict = {} input_dims_mapping_lens = {} max_dims_mapping_len = -1 for arg_name in input_arg_names: dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if max_dims_mapping_len < len(dims_mapping): max_dims_mapping_len = len(dims_mapping) input_dims_mapping_dict[arg_name] = dims_mapping input_dims_mapping_lens[arg_name] = len(dims_mapping) dims_mapping_list = [] for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] for i in range(input_dims_mapping_lens[arg_name]): new_idx = ( max_dims_mapping_len - input_dims_mapping_lens[arg_name] ) + i new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] dims_mapping_list.append(new_dims_mapping) else: dims_mapping_list.append(input_dims_mapping_dict[arg_name]) output_arg_names = op_desc.output_arg_names() for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) assert len(dims_mapping) == max_dims_mapping_len dims_mapping_list.append(dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) assert ( compatible_dims_mapping is not None ), "There is no compatible dim mapping." for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: new_dims_mapping = [ -1 for _ in range(input_dims_mapping_lens[arg_name]) ] for i in range(input_dims_mapping_lens[arg_name]): new_idx = ( max_dims_mapping_len - input_dims_mapping_lens[arg_name] ) + i new_dims_mapping[i] = compatible_dims_mapping[new_idx] if new_dims_mapping != input_dims_mapping_dict[arg_name]: op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) changed = True else: if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: op_dist_attr.set_input_dims_mapping( arg_name, compatible_dims_mapping ) changed = True for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if compatible_dims_mapping != dims_mapping: op_dist_attr.set_output_dims_mapping( arg_name, compatible_dims_mapping ) changed = True return changed def get_all_distributed_main_program( serial_program_info, dist_context, parallelizer ): "Get all distributed main programs by dist_context." from .dist_context import DistributedOperatorContext cluster = serial_program_info.cluster copied_parallelizer = copy.deepcopy(parallelizer) all_dist_main_program = [] ranks = ( paddle.distributed.get_world_size() if cluster is None else len(cluster.get_all_devices("GPU")) ) for rank_id in range(ranks): used_dist_context = copy.deepcopy(dist_context) used_dist_context._dist_op_context = DistributedOperatorContext() ( _, _, dist_startup_program, dist_main_program, _, ) = copied_parallelizer._get_dist_program(rank_id, used_dist_context) all_dist_main_program.append(dist_main_program) return all_dist_main_program class SerialProgramInfo: def __init__( self, train_program, satrtup_program, loss, optimizer, cluster=None ): self._train_program = train_program self._startup_program = satrtup_program self._loss = loss self._optimizer = optimizer self._cluster = cluster @property def train_program(self): return self._train_program @property def startup_program(self): return self._startup_program @property def loss(self): return self._loss @property def optimizer(self): return self._optimizer @property def cluster(self): return self._cluster def get_standalone_cost_data(distributed_programs): def _compute_runtime(op_cost, op, vars): runtime = 0 try: runtime = float(op_cost["op_time"]) except: return runtime op_config = op_cost["config"] total_static_input_size = 0 total_actual_input_size = 0 parsed_info = op_config.split("\n") variable = "(Variable)" for info in parsed_info: variable = ( "(Variable)" if "(Variable)" in info else "(list" ) if variable in info: arg_name_lower = info[: info.find(variable) - 1] shape_left_boundary = info.find("[") shape_right_boundary = info.find("]") assert ( shape_left_boundary > 0 and shape_right_boundary > 0 and shape_right_boundary > shape_left_boundary ), "Get shape failed." shape = info[ shape_left_boundary + 1 : shape_right_boundary ].split(",") shape = [int(x.strip()) for x in shape] dtype_factor = 1 total_static_input_size += reduce(lambda x, y: x * y, shape, 1) if op.type == "c_embedding": arg_name_lower = ( "w" if arg_name_lower == "weight" else "ids" ) for arg_name in op.input_names: if arg_name.lower() == arg_name_lower: for var_name in op.input(arg_name): var = vars[var_name] total_actual_input_size += reduce( lambda x, y: x * y, var.shape ) break assert ( total_static_input_size > 0 and total_actual_input_size > 0 ), "Get input size failed." actual_runtime = ( total_actual_input_size / total_static_input_size * runtime ) return actual_runtime import paddle.cost_model as cm cost_model = cm.CostModel() cost_model.static_cost_data() DEFAULT_MULTIPLE = 2 OP_NAME_MAPPING = { "c_embedding": "embedding", "matmul_v2": "matmul", "transpose2": "transpose", "reshape2": "reshape", "unsqueeze2": "unsqueeze", "reduce_sum": "sum", "elementwise_div": "divide", } standalone_cost_data = [] # skip ops not_enum_ops = [ "create_py_reader", "create_double_buffer_reader", "read", "assign", ] for distributed_program in distributed_programs: cost_data = {} vars = distributed_program.global_block().vars for op in distributed_program.global_block().ops: runtime = 0 if op.type in not_enum_ops: cost_data[op.desc.id()] = runtime continue dtype = ( str(vars[op.input_arg_names[0]].dtype) if op.input_arg_names else "float32" ) if int(op.attr('op_role')) == int(OpRole.Backward): if "_grad" in op.type: forward_op_name = op.type[:-5] if forward_op_name in OP_NAME_MAPPING.keys(): forward_op_name = OP_NAME_MAPPING[forward_op_name] op_cost = cost_model.get_static_op_time( forward_op_name, forward=False, dtype=dtype ) if op_cost: runtime = _compute_runtime(op_cost, op, vars) else: op_cost = cost_model.get_static_op_time( forward_op_name, dtype=dtype ) if op_cost: runtime = 2 * _compute_runtime(op_cost, op, vars) elif int(op.attr('op_role')) == int(OpRole.Forward): op_name = ( OP_NAME_MAPPING[op.type] if op.type in OP_NAME_MAPPING.keys() else op.type ) op_cost = cost_model.get_static_op_time(op_name) if op_cost: runtime = _compute_runtime(op_cost, op, vars) cost_data[op.desc.id()] = runtime standalone_cost_data.append(cost_data) return standalone_cost_data def set_dist_op_desc_original_id(dist_op_desc, op_desc, dist_context): op_id = op_desc.id() op_original_id = op_desc.original_id() # First, try to set the original id to the id of the op_desc if op_id in dist_context._dist_ops_for_program: dist_op_desc.set_original_id(op_id) return # Second, try to set the original id to the original_id of the op_desc elif op_original_id in dist_context._dist_ops_for_program: dist_op_desc.set_original_id(op_original_id) return # Third, print error infomation if we cannot find the original id else: raise AssertionError( "Cannot find the original id in the distributed context" ) def to_list(value): if value is None: return value if isinstance(value, (list, tuple)): return list(value) return [value] def debug_program(program, path, name): filename = os.path.join( path, name + '_program' + ".%d" % (paddle.distributed.get_rank()) ) with open(filename, 'w') as f: f.write(str(program)) def ring_id_to_process_group(ring_id): from .process_group import get_all_process_groups for g in get_all_process_groups(): if g.id == ring_id: return g return None def find_higher_order_backward_op(program): higher_order_op_suffix = ['_grad_grad', 'triple_grad'] for block in program.blocks: for op in block.ops: for suffix in higher_order_op_suffix: if suffix in op.type: return True return False def get_var_numel(var): """ input: - var: variable return: number of elemnet in var """ assert isinstance(var, Variable) assert -1 not in var.shape return reduce(lambda x, y: x * y, var.shape, 1) def get_lr(optimizer): if isinstance(optimizer, paddle.optimizer.Optimizer): return optimizer.get_lr() elif isinstance(optimizer, paddle.static.Optimizer): if isinstance(optimizer._learning_rate, float): return optimizer._learning_rate else: return optimizer._learning_rate() else: raise TypeError( "'optimizer' must be object of class `paddle.optimizer.Optimizer`" " or `paddle.static.Optimizer`, but got {}.".format(type(optimizer)) ) def initialize_pg_in_full_mode(all_process_groups, cur_rank): import socket from ...collective import _get_global_env has_recv_by_socket = [] # This is a magic number magic_num = 500 genv = _get_global_env() cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":") cur_rank_recv_port = int(cur_rank_port) + magic_num server_socket = None # Large enough for recv rank buff_size = 1024 server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.bind((cur_rank_ip, cur_rank_recv_port)) # The 10 is an empirical value server_socket.listen(10) client_sockets = {} for process_group in all_process_groups: if cur_rank not in process_group.ranks: continue if len(process_group.ranks) == 2: index = process_group.ranks.index(cur_rank) is_send = True if index == 0 else False if is_send: recv_rank = process_group.ranks[1] recv_rank_ip, recv_rank_port = genv.trainer_endpoints[ recv_rank ].split(":") connect_port = int(recv_rank_port) + magic_num client_socket = socket.socket( socket.AF_INET, socket.SOCK_STREAM ) client_socket.connect((recv_rank_ip, connect_port)) client_socket.send(str(cur_rank).encode('utf-8')) rank = client_socket.recv(buff_size).decode('utf-8') rank = int(rank) if rank != recv_rank: raise ValueError( "Please check comm pair, the recv rank should be {} but got {}.".format( recv_rank, rank ) ) else: print( "It is able to instantiate {} as sender now.".format( process_group.ranks ) ) client_socket.close() else: send_rank = process_group.ranks[0] while True: if send_rank not in has_recv_by_socket: client_socket, recv_addr = server_socket.accept() rank = int(client_socket.recv(buff_size).decode()) client_sockets[rank] = client_socket has_recv_by_socket.append(rank) else: client_sockets[send_rank].send( str(cur_rank).encode("utf-8") ) client_sockets[send_rank].close() print( "It is able to instantiate {} as recver now.".format( process_group.ranks ) ) break process_group.instantiate() server_socket.close() def is_recompute_op(op): return op.has_attr('op_namescope') and "/auto_parallel/rc" in op.attr( 'op_namescope' ) def set_recompute_segments(model, losses, strategy, program): from ...passes.auto_parallel_recompute import RecomputeState if not losses: return recompute = strategy.recompute if not recompute.enable: return # NOTE: hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here # extract ckpts by specific model ckpts = [] if isinstance(model, paddle.nn.Layer): if ( hasattr(model, "gpt") and model.__class__.__name__ in [ 'GPTForPretraining', 'GPTForPretrainingAuto', ] and hasattr(model.gpt, "checkpoints") ): ckpts = model.gpt.checkpoints # last recompute segment is not need to recompute if len(ckpts) > 2: ckpts.pop() else: ckpts = recompute.checkpoints else: ckpts = recompute.checkpoints if not ckpts: return block = program.global_block() rc_state = RecomputeState(block, block.ops) rc_state.build_stats() checkpoints = rc_state.sort_checkpoints(ckpts) segments = [] start_idx = -1 pre_segment_end_idx = -1 while start_idx + 1 < len(checkpoints): if start_idx == -1: ckpt_name = checkpoints[start_idx + 1] if ckpt_name not in rc_state.var_op_deps: start_idx += 1 continue op_idx_list = rc_state.var_op_deps[ckpt_name]["var_as_output_ops"] if op_idx_list and max(op_idx_list) > 0: segments.append([0, max(op_idx_list) + 1]) else: flag, min_idx, max_idx = rc_state.is_subgraph( [checkpoints[start_idx]], [checkpoints[start_idx + 1]] ) if flag: min_idx = rc_state._update_segment_start( min_idx, pre_segment_end_idx ) segments.append([min_idx, max_idx + 1]) else: logging.debug( "Could not recompute op range [{}] - [{}] ".format( min_idx, max_idx + 1 ) ) start_idx += 1 for i, segment in enumerate(segments): for j in range(segment[0], segment[1]): block.ops[j]._set_attr( 'op_namescope', "/auto_parallel/rc_" + str(i) ) def get_input_split_info(cur_rank, var, dist_context): # deduce how the input data is split among the cluster tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(var) process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping if cur_rank not in process_mesh.process_ids: rank_id = _get_corresponding_rank(dist_context, process_mesh, cur_rank) else: rank_id = cur_rank batch_size_axis = dims_mapping[0] if batch_size_axis > -1 and process_mesh.shape[batch_size_axis] > 1: group_ranks = _get_comm_group( process_mesh.process_ids, process_mesh.shape, batch_size_axis, rank_id, ) return len(group_ranks), group_ranks.index(rank_id) return 1, 0 def validate_opt(optimizer): if optimizer is not None: optimizer._parameter_list = None optimizer._param_groups = None return optimizer def set_data_parallel(x): from ..interface import ProcessMesh, shard_tensor from .process_group import get_world_process_group world_ranks = get_world_process_group().ranks process_mesh = ProcessMesh(world_ranks, ['dp']) shard_spec = ['dp' if len(world_ranks) > 1 else None] + [ None for _ in range(len(x.shape) - 1) ] return shard_tensor(x, process_mesh, shard_spec) def is_naive_data_parallel(dist_context): # Navie data parallel only completes dist_attr once from the front to back. if not dist_context.data_parallel: return False ops_type = [ op.type for op in dist_context._original_serial_main_program.global_block().ops ] if ( not set(ops_type) & set(__not_naive_data_parallel_op__) ) and dist_context.data_parallel: return True return False def _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): py_process_mesh = py_dist_attr.process_mesh if py_process_mesh is not None: cpp_dist_attr.process_mesh = core.ProcessMesh( py_process_mesh.shape, py_process_mesh.process_ids, ["d" + str(i) for i in range(len(py_process_mesh.shape))], ) cpp_dist_attr.dims_mapping = py_dist_attr.dims_mapping cpp_dist_attr.annotated = py_dist_attr.annotated def _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): from ..process_mesh import ProcessMesh cpp_process_mesh = cpp_dist_attr.process_mesh if cpp_process_mesh is not None: py_dist_attr.process_mesh = ProcessMesh( shape=cpp_process_mesh.shape, process_ids=cpp_process_mesh.process_ids, ) py_dist_attr.dims_mapping = cpp_dist_attr.dims_mapping py_dist_attr.annotated = cpp_dist_attr.annotated def _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr): py_process_mesh = py_dist_attr.process_mesh if py_process_mesh is not None: cpp_dist_attr.process_mesh = core.ProcessMesh( py_process_mesh.shape, py_process_mesh.process_ids, ["d" + str(i) for i in range(len(py_process_mesh.shape))], ) cpp_dist_attr.impl_type = py_dist_attr.impl_type cpp_dist_attr.impl_idx = py_dist_attr.impl_idx cpp_dist_attr.is_recompute = py_dist_attr.is_recompute cpp_dist_attr.annotated = py_dist_attr.annotated for name, py_tensor_dist_attr in py_dist_attr.inputs_dist_attrs.items(): cpp_tensor_dist_attr = cpp_dist_attr.get_input_dist_attr(name) _copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr) for name, py_tensor_dist_attr in py_dist_attr.outputs_dist_attrs.items(): cpp_tensor_dist_attr = cpp_dist_attr.get_output_dist_attr(name) _copy_tensor_dist_attr_to_cpp(cpp_tensor_dist_attr, py_tensor_dist_attr) def _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr): from ..process_mesh import ProcessMesh cpp_process_mesh = cpp_dist_attr.process_mesh if cpp_process_mesh is not None: py_dist_attr.process_mesh = ProcessMesh( shape=cpp_process_mesh.shape, process_ids=cpp_process_mesh.process_ids, ) py_dist_attr.impl_type = cpp_dist_attr.impl_type py_dist_attr.impl_idx = cpp_dist_attr.impl_idx py_dist_attr.is_recompute = cpp_dist_attr.is_recompute py_dist_attr.annotated = cpp_dist_attr.annotated for name, cpp_tensor_dist_attr in cpp_dist_attr.inputs_dist_attrs.items(): py_tensor_dist_attr = py_dist_attr.get_input_dist_attr(name) _copy_tensor_dist_attr_from_cpp( cpp_tensor_dist_attr, py_tensor_dist_attr ) for name, cpp_tensor_dist_attr in cpp_dist_attr.outputs_dist_attrs.items(): py_tensor_dist_attr = py_dist_attr.get_output_dist_attr(name) _copy_tensor_dist_attr_from_cpp( cpp_tensor_dist_attr, py_tensor_dist_attr ) def _copy_dist_attr_to_cpp(dist_context): for dist_tensor in dist_context._dist_tensors_for_program.values(): _copy_tensor_dist_attr_to_cpp( dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr ) for dist_op in dist_context._dist_ops_for_program.values(): _copy_op_dist_attr_to_cpp( dist_op.serial_op.dist_attr, dist_op.dist_attr ) def _copy_dist_attr_from_cpp(dist_context): for dist_tensor in dist_context._dist_tensors_for_program.values(): _copy_tensor_dist_attr_from_cpp( dist_tensor.serial_tensor.dist_attr, dist_tensor.dist_attr ) for dist_op in dist_context._dist_ops_for_program.values(): _copy_op_dist_attr_from_cpp( dist_op.serial_op.dist_attr, dist_op.dist_attr ) def _copy_dist_attr_to_cpp_for_graph(dist_context): for node in dist_context.serial_ordered_nodes: if node.is_var() and node.var() is not None: py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node) cpp_dist_attr = node.var().dist_attr _copy_tensor_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr) if node.is_op() and node.op() is not None: py_dist_attr = dist_context.get_op_dist_attr_for_graph(node) cpp_dist_attr = node.op().dist_attr _copy_op_dist_attr_to_cpp(cpp_dist_attr, py_dist_attr) def _copy_dist_attr_from_cpp_for_graph(dist_context): for node in dist_context.serial_ordered_nodes: if node.is_var() and node.var() is not None: py_dist_attr = dist_context.get_tensor_dist_attr_for_graph(node) cpp_dist_attr = node.var().dist_attr _copy_tensor_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr) if node.is_op() and node.op() is not None: py_dist_attr = dist_context.get_op_dist_attr_for_graph(node) cpp_dist_attr = node.op().dist_attr _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr) def insert_dependencies_for_two_ops( block, idx, prior_op, posterior_op, dist_context, is_recompute=False, sync=False, op_namescope=None, ): """ dependency: prior_op should be run before posterior_op """ assert ( len(prior_op.output_arg_names) >= 1 ), "first op of dependency should at least have one output. [{}]".format( str(prior_op) ) assert ( len(posterior_op.input_arg_names) >= 1 ), "second op of dependency should at least have one input. [{}]".format( str(posterior_op) ) prior_op_mesh = dist_context.get_op_dist_attr_for_program( prior_op ).process_mesh posterior_mesh = dist_context.get_op_dist_attr_for_program( posterior_op ).process_mesh assert ( prior_op_mesh == posterior_mesh ), "two ops of dependency should have same mesh but got [{}] and [{}]".format( str(prior_op_mesh), str(posterior_mesh) ) def _select_best_depend_var(vars): # parameter should not be dep var since it maybe partition in sharding pass vars = [var for var in vars if not var.is_parameter] assert len(vars) > 0 vars_with_numels = [(var, get_var_numel(var)) for var in vars] vars_with_numels.sort(key=lambda x: x[1]) return vars_with_numels[-1][0] first_var = _select_best_depend_var( [block.var(name) for name in prior_op.output_arg_names] ) second_var = _select_best_depend_var( [block.var(name) for name in posterior_op.input_arg_names] ) return insert_dependencies_for_vars( block, idx, first_var, second_var, dist_context, OpRole.Backward, process_mesh=prior_op_mesh, is_recompute=is_recompute, sync=sync, op_namescope=op_namescope, use_nop=False, ) def insert_dependencies_for_vars( block, idx, prior_vars, post_vars, dist_context, oprole, process_mesh=None, is_recompute=False, sync=False, op_namescope=None, use_nop=False, ): """ dependency: op that generates prior_vars should be run before op that generates post_vars """ if isinstance(prior_vars, Variable): prior_vars = [prior_vars] if isinstance(post_vars, Variable): post_vars = [post_vars] for prior_var in prior_vars: assert block.has_var(prior_var.name) for post_var in post_vars: assert block.has_var(post_var.name) if process_mesh is None: process_mesh = dist_context.get_tensor_dist_attr_for_program( post_vars[0] ).process_mesh assert process_mesh is not None use_nop = True if use_nop: depend_op = block._insert_op_without_sync( idx, type='nop', inputs={ "X": prior_vars, }, outputs={"Out": post_vars}, ) else: depend_op = block._insert_op_without_sync( idx, type='depend', inputs={ "X": post_vars, "Dep": prior_vars, }, outputs={"Out": post_vars}, ) depend_op._set_attr(OP_ROLE_KEY, oprole) # TODO: condition can be removed when add correct dist_attr for coalesce vars and ops in sharding_pass if is_recompute or process_mesh != [-1]: depend_op_dist_attr = OperatorDistAttr() depend_op_dist_attr.impl_idx = 0 depend_op_dist_attr.impl_type = "default" depend_op_dist_attr.process_mesh = process_mesh depend_op_dist_attr.is_recompute = is_recompute for input_varname in depend_op.desc.input_arg_names(): var = block.var(input_varname) mapping = dist_context.get_tensor_dist_attr_for_program( var ).dims_mapping depend_op_dist_attr.set_input_dims_mapping(input_varname, mapping) for output_varname in depend_op.desc.output_arg_names(): var = block.var(output_varname) mapping = dist_context.get_tensor_dist_attr_for_program( var ).dims_mapping depend_op_dist_attr.set_output_dims_mapping(output_varname, mapping) dist_context.set_op_dist_attr_for_program( depend_op, depend_op_dist_attr ) if op_namescope is not None: depend_op._set_attr('op_namescope', f"/{op_namescope}") if sync: block._sync_with_cpp() return depend_op def is_dep_skip_op(op): if "c_" in op.type: return True return False def _dygraph_guard_(func): def __impl__(*args, **kwargs): if paddle.framework.in_dynamic_mode(): return func(*args, **kwargs) else: with paddle.fluid.dygraph.guard(): return func(*args, **kwargs) return __impl__ dygraph_guard = wrap_decorator(_dygraph_guard_) def use_new_executor(): new_executor_micro_batching = os.environ.get( 'FLAGS_new_executor_micro_batching', None ) return new_executor_micro_batching in [ 1, '1', True, 'True', 'true', ] def use_new_ir(): enable_new_ir_in_executor = os.environ.get( 'FLAGS_enable_new_ir_in_executor', None ) return enable_new_ir_in_executor in [ 1, '1', True, 'True', 'true', ] def get_pp_stage(dist_context, rank): pp_idx = None for idx, process_mesh in enumerate(dist_context.process_meshes): if rank in process_mesh.process_ids: pp_idx = idx break return pp_idx def wrap_data_for_completion( dist_op, input_names: list, output_names: list, attr_names: list ): """ Get data used in inferring distributed attributes, including: 1. DistTensorSpec for each input and output tensor of this dist_op. 2. Operator attributes of this dist_op, e.g. transpose_x in matmul op. Args: dist_op: the DistributedOperator input_names: list, name of the dist_op's input tensors output_names: list, name of the dist_op's output tensors attr_names: list, attribute name of the dist_op's corresponding serial op Returns: input_specs: list, DistTensorSpec for each input tensor of the dist_op output_specs: list, DistTensorSpec for each output tensor of the dist_op attrs: dict, attribute map of the dist op Examples: .. code-block:: python >>> # doctest: +SKIP('Depends on other ops.') >>> from paddle.distributed.auto_parallel.static.utils import wrap_data_for_completion >>> op_desc = dist_op.serial_op.desc >>> input_name_list = [] >>> output_name_list = [] >>> input_name_list.append(op_desc.input('X')[0]) # 'X' is the arg name for op >>> input_name_list.append(op_desc.input('Y')[0]) >>> output_name_list.append(op_desc.output('Out')[0]) >>> attr_name_list = ['trans_x', 'trans_y'] >>> input_specs, output_specs, attrs = wrap_data_for_completion( ... dist_op, ... input_name_list, ... output_name_list, ... attr_name_list) """ input_specs = [] output_specs = [] attrs = {} serial_op = dist_op.serial_op # Construct each input tensor's DistTensorSpec with shape and dist_attr for name in input_names: tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(name) var = serial_op.block._var_recursive(name) tensor_shape = var.shape dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr) input_specs.append(dist_spec) # Construct each output tensor's DistTensorSpec with shape and dist_attr for name in output_names: tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(name) var = serial_op.block._var_recursive(name) tensor_shape = var.shape dist_spec = DistTensorSpec(tensor_shape, tensor_dist_attr) output_specs.append(dist_spec) for attr_name in attr_names: attrs[attr_name] = serial_op.desc.attr(attr_name) return input_specs, output_specs, attrs