未验证 提交 d0ecff1b 编写于 作者: L littletomatodonkey 提交者: GitHub

fix save distillation model (#567)

上级 0e6fe6f1
...@@ -73,13 +73,13 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): ...@@ -73,13 +73,13 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
def load_distillation_model(model, pretrained_model, load_static_weights): def load_distillation_model(model, pretrained_model, load_static_weights):
logger.info("In distillation mode, teacher model will be " logger.info("In distillation mode, teacher model will be "
"loaded firstly before student model.") "loaded firstly before student model.")
assert len(pretrained_model
) == 2, "pretrained_model length should be 2 but got {}".format( if not isinstance(pretrained_model, list):
len(pretrained_model)) pretrained_model = [pretrained_model]
assert len(
load_static_weights if not isinstance(load_static_weights, list):
) == 2, "load_static_weights length should be 2 but got {}".format( load_static_weights = [load_static_weights]
len(load_static_weights))
teacher = model.teacher if hasattr(model, teacher = model.teacher if hasattr(model,
"teacher") else model._layers.teacher "teacher") else model._layers.teacher
student = model.student if hasattr(model, student = model.student if hasattr(model,
...@@ -88,16 +88,16 @@ def load_distillation_model(model, pretrained_model, load_static_weights): ...@@ -88,16 +88,16 @@ def load_distillation_model(model, pretrained_model, load_static_weights):
teacher, teacher,
path=pretrained_model[0], path=pretrained_model[0],
load_static_weights=load_static_weights[0]) load_static_weights=load_static_weights[0])
logger.info( logger.info("Finish initing teacher model from {}".format(
logger.coloring("Finish initing teacher model from {}".format( pretrained_model))
pretrained_model), "HEADER")) # load student model
if len(pretrained_model) >= 2:
load_dygraph_pretrain( load_dygraph_pretrain(
student, student,
path=pretrained_model[1], path=pretrained_model[1],
load_static_weights=load_static_weights[1]) load_static_weights=load_static_weights[1])
logger.info( logger.info("Finish initing student model from {}".format(
logger.coloring("Finish initing student model from {}".format( pretrained_model))
pretrained_model), "HEADER"))
def init_model(config, net, optimizer=None): def init_model(config, net, optimizer=None):
...@@ -123,11 +123,7 @@ def init_model(config, net, optimizer=None): ...@@ -123,11 +123,7 @@ def init_model(config, net, optimizer=None):
load_static_weights = config.get('load_static_weights', False) load_static_weights = config.get('load_static_weights', False)
use_distillation = config.get('use_distillation', False) use_distillation = config.get('use_distillation', False)
if pretrained_model: if pretrained_model:
if isinstance(pretrained_model, if use_distillation:
list): # load distillation pretrained model
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(
pretrained_model)
load_distillation_model(net, pretrained_model, load_static_weights) load_distillation_model(net, pretrained_model, load_static_weights)
else: # common load else: # common load
load_dygraph_pretrain( load_dygraph_pretrain(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册