提交 8f59bf52 编写于 作者: W wuzewu

Add pretrained param load log

上级 85645767
......@@ -261,6 +261,7 @@ def train(cfg):
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL):
print('Pretrained model dir:', cfg.TRAIN.PRETRAINED_MODEL)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
......@@ -271,13 +272,8 @@ def train(cfg):
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL, var.name))
if var_shape == shape:
return True
else:
print(
"Variable[{}] shape does not match current network, skip"
" to load it.".format(var.name))
return False
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
......@@ -285,13 +281,22 @@ def train(cfg):
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
if cfg.MODEL.FP16:
# If open FP16 training mode, load FP16 var separate
load_fp16_vars(exe, cfg.TRAIN.PRETRAINED_MODEL, train_prog)
else:
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL, vars=load_vars)
print("Pretrained model loaded successfully!")
for var in load_vars:
print("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print("Parameter[{}] shape does not match current network, skip"
" to load it.".format(var.name))
print("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else:
print('Pretrained model dir {} not exists, training from scratch...'.
format(cfg.TRAIN.PRETRAINED_MODEL))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册