From b30fd0df413d2f14a23dd5103c86ca3412438c12 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 18 Feb 2019 06:05:27 +0000 Subject: [PATCH] add logger in load_pretrain_weights --- fluid/PaddleCV/video/models/model.py | 1 + fluid/PaddleCV/video/models/stnet/stnet.py | 4 ++++ fluid/PaddleCV/video/models/tsn/tsn.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/fluid/PaddleCV/video/models/model.py b/fluid/PaddleCV/video/models/model.py index 54c454eb..44f888ef 100755 --- a/fluid/PaddleCV/video/models/model.py +++ b/fluid/PaddleCV/video/models/model.py @@ -143,6 +143,7 @@ class ModelBase(object): return path def load_pretrain_params(self, exe, pretrain, prog, place): + logger.info("Load pretrain weights from {}".format(pretrain)) fluid.io.load_params(exe, pretrain, main_program=prog) def get_config_from_sec(self, sec, item, default=None): diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index 0d710ee7..e20ad0bd 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -17,6 +17,9 @@ import paddle.fluid as fluid from ..model import ModelBase from .stnet_res_model import StNet_ResNet +import logging +logger = logging.getLogger(__name__) + __all__ = ["STNET"] @@ -131,6 +134,7 @@ class STNET(ModelBase): return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) + logger.info("Load pretrain weights from {}, exclude fc, batch_norm, xception, conv3d layers.".format(pretrain)) vars = filter(is_parameter, prog.list_vars()) fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog) diff --git a/fluid/PaddleCV/video/models/tsn/tsn.py b/fluid/PaddleCV/video/models/tsn/tsn.py index be145858..5bc8aba3 100644 --- a/fluid/PaddleCV/video/models/tsn/tsn.py +++ b/fluid/PaddleCV/video/models/tsn/tsn.py @@ -18,6 +18,9 @@ from paddle.fluid import ParamAttr from ..model import ModelBase from .tsn_res_model import TSN_ResNet +import logging +logger = logging.getLogger(__name__) + __all__ = ["TSN"] @@ -133,6 +136,7 @@ class TSN(ModelBase): def is_parameter(var): return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) + logger.info("Load pretrain weights from {}, exclude fc layer.".format(pretrain)) vars = filter(is_parameter, prog.list_vars()) fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog) -- GitLab