diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ce776d6821226f881995ec4fac873367e7a2eb13..cd9217559994b95f388bb09f24b598c510a6afd2 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -187,9 +187,10 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None): raise ValueError(e.__str__()) parameter_dict = {} - if model_type != checkpoint_list.model_type: - raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( - checkpoint_list.model_type, model_type)) + if checkpoint_list.model_type: + if model_type != checkpoint_list.model_type: + raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( + checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content