未验证 提交 e5327e01 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #1746 from heavengate/video_resume

add resume weights in train.py
......@@ -28,6 +28,7 @@ from config import *
import models
from datareader import get_reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -143,8 +143,6 @@ class ModelBase(object):
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
......
......@@ -132,7 +132,7 @@ class STNET(ModelBase):
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog)
param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor()
......
......@@ -25,6 +25,7 @@ import models
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -26,6 +26,7 @@ from config import *
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......@@ -59,6 +60,13 @@ def parse_args():
default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
)
parser.add_argument(
'--resume',
type=str,
default=None,
help='path to resume training based on previous checkpoints. '
'None for not resuming any checkpoints.'
)
parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument(
......@@ -141,12 +149,21 @@ def train(args):
exe = fluid.Executor(place)
exe.run(startup)
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
if args.resume:
# if resume weights is given, load resume weights directly
assert os.path.exists(args.resume), \
"Given resume weight dir {} not exist.".format(args.resume)
def if_exist(var):
return os.path.exists(os.path.join(args.resume, var.name))
fluid.io.load_vars(exe, args.resume, predicate=if_exist, main_program=train_prog)
else:
# if not in resume mode, load pretrain weights
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册