未验证 提交 16699d83 编写于 作者: S Stas Bekman 提交者: GitHub

[ds-inference] checkpoint loading => tqdm (#2107)

* [ds-inference] checkpoint loading => tqdm

solve 2 issues:
- less noise using tqdm progress bar
- more informative - tell users how much to wait and how many shards to load

New way:

```
Loading 72 checkpoints:  12%|█▎        | 9/72 [01:12<08:39,  8.25s/it]
```

* write only from one process

* style
上级 aa88137b
import copy
import torch
import tqdm
import deepspeed
import deepspeed.ops.transformer as transformer_inference
from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy, BLOOMLayerPolicy
......@@ -765,9 +766,11 @@ def replace_transformer_layer(orig_layer_impl,
_replace_policy=policy)
if checkpoint is not None:
pbar = tqdm.tqdm(total=len(checkpoint),
desc=f"Loading {len(checkpoint)} checkpoint shards")
for i in range(len(checkpoint)):
if not deepspeed.comm.is_initialized() or deepspeed.comm.get_rank() == 0:
print(f"loading checkpoint ({i})")
pbar.update(1)
sd = torch.load(checkpoint[i], map_location='cpu')
load_model_with_checkpoint(replaced_module, sd, mp_replace)
return replaced_module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册