diff --git a/fluid/PaddleCV/video/.gitignore b/fluid/PaddleCV/video/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3ee384cd9bef7718581d538c7a95941f8227ddf9 --- /dev/null +++ b/fluid/PaddleCV/video/.gitignore @@ -0,0 +1,5 @@ +data +checkpoints +output* +*.py +*.swp diff --git a/fluid/PaddleCV/video/models/model.py b/fluid/PaddleCV/video/models/model.py index eb4d581d69142b82e39efa7c31f6053f164991f4..7fdade3c7ee488478a7b342cf2bf3fb90ce4a7a9 100755 --- a/fluid/PaddleCV/video/models/model.py +++ b/fluid/PaddleCV/video/models/model.py @@ -13,6 +13,7 @@ #limitations under the License. import os +import logging try: from configparser import ConfigParser except: @@ -25,6 +26,8 @@ from .utils import download, AttrDict WEIGHT_DIR = os.path.expanduser("~/.paddle/weights") +logger = logging.getLogger(__name__) + class NotImplementError(Exception): "Error: model function not implement" @@ -163,15 +166,14 @@ class ModelBase(object): "get model weight default path and download url" raise NotImplementError(self, self.weights_info) - def get_weights(self, logger=None): + def get_weights(self): "get model weight file path, download weight from Paddle if not exist" path, url = self.weights_info() path = os.path.join(WEIGHT_DIR, path) if os.path.exists(path): return path - if logger: - logger.info("Download weights of {} from {}".format(self.name, url)) + logger.info("Download weights of {} from {}".format(self.name, url)) download(url, path) return path @@ -186,7 +188,7 @@ class ModelBase(object): "get pretrain base model directory" return (None, None) - def get_pretrain_weights(self, logger=None): + def get_pretrain_weights(self): "get model weight file path, download weight from Paddle if not exist" path, url = self.pretrain_info() if not path: @@ -196,22 +198,15 @@ class ModelBase(object): if os.path.exists(path): return path - if logger: - logger.info("Download pretrain weights of {} from {}".format( + logger.info("Download pretrain weights of {} from {}".format( self.name, url)) utils.download(url, path) return path - def load_pretrained_params(self, exe, pretrained_base, prog, place): + def load_pretrain_params(self, exe, pretrain, prog): def if_exist(var): return os.path.exists(os.path.join(pretrained_base, var.name)) - - inference_program = prog.clone(for_test=True) - fluid.io.load_vars( - exe, - pretrained_base, - predicate=if_exist, - main_program=inference_program) + fluid.io.load_params(exe, pretrain, main_program=prog) def get_config_from_sec(self, sec, item, default=None): cfg_item = self._config.get_config_from_sec(sec.upper(), diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index e60db862ae9b6e4d113b785317cc886c7f83acbb..773b1f8e4f134a4dc4b8169085b2763580254632 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -153,15 +153,8 @@ class STNET(ModelBase): def create_metrics_args(self): return {} - def load_pretrained_params(self, exe, pretrain_base, prog, place): - def is_parameter(var): - if isinstance(var, fluid.framework.Parameter): - return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ - and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) - - inference_program = prog.clone(for_test=True) - vars = filter(is_parameter, inference_program.list_vars()) - fluid.io.load_vars(exe, pretrain_base, vars=vars) + def load_pretrain_params(self, exe, pretrain, prog): + fluid.io.load_params(exe, pretrain, main_program=prog) param_tensor = fluid.global_scope().find_var( "conv1_weights").get_tensor() diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 13cc0e2b88e3aa58fe5cb3d32773b8bcbaae0643..2d5d8074f074cda2f6fe879de5623484fad7c2c1 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -21,6 +21,7 @@ import numpy as np import paddle.fluid as fluid import models +logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) diff --git a/fluid/PaddleCV/video/train.py b/fluid/PaddleCV/video/train.py index 1bb491195b64f978e3ad7b24cd7d24c186b3d466..000b7bad541ea2acb1436dc6bd6efcca145e2a63 100755 --- a/fluid/PaddleCV/video/train.py +++ b/fluid/PaddleCV/video/train.py @@ -152,7 +152,7 @@ def train(train_model, valid_model, args): "Given pretrain weight dir {} not exist.".format(args.pretrain) pretrain = args.pretrain or train_model.get_pretrain_weights() if pretrain: - train_model.load_pretrained_params(exe, pretrain, train_prog, place) + train_model.load_pretrain_params(exe, pretrain, train_prog) if args.no_parallel: train_exe = exe