未验证 提交 e5327e01 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #1746 from heavengate/video_resume

add resume weights in train.py
...@@ -28,6 +28,7 @@ from config import * ...@@ -28,6 +28,7 @@ from config import *
import models import models
from datareader import get_reader from datareader import get_reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -143,8 +143,6 @@ class ModelBase(object): ...@@ -143,8 +143,6 @@ class ModelBase(object):
return path return path
def load_pretrain_params(self, exe, pretrain, prog, place): def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog) fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
......
...@@ -132,7 +132,7 @@ class STNET(ModelBase): ...@@ -132,7 +132,7 @@ class STNET(ModelBase):
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars()) vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars) fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
param_tensor = fluid.global_scope().find_var( param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor() "conv1_weights").get_tensor()
......
...@@ -25,6 +25,7 @@ import models ...@@ -25,6 +25,7 @@ import models
from datareader import get_reader from datareader import get_reader
from metrics import get_metrics from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -26,6 +26,7 @@ from config import * ...@@ -26,6 +26,7 @@ from config import *
from datareader import get_reader from datareader import get_reader
from metrics import get_metrics from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,6 +60,13 @@ def parse_args(): ...@@ -59,6 +60,13 @@ def parse_args():
default=None, default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.' help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
) )
parser.add_argument(
'--resume',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.'
)
parser.add_argument( parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.') '--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument( parser.add_argument(
...@@ -141,6 +149,15 @@ def train(args): ...@@ -141,6 +149,15 @@ def train(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
if args.resume:
# if resume weights is given, load resume weights directly
assert os.path.exists(args.resume), \
"Given resume weight dir {} not exist.".format(args.resume)
def if_exist(var):
return os.path.exists(os.path.join(args.resume, var.name))
fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog)
else:
# if not in resume mode, load pretrain weights
if args.pretrain: if args.pretrain:
assert os.path.exists(args.pretrain), \ assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain) "Given pretrain weight dir {} not exist.".format(args.pretrain)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册