From 694a1c8067e94099619f2f2e6554532ea631ec1e Mon Sep 17 00:00:00 2001 From: chenzomi Date: Thu, 25 Jun 2020 09:52:37 +0800 Subject: [PATCH] fix checkpoint evaliaction. --- mindspore/train/serialization.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c4fba9ba8..ff1b8c312 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -186,9 +186,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): raise ValueError(e.__str__()) parameter_dict = {} - if model_type != checkpoint_list.model_type: - raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( - checkpoint_list.model_type, model_type)) + if checkpoint_list.model_type: + if model_type != checkpoint_list.model_type: + raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( + checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content -- GitLab