# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from paddle import _legacy_C_ops __all__ = [] FLOAT_TYPE_DICT = { paddle.float16: "float16", paddle.float32: "float32", paddle.float64: "float64", paddle.bfloat16: "bfloat16", } PADDLE_TO_NUMBER = { paddle.float16: 0, paddle.float32: 1, paddle.float64: 2, paddle.int32: 3, paddle.int64: 4, paddle.bfloat16: 5, } NUMBER_TO_DTYPE = { 0: "float16", 1: "float32", 2: "float64", 3: "int32", 4: "int64", 5: "bfloat16", } def is_float_tensor(tensor): """Is a float tensor""" return tensor.dtype in FLOAT_TYPE_DICT.keys() def get_tensor_dtype(dtype): assert dtype in FLOAT_TYPE_DICT.keys() return FLOAT_TYPE_DICT[dtype] def paddle_2_number(dtype): assert dtype in PADDLE_TO_NUMBER.keys() return PADDLE_TO_NUMBER[dtype] def number_2_dtype(number): assert number in NUMBER_TO_DTYPE.keys() return NUMBER_TO_DTYPE[number] def get_tensor_bytes(tensor): """Get the bytes a tensor occupied.""" elem_size = None if tensor.dtype == paddle.float32: elem_size = 4 elif tensor.dtype == paddle.float64: elem_size = 8 elif tensor.dtype == paddle.int64: elem_size = 8 elif tensor.dtype == paddle.int32: elem_size = 4 elif tensor.dtype == paddle.float16: elem_size = 2 elif tensor.dtype == paddle.int8: elem_size = 1 else: raise ValueError(f"unknown data type: {tensor.dtype}") return tensor.numel() * elem_size def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: no need to pass in tensor_list, the returned tensor is spliced """ if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id nranks = ( paddle.distributed.collective._get_global_group().nranks if group is None else group.nranks ) return _legacy_C_ops.c_allgather( tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks, )