diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 4e27f12c1d4830f2f16580bfa976cf3ace78d934..1f3c66d45e8e980ddbaec2f64da46f896cbe1b38 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -105,7 +105,8 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): net.set_state_dict(para_dict) loss.set_state_dict(para_dict) for i in range(len(optimizer)): - optimizer[i].set_state_dict(opti_dict) + optimizer[i].set_state_dict(opti_dict[i] if isinstance( + opti_dict, list) else opti_dict) logger.info("Finish load checkpoints from {}".format(checkpoints)) return metric_dict @@ -117,7 +118,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): else: # common load load_dygraph_pretrain(net, path=pretrained_model) logger.info("Finish load pretrained model from {}".format( - pretrained_model)) + pretrained_model)) def save_model(net,