提交 485c7d5c 编写于 作者: D dengkaipeng

add resume weights in train.py

上级 932be2e9
...@@ -28,6 +28,7 @@ from config import * ...@@ -28,6 +28,7 @@ from config import *
import models import models
from datareader import get_reader from datareader import get_reader
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -143,8 +143,6 @@ class ModelBase(object): ...@@ -143,8 +143,6 @@ class ModelBase(object):
return path return path
def load_pretrain_params(self, exe, pretrain, prog, place): def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog) fluid.io.load_params(exe, pretrain, main_program=prog)
def get_config_from_sec(self, sec, item, default=None): def get_config_from_sec(self, sec, item, default=None):
......
...@@ -25,6 +25,7 @@ import models ...@@ -25,6 +25,7 @@ import models
from datareader import get_reader from datareader import get_reader
from metrics import get_metrics from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -26,6 +26,7 @@ from config import * ...@@ -26,6 +26,7 @@ from config import *
from datareader import get_reader from datareader import get_reader
from metrics import get_metrics from metrics import get_metrics
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,6 +60,12 @@ def parse_args(): ...@@ -59,6 +60,12 @@ def parse_args():
default=None, default=None,
help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.' help='path to pretrain weights. None to use default weights path in ~/.paddle/weights.'
) )
parser.add_argument(
'--resume-weights',
type=str,
default=None,
help='path to resume weights. If not None, only resume weigths will be loaded.'
)
parser.add_argument( parser.add_argument(
'--use-gpu', type=bool, default=True, help='default use gpu.') '--use-gpu', type=bool, default=True, help='default use gpu.')
parser.add_argument( parser.add_argument(
...@@ -141,12 +148,19 @@ def train(args): ...@@ -141,12 +148,19 @@ def train(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
if args.pretrain: if args.resume_weights:
assert os.path.exists(args.pretrain), \ # if resume weights is given, load resume weights directly
"Given pretrain weight dir {} not exist.".format(args.pretrain) assert os.path.exists(args.resume_weights), \
pretrain = args.pretrain or train_model.get_pretrain_weights() "Given resume weight dir {} not exist.".format(args.resume_weights)
if pretrain: fluid.io.load_params(exe, args.resume_weights, main_program=train_prog)
train_model.load_pretrain_params(exe, pretrain, train_prog, place) else:
# if not in resume mode, load pretrain weights
if args.pretrain:
assert os.path.exists(args.pretrain), \
"Given pretrain weight dir {} not exist.".format(args.pretrain)
pretrain = args.pretrain or train_model.get_pretrain_weights()
if pretrain:
train_model.load_pretrain_params(exe, pretrain, train_prog, place)
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu, use_cuda=args.use_gpu,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册