提交 bed63326 编写于 作者: C chenzomi

fix checkpoint evaliaction.

上级 3b632eac
...@@ -187,9 +187,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): ...@@ -187,9 +187,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
raise ValueError(e.__str__()) raise ValueError(e.__str__())
parameter_dict = {} parameter_dict = {}
if model_type != checkpoint_list.model_type: if checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( if model_type != checkpoint_list.model_type:
checkpoint_list.model_type, model_type)) raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try: try:
for element in checkpoint_list.value: for element in checkpoint_list.value:
data = element.tensor.tensor_content data = element.tensor.tensor_content
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册