From 133458b114225b2661a364451a457a38bd9cd2fd Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 30 Jan 2019 07:47:49 +0000 Subject: [PATCH] use load_params --- fluid/PaddleCV/video/.gitignore | 5 +++++ fluid/PaddleCV/video/models/model.py | 23 +++++++++------------- fluid/PaddleCV/video/models/stnet/stnet.py | 11 ++--------- fluid/PaddleCV/video/test.py | 1 + fluid/PaddleCV/video/train.py | 2 +- 5 files changed, 18 insertions(+), 24 deletions(-) create mode 100644 fluid/PaddleCV/video/.gitignore diff --git a/fluid/PaddleCV/video/.gitignore b/fluid/PaddleCV/video/.gitignore new file mode 100644 index 00000000..3ee384cd --- /dev/null +++ b/fluid/PaddleCV/video/.gitignore @@ -0,0 +1,5 @@ +data +checkpoints +output* +*.py +*.swp diff --git a/fluid/PaddleCV/video/models/model.py b/fluid/PaddleCV/video/models/model.py index eb4d581d..7fdade3c 100755 --- a/fluid/PaddleCV/video/models/model.py +++ b/fluid/PaddleCV/video/models/model.py @@ -13,6 +13,7 @@ #limitations under the License. import os +import logging try: from configparser import ConfigParser except: @@ -25,6 +26,8 @@ from .utils import download, AttrDict WEIGHT_DIR = os.path.expanduser("~/.paddle/weights") +logger = logging.getLogger(__name__) + class NotImplementError(Exception): "Error: model function not implement" @@ -163,15 +166,14 @@ class ModelBase(object): "get model weight default path and download url" raise NotImplementError(self, self.weights_info) - def get_weights(self, logger=None): + def get_weights(self): "get model weight file path, download weight from Paddle if not exist" path, url = self.weights_info() path = os.path.join(WEIGHT_DIR, path) if os.path.exists(path): return path - if logger: - logger.info("Download weights of {} from {}".format(self.name, url)) + logger.info("Download weights of {} from {}".format(self.name, url)) download(url, path) return path @@ -186,7 +188,7 @@ class ModelBase(object): "get pretrain base model directory" return (None, None) - def get_pretrain_weights(self, logger=None): + def get_pretrain_weights(self): "get model weight file path, download weight from Paddle if not exist" path, url = self.pretrain_info() if not path: @@ -196,22 +198,15 @@ class ModelBase(object): if os.path.exists(path): return path - if logger: - logger.info("Download pretrain weights of {} from {}".format( + logger.info("Download pretrain weights of {} from {}".format( self.name, url)) utils.download(url, path) return path - def load_pretrained_params(self, exe, pretrained_base, prog, place): + def load_pretrain_params(self, exe, pretrain, prog): def if_exist(var): return os.path.exists(os.path.join(pretrained_base, var.name)) - - inference_program = prog.clone(for_test=True) - fluid.io.load_vars( - exe, - pretrained_base, - predicate=if_exist, - main_program=inference_program) + fluid.io.load_params(exe, pretrain, main_program=prog) def get_config_from_sec(self, sec, item, default=None): cfg_item = self._config.get_config_from_sec(sec.upper(), diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index e60db862..773b1f8e 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -153,15 +153,8 @@ class STNET(ModelBase): def create_metrics_args(self): return {} - def load_pretrained_params(self, exe, pretrain_base, prog, place): - def is_parameter(var): - if isinstance(var, fluid.framework.Parameter): - return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ - and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) - - inference_program = prog.clone(for_test=True) - vars = filter(is_parameter, inference_program.list_vars()) - fluid.io.load_vars(exe, pretrain_base, vars=vars) + def load_pretrain_params(self, exe, pretrain, prog): + fluid.io.load_params(exe, pretrain, main_program=prog) param_tensor = fluid.global_scope().find_var( "conv1_weights").get_tensor() diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 13cc0e2b..2d5d8074 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -21,6 +21,7 @@ import numpy as np import paddle.fluid as fluid import models +logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) diff --git a/fluid/PaddleCV/video/train.py b/fluid/PaddleCV/video/train.py index 1bb49119..000b7bad 100755 --- a/fluid/PaddleCV/video/train.py +++ b/fluid/PaddleCV/video/train.py @@ -152,7 +152,7 @@ def train(train_model, valid_model, args): "Given pretrain weight dir {} not exist.".format(args.pretrain) pretrain = args.pretrain or train_model.get_pretrain_weights() if pretrain: - train_model.load_pretrained_params(exe, pretrain, train_prog, place) + train_model.load_pretrain_params(exe, pretrain, train_prog) if args.no_parallel: train_exe = exe -- GitLab