# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License import threading import paddle.fluid.core as core import numpy as np def is_valid_list_index(list, index): if index >= -len(list) and index < len(list): return True else: return False def is_dim_shard(mapping): if mapping != -1: return True else: return False def is_dim_replicate(mapping): if mapping == -1: return True else: return False def compute_compatible_dim_mapping(dim_mappings): if not dim_mappings: return None compatible_mapping = dim_mappings[0] for mapping in dim_mappings: if compatible_mapping == -1: compatible_mapping = mapping elif mapping == -1: continue elif compatible_mapping == mapping: continue else: return None return compatible_mapping def compute_compatible_dims_mapping(dims_mapping_list): if not dims_mapping_list: return None length = len(dims_mapping_list[0]) for dims_mapping in dims_mapping_list: assert dims_mapping is not None, \ "Dims mapping must not be None for compatible computation" assert len(dims_mapping) == length, \ "The length of dims_mapping in list must be same for compatible computation." compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( list(dim_mappings)) if compatible_dim_mapping is None: return None compatible_result.append(compatible_dim_mapping) return compatible_result def compute_compatible_process_mesh(process_mesh_list): compatible_process_mesh = None if not process_mesh_list: return compatible_process_mesh for process_mesh in process_mesh_list: if process_mesh is not None: if compatible_process_mesh is None: compatible_process_mesh = process_mesh else: assert process_mesh == compatible_process_mesh, \ "There is no compatible process mesh." return compatible_process_mesh def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): assert len(dims_mapping_list) == len(index_list) changed = False dim_mappings = [] for i in range(len(dims_mapping_list)): assert is_valid_list_index(dims_mapping_list[i], index_list[i]) dim_mappings.append(dims_mapping_list[i][index_list[i]]) compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) if compatible_dim_mapping is None: return False for i in range(len(dims_mapping_list)): if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: dims_mapping_list[i][index_list[i]] = compatible_dim_mapping changed = True return changed def append_distributed_attr_suffix(name): """ Append auto parallel suffix for distributed attribute name. """ return name + core.kAutoParallelSuffix() def remove_distributed_attr_suffix(name): """ Remove auto parallel suffix from distributed attribute name. """ return name.strip(core.kAutoParallelSuffix()) def check_distributed_attr_for_program(program, dist_context=None): from .context import get_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() assert dist_context.is_initialized_for_program(), \ "Distributed attributes must be initialized before check." for block in program.blocks: for tensor in block.vars.values(): tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( tensor) if (tensor_dist_attr is not None) and ( not tensor_dist_attr.is_valid()): return False for op in block.ops: op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) if (op_dist_attr is not None) and (not op_dist_attr.is_valid()): return False return True def print_program_with_distributed_attr(program, dist_context=None): """ This function reuses the original program output ability with a distributed context. Using lock can avoid multiple threads change the default distributed context simultaneously. """ lock = threading.Lock() lock.acquire() from .context import get_default_distributed_context from .context import set_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() print(program) else: original_default_context = get_default_distributed_context() set_default_distributed_context(dist_context) print(program) set_default_distributed_context(original_default_context) lock.release() def _get_comm_group(processes, shape, axis, rank): """ Given a rank and the processes mesh the rank belongs to, compute the communication peers of the rank based on the give axis in the mesh. Example: 16 processes managed in a 4-Dimensinal mesh with shape of [2, 2, 2, 2]. the rank communication peers of rank 0 (included) are following: in axis 0: [0, 1] in axis 1: [0, 2] in axis 2: [0, 4] in axis 3: [0, 8] """ # NOTE _linear_idx2coordinate assume processes mesh start with 0 and continuous # tricks to support processes mesh when it is not start with 0 or continuous rank_relatvie = processes.index(rank) coordinate = _linear_idx2coordinate(shape, rank_relatvie) coordinates_in_group = [coordinate[:] for i in range(shape[axis])] # select comm group for i in range(shape[axis]): coordinates_in_group[i][axis] = i ranks_in_group_relative = [ _coordinate2linear_idx(shape, coordinate) for coordinate in coordinates_in_group ] ranks_in_group = [processes[idx] for idx in ranks_in_group_relative] return sorted(ranks_in_group) def _coordinate2linear_idx(mesh_shape, coordinate): """ convert a coordinate in multidimensional mesh space into a scala idx in linear space. it use Row-major order for dimension conversion. so it has: [most_significant_dim, ..., least_significant_dim] assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] linear_idx of a n dimensional coordinate is: I[n-1] * (S[n-2] * S[n-3] * S[n-4] * .... S[0]) + I[n-2] * ( S[n-3] * S[n-4] * .... S[0]) + I[n-3] * ( S[n-4] * .... S[0]) + ... I[1] * ( S[0]) + I[0] """ # NOTE the following function work based on a strong an assumption # that the processes in mesh are # 1. starts from 0 # 2. continuous # it will be wrong if ths above condition doesnot meet, # e.g. process_mesh = { process_groups = [7, 8, 9,10, 12, 13, 14, 15], mesh = [2, 4]} # if you want a more general mapping, you should use cartesian product assert len(mesh_shape) == len( coordinate ), "coordinate should have the same size as mesh shape, but got shape: {}, coordinate: {}".format( mesh_shape, coordinate) for i in range(len(mesh_shape)): assert coordinate[ i] >= 0, "index in dimension [{}] is least than zero. coordinate: {}".format( i, coordinate) assert coordinate[i] < mesh_shape[ i], "index beyond extent in dimension [{}]. shape: {}, coordinate: {}".format( i, mesh_shape, coordinate) base = mesh_shape[-1] linear_idx = coordinate[-1] # row major order for i in range(len(mesh_shape) - 2, -1, -1): linear_idx += base * coordinate[i] base *= mesh_shape[i] return linear_idx def _linear_idx2coordinate(mesh_shape, linear_idx): """ mapping a linear scala into multidimensional mesh space, return it coordinate in that space. it is the inverse function of _coordinate2linear_idx. assume: the size of i-th dimension to be: S[i] the index of j-th dimension is: I[j] the coordinate given linear_idx is: I[0] = linear_idx % S[0] I[0] = (linear_idx / S[0]) % S[1] I[0] = (linear_idx / (S[0] * S[1])) % S[2] .... """ assert linear_idx >= 0, "linear index [{}] is least than zero".format( linear_idx) assert linear_idx < np.prod( mesh_shape ), "linear index beyond the extent of mesh shape. shape: {}, linear index: {}".format( mesh_shape, linear_idx) base = 1 coordinate = [-1] * len(mesh_shape) for i in reversed(range(len(mesh_shape))): offset = linear_idx / base coordinate[i] = int(offset % mesh_shape[i]) base *= mesh_shape[i] # row major order return coordinate