From 8f59bf52581f00a4bc80f94af2a6eb3d5cf54200 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 4 Sep 2019 20:21:15 +0800 Subject: [PATCH] Add pretrained param load log --- pdseg/train.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pdseg/train.py b/pdseg/train.py index 6166603c..3238701d 100644 --- a/pdseg/train.py +++ b/pdseg/train.py @@ -261,6 +261,7 @@ def train(cfg): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL): print('Pretrained model dir:', cfg.TRAIN.PRETRAINED_MODEL) load_vars = [] + load_fail_vars = [] def var_shape_matched(var, shape): """ @@ -271,13 +272,8 @@ def train(cfg): if var_exist: var_shape = parse_shape_from_file( os.path.join(cfg.TRAIN.PRETRAINED_MODEL, var.name)) - if var_shape == shape: - return True - else: - print( - "Variable[{}] shape does not match current network, skip" - " to load it.".format(var.name)) - return False + return var_shape == shape + return False for x in train_prog.list_vars(): if isinstance(x, fluid.framework.Parameter): @@ -285,13 +281,22 @@ def train(cfg): x.name).get_tensor().shape()) if var_shape_matched(x, shape): load_vars.append(x) + else: + load_fail_vars.append(x) if cfg.MODEL.FP16: # If open FP16 training mode, load FP16 var separate load_fp16_vars(exe, cfg.TRAIN.PRETRAINED_MODEL, train_prog) else: fluid.io.load_vars( exe, dirname=cfg.TRAIN.PRETRAINED_MODEL, vars=load_vars) - print("Pretrained model loaded successfully!") + for var in load_vars: + print("Parameter[{}] loaded sucessfully!".format(var.name)) + for var in load_fail_vars: + print("Parameter[{}] shape does not match current network, skip" + " to load it.".format(var.name)) + print("{}/{} pretrained parameters loaded successfully!".format( + len(load_vars), + len(load_vars) + len(load_fail_vars))) else: print('Pretrained model dir {} not exists, training from scratch...'. format(cfg.TRAIN.PRETRAINED_MODEL)) -- GitLab