diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c4fba9ba8a5a0258f772cec14a9f8a9be64b85d5..ff1b8c3122334f81c049790d4b290d0f050741ce 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -186,9 +186,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): raise ValueError(e.__str__()) parameter_dict = {} - if model_type != checkpoint_list.model_type: - raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( - checkpoint_list.model_type, model_type)) + if checkpoint_list.model_type: + if model_type != checkpoint_list.model_type: + raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( + checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content