From d0ecff1b5abd89703b0d561419a490aae39e64f2 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Tue, 19 Jan 2021 18:49:30 +0800 Subject: [PATCH] fix save distillation model (#567) --- ppcls/utils/save_load.py | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index f83920b0..36f5236d 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -73,13 +73,13 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): def load_distillation_model(model, pretrained_model, load_static_weights): logger.info("In distillation mode, teacher model will be " "loaded firstly before student model.") - assert len(pretrained_model - ) == 2, "pretrained_model length should be 2 but got {}".format( - len(pretrained_model)) - assert len( - load_static_weights - ) == 2, "load_static_weights length should be 2 but got {}".format( - len(load_static_weights)) + + if not isinstance(pretrained_model, list): + pretrained_model = [pretrained_model] + + if not isinstance(load_static_weights, list): + load_static_weights = [load_static_weights] + teacher = model.teacher if hasattr(model, "teacher") else model._layers.teacher student = model.student if hasattr(model, @@ -88,16 +88,16 @@ def load_distillation_model(model, pretrained_model, load_static_weights): teacher, path=pretrained_model[0], load_static_weights=load_static_weights[0]) - logger.info( - logger.coloring("Finish initing teacher model from {}".format( - pretrained_model), "HEADER")) - load_dygraph_pretrain( - student, - path=pretrained_model[1], - load_static_weights=load_static_weights[1]) - logger.info( - logger.coloring("Finish initing student model from {}".format( - pretrained_model), "HEADER")) + logger.info("Finish initing teacher model from {}".format( + pretrained_model)) + # load student model + if len(pretrained_model) >= 2: + load_dygraph_pretrain( + student, + path=pretrained_model[1], + load_static_weights=load_static_weights[1]) + logger.info("Finish initing student model from {}".format( + pretrained_model)) def init_model(config, net, optimizer=None): @@ -123,11 +123,7 @@ def init_model(config, net, optimizer=None): load_static_weights = config.get('load_static_weights', False) use_distillation = config.get('use_distillation', False) if pretrained_model: - if isinstance(pretrained_model, - list): # load distillation pretrained model - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len( - pretrained_model) + if use_distillation: load_distillation_model(net, pretrained_model, load_static_weights) else: # common load load_dygraph_pretrain( -- GitLab