未验证 提交 c35bfe89 编写于 作者: R Reza Yazdani 提交者: GitHub

fix ds-inference without policy (#2247)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 fae896ef
......@@ -653,11 +653,10 @@ def replace_transformer_layer(orig_layer_impl,
new_bias = torch.empty((weight_shape[0]),
device=child.weight.device,
dtype=child.weight.dtype)
if z_inference:
with deepspeed.zero.GatheredParameters(child.bias, modifier_rank=0):
new_bias.data.copy_(child.bias.data)
else:
elif child.bias:
new_bias.data.copy_(child.bias.data)
return LinearAllreduce(data, child.bias if child.bias is None else \
torch.nn.parameter.Parameter(new_bias.to(torch.cuda.current_device())), mp_group)
......@@ -704,7 +703,9 @@ def replace_transformer_layer(orig_layer_impl,
child.weight.shape[1] // mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight, child.weight.ds_tensor.data)
data = mp_replace.copy(new_weight,
child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0],
child.weight.shape[1] // mp_size)
new_embedding.weight.data.copy_(data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册