compressed_ar.py 1.9 KB
# python -m torch.distributed.launch --nproc_per_node=1 24_bit_allreduce.py

import torch
import os
import cupy
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack

def torch2cupy(tensor):
    return cupy.fromDlpack(to_dlpack(tensor))


def cupy2torch(cupy_tensor):
    return from_dlpack(cupy_tensor.toDlpack())


def decompose_cupy(tensor):
    mantissa, exponent = cupy.frexp(torch2cupy(tensor.float()))
    return cupy2torch(mantissa).half(), cupy2torch(exponent).to(torch.int8)


def decompose(t):
    if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
        raise Exception('Torch version >= 1.9.0 needed for 24_bit_allreduce.decompose')
    mantissa, exponent = torch.frexp(t.float())
    return mantissa.half(), exponent.to(torch.int8)


def reconstruct(mantissa, exponent, original_dtype=torch.bfloat16):
    return torch.ldexp(mantissa, exponent).to(original_dtype)


def compressed_all_reduce_torch(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
    original_dtype = tensor.dtype
    m, e = decompose(tensor)
    torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op)
    torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op)
    return reconstruct(m, e, original_dtype)


def compressed_all_reduce_cupy(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
    original_dtype = tensor.dtype
    m, e = decompose_cupy(tensor)
    torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op)
    torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op)
    return reconstruct(m, e, original_dtype)

version = torch.__version__.split('.')
TORCH_VERSION_MAJOR = int(version[0])
TORCH_VERSION_MINOR = int(version[1])
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
    compressed_all_reduce = compressed_all_reduce_cupy
else:
    compressed_all_reduce = compressed_all_reduce_torch
反馈
建议
客服 返回
顶部