提交 b17fbac3 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix distillation load

上级 927a7887
...@@ -45,10 +45,7 @@ def _mkdir_if_not_exist(path): ...@@ -45,10 +45,7 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain( def load_dygraph_pretrain(model, path=None, load_static_weights=False):
model,
path=None,
load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(path)) "exists.".format(path))
...@@ -74,9 +71,14 @@ def load_dygraph_pretrain( ...@@ -74,9 +71,14 @@ def load_dygraph_pretrain(
def load_distillation_model(model, pretrained_model, load_static_weights): def load_distillation_model(model, pretrained_model, load_static_weights):
logger.info("In distillation mode, teacher model will be " logger.info("In distillation mode, teacher model will be "
"loaded firstly before student model.") "loaded firstly before student model.")
assert len(pretrained_model) == 2, "pretrained_model length should be 2 but got {}".format(len(pretrained_model)) assert len(pretrained_model
assert len(load_static_weights) == 2, "load_static_weights length should be 2 but got {}".format(len(load_static_weights)) ) == 2, "pretrained_model length should be 2 but got {}".format(
len(pretrained_model))
assert len(
load_static_weights
) == 2, "load_static_weights length should be 2 but got {}".format(
len(load_static_weights))
load_dygraph_pretrain( load_dygraph_pretrain(
model.teacher, model.teacher,
path=pretrained_model[0], path=pretrained_model[0],
...@@ -92,6 +94,7 @@ def load_distillation_model(model, pretrained_model, load_static_weights): ...@@ -92,6 +94,7 @@ def load_distillation_model(model, pretrained_model, load_static_weights):
logger.coloring("Finish initing student model from {}".format( logger.coloring("Finish initing student model from {}".format(
pretrained_model), "HEADER")) pretrained_model), "HEADER"))
def init_model(config, net, optimizer=None): def init_model(config, net, optimizer=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
...@@ -114,13 +117,17 @@ def init_model(config, net, optimizer=None): ...@@ -114,13 +117,17 @@ def init_model(config, net, optimizer=None):
load_static_weights = config.get('load_static_weights', False) load_static_weights = config.get('load_static_weights', False)
use_distillation = config.get('use_distillation', False) use_distillation = config.get('use_distillation', False)
if pretrained_model: if pretrained_model:
if isinstance(pretrained_model, list): # load distillation pretrained model if isinstance(pretrained_model,
list): # load distillation pretrained model
if not isinstance(load_static_weights, list): if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(pretrained_model) load_static_weights = [load_static_weights] * len(
pretrained_model)
load_distillation_model(net, pretrained_model, load_static_weights) load_distillation_model(net, pretrained_model, load_static_weights)
else: # common load else: # common load
load_dygraph_pretrain( load_dygraph_pretrain(
net, path=pretrained_model, load_static_weights=load_static_weights) net,
path=pretrained_model,
load_static_weights=load_static_weights)
logger.info( logger.info(
logger.coloring("Finish initing model from {}".format( logger.coloring("Finish initing model from {}".format(
pretrained_model), "HEADER")) pretrained_model), "HEADER"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册