utils.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16
from paddle import _legacy_C_ops
17

18
__all__ = []
19

20 21 22 23
FLOAT_TYPE_DICT = {
    paddle.float16: "float16",
    paddle.float32: "float32",
    paddle.float64: "float64",
24
    paddle.bfloat16: "bfloat16",
25 26 27 28 29 30 31
}

PADDLE_TO_NUMBER = {
    paddle.float16: 0,
    paddle.float32: 1,
    paddle.float64: 2,
    paddle.int32: 3,
32
    paddle.int64: 4,
33
    paddle.bfloat16: 5,
34 35 36 37 38 39 40
}

NUMBER_TO_DTYPE = {
    0: "float16",
    1: "float32",
    2: "float64",
    3: "int32",
41
    4: "int64",
42
    5: "bfloat16",
43
}
44 45 46 47


def is_float_tensor(tensor):
    """Is a float tensor"""
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
    return tensor.dtype in FLOAT_TYPE_DICT.keys()


def get_tensor_dtype(dtype):
    assert dtype in FLOAT_TYPE_DICT.keys()
    return FLOAT_TYPE_DICT[dtype]


def paddle_2_number(dtype):
    assert dtype in PADDLE_TO_NUMBER.keys()
    return PADDLE_TO_NUMBER[dtype]


def number_2_dtype(number):
    assert number in NUMBER_TO_DTYPE.keys()
    return NUMBER_TO_DTYPE[number]
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81


def get_tensor_bytes(tensor):
    """Get the bytes a tensor occupied."""
    elem_size = None
    if tensor.dtype == paddle.float32:
        elem_size = 4
    elif tensor.dtype == paddle.float64:
        elem_size = 8
    elif tensor.dtype == paddle.int64:
        elem_size = 8
    elif tensor.dtype == paddle.int32:
        elem_size = 4
    elif tensor.dtype == paddle.float16:
        elem_size = 2
    elif tensor.dtype == paddle.int8:
        elem_size = 1
    else:
82
        raise ValueError(f"unknown data type: {tensor.dtype}")
83
    return tensor.numel() * elem_size
84 85 86 87


def _all_gather(tensor, group=None, use_calc_stream=True):
    """
88
    The main difference with paddle.distributed.all_gather:
89 90 91 92 93
    no need to pass in tensor_list, the returned tensor is spliced
    """
    if group is not None and not group.is_member():
        return
    ring_id = 0 if group is None else group.id
94 95 96 97 98 99 100 101 102 103 104 105 106 107
    nranks = (
        paddle.distributed.collective._get_global_group().nranks
        if group is None
        else group.nranks
    )
    return _legacy_C_ops.c_allgather(
        tensor,
        'use_calc_stream',
        use_calc_stream,
        'ring_id',
        ring_id,
        'nranks',
        nranks,
    )