提交 8f59bf52 编写于 作者: W wuzewu

Add pretrained param load log

上级 85645767
...@@ -261,6 +261,7 @@ def train(cfg): ...@@ -261,6 +261,7 @@ def train(cfg):
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL):
print('Pretrained model dir:', cfg.TRAIN.PRETRAINED_MODEL) print('Pretrained model dir:', cfg.TRAIN.PRETRAINED_MODEL)
load_vars = [] load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape): def var_shape_matched(var, shape):
""" """
...@@ -271,12 +272,7 @@ def train(cfg): ...@@ -271,12 +272,7 @@ def train(cfg):
if var_exist: if var_exist:
var_shape = parse_shape_from_file( var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL, var.name)) os.path.join(cfg.TRAIN.PRETRAINED_MODEL, var.name))
if var_shape == shape: return var_shape == shape
return True
else:
print(
"Variable[{}] shape does not match current network, skip"
" to load it.".format(var.name))
return False return False
for x in train_prog.list_vars(): for x in train_prog.list_vars():
...@@ -285,13 +281,22 @@ def train(cfg): ...@@ -285,13 +281,22 @@ def train(cfg):
x.name).get_tensor().shape()) x.name).get_tensor().shape())
if var_shape_matched(x, shape): if var_shape_matched(x, shape):
load_vars.append(x) load_vars.append(x)
else:
load_fail_vars.append(x)
if cfg.MODEL.FP16: if cfg.MODEL.FP16:
# If open FP16 training mode, load FP16 var separate # If open FP16 training mode, load FP16 var separate
load_fp16_vars(exe, cfg.TRAIN.PRETRAINED_MODEL, train_prog) load_fp16_vars(exe, cfg.TRAIN.PRETRAINED_MODEL, train_prog)
else: else:
fluid.io.load_vars( fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL, vars=load_vars) exe, dirname=cfg.TRAIN.PRETRAINED_MODEL, vars=load_vars)
print("Pretrained model loaded successfully!") for var in load_vars:
print("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print("Parameter[{}] shape does not match current network, skip"
" to load it.".format(var.name))
print("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print('Pretrained model dir {} not exists, training from scratch...'. print('Pretrained model dir {} not exists, training from scratch...'.
format(cfg.TRAIN.PRETRAINED_MODEL)) format(cfg.TRAIN.PRETRAINED_MODEL))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册