未验证 提交 39bdc141 编写于 作者: R Reza Yazdani 提交者: GitHub

fixing the checkpoint loading at inference-engine (#2429)

Co-authored-by: NAmmar Ahmad Awan <ammar.awan@microsoft.com>
上级 10e9d04c
......@@ -449,7 +449,8 @@ class InferenceEngine(Module):
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
else:
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
self.checkpoint_engine)
if type(sd_loader) is list:
self.sd = torch.load(sd_loader[0], map_location='cpu')
......@@ -492,7 +493,6 @@ class InferenceEngine(Module):
self.module.load_state_dict(
state_dict=checkpoint[self._choose_module_key(checkpoint)],
checkpoint_engine=self.checkpoint_engine,
strict=load_module_strict)
def _choose_module_key(self, sd):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册