# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import warnings import numpy as np import paddle from ...utils.log_utils import get_logger class Converter: """ Converter is a class object for auto parallel to convert tensors from one parallel strategy to another one. Tensors will merge and slice value with their strategy when strategies are different. """ def __init__(self, tensors_dict, pre_strategy, cur_strategy): """ Args: tensors_dict(dict): tensors' value of all ranks that to be converted. key is tensor's name(str), value is all ranks' data(list(numpy.ndarray)) pre_strategy(dict): tensors' distributed attribute of last training process. key is tensor's name(str), value is tensor's distributed attribute in last training process. cur_strategy(dict): tensors' distributed attribute of current rank. key is tensor's name(str), value is tensor's distributed attribute in current rank. """ self._tensors_dict = self._check_tensor_dict(tensors_dict) self._pre_strategy = self._check_pre_strategy(pre_strategy) self._cur_strategy = self._check_cur_strategy(cur_strategy) self._logger = get_logger(logging.INFO) def _check_tensor_dict(self, tensors_dict): if not tensors_dict: raise ValueError( "'tensors_dict' is None, " "the tensors to be converted cannot be None." ) if not isinstance(tensors_dict, dict): raise TypeError( "The type of 'tensors_dict' should be 'dict', but got '{}'.".format( str(type(tensors_dict)) ) ) return tensors_dict def _check_pre_strategy(self, pre_strategy): if not pre_strategy: raise ValueError( "'pre_strategy' is None, " "there are not tensors in pre process." ) if not isinstance(pre_strategy, dict): raise TypeError( "The type of 'pre_strategy' should be 'dict', " "but got '{}'.".format(str(type(pre_strategy))) ) return pre_strategy def _check_cur_strategy(self, cur_strategy): if not cur_strategy: warnings.warn( "'cur_strategy' is None, " "there are not tensors in cur process" ) if not isinstance(cur_strategy, dict): raise TypeError( "The type of 'cur_strategy' should be 'dict', " "but got '{}'.".format(str(type(cur_strategy))) ) return cur_strategy def convert(self, strict=True): """ Convert tensors Args: strict(bool): whether to strict convert tensor with tensor's name. If False, it will convert tensors by prefix matching. Otherwise, tensors will be converted with their name strictly. Returns: converted tensors(dict) Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:DISTRIBUTED) >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.converter import Converter >>> complete_tensors = np.arange(4).reshape([2, 2]) >>> partitial_tensors = np.split(complete_tensors, 2, axis=0) >>> name = "tmp_0" >>> tensors_dict = {name: partitial_tensors} >>> strategy_1 = { ... name: { ... "process_shape": [2], ... "process_group": [0, 1], ... "dims_mapping": [0, -1] ... } ... } >>> strategy_2 = { ... name: { ... "process_shape": [2], ... "process_group": [0, 1], ... "dims_mapping": [-1, -1] ... } ... } >>> converter = Converter(tensors_dict, strategy_1, strategy_2) >>> result = converter.convert() >>> # the result's value is equal to `complete_tensors` """ tensors_dict = {} # the name which is in cur_process but not in pre_process tensor_not_in_pre = [] # the name which is in pre_process but not in cur_process tensor_not_in_cur = [] # the name which is in strategy but not in ckpt files tensor_not_in_ckpt = [] self._logger.info("Start to convert tensors.") for tensor_name in self._cur_strategy: if tensor_name not in self._pre_strategy: tensor_not_in_pre.append(tensor_name) continue if tensor_name not in self._tensors_dict: tensor_not_in_ckpt.append(tensor_name) continue self._pre_name = tensor_name self._cur_name = tensor_name tensor_list = self._tensors_dict[tensor_name] pre_dist_attr = self._pre_strategy[tensor_name] cur_dist_attr = self._cur_strategy[tensor_name] try: tensors_dict[tensor_name] = Converter.merge_and_slice( tensor_list, pre_dist_attr, cur_dist_attr ) except ValueError as err: raise ValueError( f"Fail to convert tensor '{str(tensor_name)}'. " + str(err) ) for tensor_name in self._pre_strategy: if tensor_name not in self._cur_strategy: tensor_not_in_cur.append(tensor_name) if not strict: ( tensors_dict, tensor_match_with_pre, tensor_match_with_cur, ) = self.convert_with_prefix_match( tensors_dict, tensor_not_in_pre, tensor_not_in_cur ) else: tensors_dict, tensor_match_with_pre, tensor_match_with_cur = ( tensors_dict, [], [], ) tensor_not_in_pre = set(tensor_not_in_pre) - set(tensor_match_with_pre) tensor_not_in_cur = set(tensor_not_in_cur) - set(tensor_match_with_cur) if tensor_not_in_pre: warnings.warn( "tensors [{}] are not found in last training strategy.".format( str(tensor_not_in_pre) ) ) if tensor_not_in_cur: warnings.warn( "tensors [{}] are not found in current training strategy.".format( str(tensor_not_in_cur) ) ) if tensor_not_in_ckpt: warnings.warn( "tensors [{}] are found in pre_strategy, but are not found" "in checkpoint files, please check your checkpoint files.".format( str(tensor_not_in_ckpt) ) ) return tensors_dict def convert_with_prefix_match( self, tensors_dict, tensor_not_in_pre, tensor_not_in_cur ): # the name which in cur_process and can match with pre_process tensor_match_with_pre = [] # the name which in pre_process and can match with cur_process tensor_match_with_cur = [] for cur_name in tensor_not_in_pre: prefix_name = cur_name while prefix_name.find("_") != -1: prefix_name = prefix_name[: prefix_name.rfind("_")] for pre_name in tensor_not_in_cur: if prefix_name in pre_name: # 'cur_name' of cur_process can match with 'pre_name' of pre_process self._pre_name = pre_name self._cur_name = cur_name pre_tensor_list = self._tensors_dict[pre_name] pre_dist_attr = self._pre_strategy[pre_name] cur_dist_attr = self._cur_strategy[cur_name] try: tensors_dict[cur_name] = Converter.merge_and_slice( pre_tensor_list, pre_dist_attr, cur_dist_attr ) except ValueError as err: raise ValueError( "Fail to convert tensor '{}' by '{}'. ".format( str(cur_name), str(pre_name) ) + str(err) ) self._logger.info( "tensor [{}] is matched with tensor [{}]".format( cur_name, pre_name ) ) tensor_match_with_pre.append(cur_name) tensor_match_with_cur.append(pre_name) break break return tensors_dict, tensor_match_with_pre, tensor_match_with_cur @staticmethod def merge_and_slice(tensor_list, pre_dist_attr, cur_dist_attr): """ Merge tensors with previous dist_attr and slice tensors with current dist_attr Returns: tensor(numpy.narray): a tensor's value of current rank. """ assert isinstance(tensor_list, list) assert all(isinstance(p, np.ndarray) for p in tensor_list) if pre_dist_attr == cur_dist_attr: # skip merge and slice tensor rank_id = paddle.distributed.get_rank() index = cur_dist_attr["process_group"].index(rank_id) tensor = tensor_list[index] else: pre_dims_mapping = pre_dist_attr["dims_mapping"] cur_dims_mapping = cur_dist_attr["dims_mapping"] if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: # merge tensor tensor = Converter.merge_with_dist_attr( tensor_list, pre_dist_attr ) else: # skip merge tensor tensor = tensor_list[0] if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: # slice tensor tensor = Converter.slice_with_dist_attr(tensor, cur_dist_attr) return tensor @staticmethod def merge_with_dist_attr(tensor_list, dist_attr): """Merge tensor with distributed attribute""" from .reshard import Resharder dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # get the complete shape of the tensor complete_shape = Resharder.compute_complete_shape( tensor_list[0].shape, process_shape, dims_mapping ) # merge the tensor with dist_attr partition_tensor_list = [] merged_partiton = [] for process in process_group: partition_index = Resharder.compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group, ) index = process_group.index(process) if partition_index not in merged_partiton: merged_partiton.append(partition_index) Converter.merge( partition_tensor_list, tensor_list[index], partition_index, complete_shape, ) if len(partition_tensor_list) != 1: raise ValueError( "Fail to merge tensor with dist_attr '{}'.".format( str(dist_attr) ) ) complete_tensor = partition_tensor_list[0][0] return complete_tensor @staticmethod def slice_with_dist_attr(tensor, dist_attr): """Slice tensor with distributed attribute""" dims_mapping = dist_attr["dims_mapping"] process_shape = dist_attr["process_shape"] process_group = dist_attr["process_group"] # slice the tensor with dist_attr partition_index_list = Converter._get_split_indices( tensor.shape, dims_mapping, process_shape, process_group ) sliced_tensor_list = Converter.split( tensor, partition_index_list, len(partition_index_list) ) # get the current tensor's index in sliced_tensor_list rank_id = paddle.distributed.get_rank() sliced_tensor_index = Converter._get_sliced_index( rank_id, tensor.shape, dims_mapping, process_shape, process_group ) if sliced_tensor_index not in range(len(sliced_tensor_list)): raise ValueError( "Fail to slice tensor with dist_attr '{}'.".format( str(dist_attr) ) ) sliced_tensor = sliced_tensor_list[sliced_tensor_index] return sliced_tensor @staticmethod def merge(partition_tensor_list, tensor, partition_index, complete_shape): """ Merge partitial tensors to a complete. Returns: None Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:DISTRIBUTED) >>> import numpy as np >>> import paddle >>> from paddle.distributed.auto_parallel.static.converter import Converter >>> partition_tensor_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])] >>> tensor = np.array([[[1.13, 1.14]]]) >>> partition_index = [[0,1],[0,1],[2,4]] >>> complete_shape = [3, 2] >>> Converter.merge(partition_tensor_list, tensor, partition_index, complete_shape) >>> print(partition_tensor_list) [(array([[[1.11, 1.12, 1.13, 1.14]]]), [[0, 1], [0, 1], [0, 4]])] """ from .reshard import Resharder if len(partition_tensor_list) == 1: is_complete_data = True for idx, item in enumerate(partition_tensor_list[0][1]): if item[0] != 0 or item[1] != complete_shape[idx]: is_complete_data = False break if is_complete_data: return if not partition_tensor_list: partition_tensor_list.append((tensor, partition_index)) else: i = 0 while i < len(partition_tensor_list): ( concat_axis, first_order, new_partition, ) = Resharder.compute_concat_info( partition_tensor_list[i][1], partition_index ) if concat_axis != -1: if first_order == 0: new_tensor = np.concatenate( (partition_tensor_list[i][0], tensor), axis=concat_axis, ) else: new_tensor = np.concatenate( (tensor, partition_tensor_list[i][0]), axis=concat_axis, ) partition_tensor_list.pop(i) Converter.merge( partition_tensor_list, new_tensor, new_partition, complete_shape, ) break i += 1 @staticmethod def split(complete_tensor, partition_index_list, length): """ Slice a complete tensor. Returns: sliced_tensor_list(list): sliced tensors with 'partition_index_list' Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:DISTRIBUTED) >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.converter import Converter >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> rank = 2 >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> sliced_tensor_list = Converter.split(complete_tensor, [[], [], [2, 4]], 3) >>> print(sliced_tensor_list) [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] """ sliced_tensor_list = [] axis = len(complete_tensor.shape) - length sliced_tensor = np.split( complete_tensor, partition_index_list[axis], axis=axis ) if length == 1: return sliced_tensor for tensor in sliced_tensor: sliced_tensor_list.extend( Converter.split(tensor, partition_index_list, length - 1) ) return sliced_tensor_list @staticmethod def _get_split_indices( complete_shape, dims_mapping, process_shape, process_group ): """ Get split indices of every dimension. Returns: split_indices_list(list): the split indices of every dimension of the tensor Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:DISTRIBUTED) >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.utils import _get_split_indices >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) >>> print(index) [[], [], [2, 4]] """ from .reshard import Resharder split_indices_list = [] for process in process_group: partition_index = Resharder.compute_partition_index( process, complete_shape, dims_mapping, process_shape, process_group, ) if split_indices_list: for dim in range(len(partition_index)): split_indices_list[dim].extend(partition_index[dim]) else: split_indices_list = partition_index split_indices_list = list( map( lambda x, y: list(set(x) - {y} - {0}), split_indices_list, complete_shape, ) ) split_indices_list = [sorted(x) for x in split_indices_list] return split_indices_list @staticmethod def _get_sliced_index( rank_id, complete_shape, dims_mapping, process_shape, process_group ): """ Get sliced_tensor's index of current rank in all sliced tensors list. Returns: sliced_tensor_index(int): the index of sliced tensor in sliced_tensor_list Examples: .. code-block:: python >>> # doctest: +REQUIRES(env:DISTRIBUTED) >>> import numpy as np >>> from paddle.distributed.auto_parallel.static.converter import Converter >>> complete_tensor = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) >>> rank = 2 >>> complete_shape = [1, 1, 6] >>> dims_mapping = [-1, -1, 0] >>> process_shape = [3] >>> process_group = [0, 1, 2] >>> index = Converter._get_sliced_index(rank, complete_shape, dims_mapping, ... process_shape, process_group) >>> print(index) 2 """ from .reshard import Resharder partition_index = Resharder.compute_partition_index( rank_id, complete_shape, dims_mapping, process_shape, process_group ) sliced_index = 0 for i, shape in enumerate(complete_shape): if dims_mapping[i] == -1: slice_shape = shape else: slice_shape = shape // process_shape[dims_mapping[i]] if slice_shape == 1: index = partition_index[i][0] else: index = (partition_index[i][0] + 1) // slice_shape sliced_index = sliced_index * (shape // slice_shape) + index return sliced_index