From bed6332688bbae0f2e325bd4044abe008eeaf402 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Thu, 25 Jun 2020 09:43:27 +0800 Subject: [PATCH] fix checkpoint evaliaction. --- mindspore/train/serialization.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ce776d682..cd9217559 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -187,9 +187,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): raise ValueError(e.__str__()) parameter_dict = {} - if model_type != checkpoint_list.model_type: - raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( - checkpoint_list.model_type, model_type)) + if checkpoint_list.model_type: + if model_type != checkpoint_list.model_type: + raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( + checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content -- GitLab