未验证 提交 7528035c 编写于 作者: H hipudding 提交者: GitHub

Use device_name instead of device index to support other device (#3933)

Create tensor with device=Integer will always choose cuda as its
deivce for current pytorch version(2.1), other device should use
device={device}:{index}.

Change get_accelerator().current_device() to
get_accelerator().current_device_name() to support other devices.
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 4d965416
......@@ -1043,13 +1043,13 @@ class Init(InsertPostInitMethodToModuleSubClasses):
param_buffer = torch.empty(
buffer_size,
dtype=param.dtype if not quant else torch.int8,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
param_ds_tensor = param.ds_secondary_tensor if not forward and param.ds_secondary_tensor is not None else param.ds_tensor
if not quant:
handles = _dist_allgather_fn(
param_ds_tensor.to(get_accelerator().current_device()),
param_ds_tensor.to(get_accelerator().current_device_name()),
param_buffer,
ds_process_group,
)
......@@ -1057,16 +1057,16 @@ class Init(InsertPostInitMethodToModuleSubClasses):
return AllGatherHandle(handles, param)
else:
quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor)
handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device()), param_buffer,
ds_process_group)
handle = _dist_allgather_fn(quantized_param.to(get_accelerator().current_device_name()),
param_buffer, ds_process_group)
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device()),
quant_handle = _dist_allgather_fn(scales.to(get_accelerator().current_device_name()),
quant_scale_buffer, ds_process_group)
quant_info = QuantizationInfo()
......@@ -1086,7 +1086,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
flat_tensor = torch.empty(partition_sz * world_size,
dtype=get_only_unique_item(p.dtype
for p in params) if not quant else torch.int8,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False)
if not quant:
partitions: List[Parameter] = []
......@@ -1118,17 +1118,17 @@ class Init(InsertPostInitMethodToModuleSubClasses):
use_secondary_tensor = True
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(torch.cat)(
[p.ds_secondary_tensor.to(get_accelerator().current_device()) for p in params]))
[p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params]))
else:
quantized_param, scales = self.quantizer_module.quantize(
instrument_w_nvtx(
torch.cat)([p.ds_tensor.to(get_accelerator().current_device()) for p in params]))
torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params]))
handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group)
quant_info = QuantizationInfo()
quant_scale_buffer = torch.empty(
scales.numel() * world_size,
dtype=torch.float32,
device=get_accelerator().current_device(),
device=get_accelerator().current_device_name(),
requires_grad=False,
)
quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册