提交 485c7d5c 编写于 作者: D dengkaipeng

add resume weights in train.py

上级 932be2e9
......@@ -28,6 +28,7 @@ from config import *
import models
from datareader import get_reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -143,8 +143,6 @@ class ModelBase(object):
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
......
......@@ -25,6 +25,7 @@ import models
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......
......@@ -26,6 +26,7 @@ from config import *
from datareader import get_reader
from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
......@@ -59,6 +60,12 @@ def parse_args():
default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
)
parser.add_argument(
'--resume-weights',
type=str,
default=None,
help='path to resume weights. If not None, only resume weigths will be loaded.'
)
parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument(
......@@ -141,12 +148,19 @@ def train(args):
exe = fluid.Executor(place)
exe.run(startup)
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
if args.resume_weights:
# if resume weights is given, load resume weights directly
assert os.path.exists(args.resume_weights), \
"Given resume weight dir {} not exist.".format(args.resume_weights)
fluid.io.load_params(exe, args.resume_weights, main_program=train_prog)
else:
# if not in resume mode, load pretrain weights
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册