diff --git a/fluid/PaddleCV/video/infer.py b/fluid/PaddleCV/video/infer.py index 8f20c5ede0486aa0ca05df8b147d3631069b165b..43470cede76a39f7b7ffdcb43c0481e25aeca11f 100755 --- a/fluid/PaddleCV/video/infer.py +++ b/fluid/PaddleCV/video/infer.py @@ -28,6 +28,7 @@ from config import * import models from datareader import get_reader +logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) diff --git a/fluid/PaddleCV/video/models/model.py b/fluid/PaddleCV/video/models/model.py index c4362763f7791264b1403003c8290fd8ffa9ed4a..54c454eb0232dc0e7c5061c7d30cb53d58a0e13e 100755 --- a/fluid/PaddleCV/video/models/model.py +++ b/fluid/PaddleCV/video/models/model.py @@ -143,8 +143,6 @@ class ModelBase(object): return path def load_pretrain_params(self, exe, pretrain, prog, place): - def if_exist(var): - return os.path.exists(os.path.join(pretrained_base, var.name)) fluid.io.load_params(exe, pretrain, main_program=prog) def get_config_from_sec(self, sec, item, default=None): diff --git a/fluid/PaddleCV/video/models/stnet/stnet.py b/fluid/PaddleCV/video/models/stnet/stnet.py index 03c5507b81417c84fd842f05b9e9dd74abf68224..0d710ee7edbd98a70a741370a3c200dc92e81490 100644 --- a/fluid/PaddleCV/video/models/stnet/stnet.py +++ b/fluid/PaddleCV/video/models/stnet/stnet.py @@ -132,7 +132,7 @@ class STNET(ModelBase): and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) vars = filter(is_parameter, prog.list_vars()) - fluid.io.load_vars(exe, pretrain, vars=vars) + fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog) param_tensor = fluid.global_scope().find_var( "conv1_weights").get_tensor() diff --git a/fluid/PaddleCV/video/test.py b/fluid/PaddleCV/video/test.py index 11ba12e9882f602d87bd6801ba11b6670ecd007a..9698caecc21a26dc38256b145dd54d04a2e13c88 100755 --- a/fluid/PaddleCV/video/test.py +++ b/fluid/PaddleCV/video/test.py @@ -25,6 +25,7 @@ import models from datareader import get_reader from metrics import get_metrics +logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) diff --git a/fluid/PaddleCV/video/train.py b/fluid/PaddleCV/video/train.py index 18fd188bea5acb5f7860e9258ecd4eab63ce4240..154c51edd431286555b0e11d42a2c7a50ff4ee42 100755 --- a/fluid/PaddleCV/video/train.py +++ b/fluid/PaddleCV/video/train.py @@ -26,6 +26,7 @@ from config import * from datareader import get_reader from metrics import get_metrics +logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logger = logging.getLogger(__name__) @@ -59,6 +60,13 @@ def parse_args(): default=None, help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.' ) + parser.add_argument( + '--resume', + type=str, + default=None, + help='path to resume training based on previous checkpoints. ' + 'None for not resuming any checkpoints.' + ) parser.add_argument( '--use-gpu', type=bool, default=True, help='default use gpu.') parser.add_argument( @@ -141,12 +149,21 @@ def train(args): exe = fluid.Executor(place) exe.run(startup) - if args.pretrain: - assert os.path.exists(args.pretrain), \ - "Given pretrain weight dir {} not exist.".format(args.pretrain) - pretrain = args.pretrain or train_model.get_pretrain_weights() - if pretrain: - train_model.load_pretrain_params(exe, pretrain, train_prog, place) + if args.resume: + # if resume weights is given, load resume weights directly + assert os.path.exists(args.resume), \ + "Given resume weight dir {} not exist.".format(args.resume) + def if_exist(var): + return os.path.exists(os.path.join(args.resume, var.name)) + fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog) + else: + # if not in resume mode, load pretrain weights + if args.pretrain: + assert os.path.exists(args.pretrain), \ + "Given pretrain weight dir {} not exist.".format(args.pretrain) + pretrain = args.pretrain or train_model.get_pretrain_weights() + if pretrain: + train_model.load_pretrain_params(exe, pretrain, train_prog, place) train_exe = fluid.ParallelExecutor( use_cuda=args.use_gpu,