提交 45cc0ce3 编写于 作者: A AUTOMATIC

Merge remote-tracking branch 'origin/master'

......@@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model.load_state_dict(sd, strict=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册