未验证 提交 b146aa35 编写于 作者: S Stas Bekman 提交者: GitHub

[ds-inference] fix progress bar (#2286)

when loading the non-sharded checkpoint update the progress bar (fix by @RezaYazdaniAminabadi) - I've just tested it to work.
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 53182531
......@@ -857,7 +857,6 @@ def replace_transformer_layer(orig_layer_impl,
desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
sd = [
torch.load(os.path.join(base_dir1,
checkpoint[i]),
......@@ -870,6 +869,7 @@ def replace_transformer_layer(orig_layer_impl,
ckpt_type,
quantizer,
)
pbar.update(1)
else:
import gc
num_checkpoints = len(ckpt_list) // ckpt_mp_size
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册