diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index df1a79fb132b2d41083d7d01a09fc1c6ffbcabd8..6ab45358eb40c5ba86ede8de459d6f0cc3a3db99 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -258,16 +258,17 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): longest_name = param_not_load[0] while prefix_name != longest_name and param_not_load: logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) - longest_name = sorted(param_not_load, key=len, reverse=True)[0] prefix_name = longest_name for net_param_name in param_not_load: for dict_name in parameter_dict: if dict_name.endswith(net_param_name): - tmp_name = dict_name[:-len(net_param_name)] - prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name + prefix_name = dict_name[:-len(net_param_name)] + break + if prefix_name != longest_name: + break if prefix_name != longest_name: - logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) + logger.warning("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) for _, param in net.parameters_and_names(): new_param_name = prefix_name + param.name if param.name in param_not_load and new_param_name in parameter_dict: