提交 aa160894 编写于 作者: C chenguowei01

add resume training

上级 40950add
......@@ -26,6 +26,7 @@ import models
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from val import evaluate
......@@ -117,12 +118,17 @@ def train(model,
num_epochs=100,
batch_size=2,
pretrained_model=None,
resume_model=None,
save_interval_epochs=1,
num_classes=None,
num_workers=8):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
start_epoch = 0
if resume_model is not None:
start_epoch = resume(optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir):
......@@ -144,7 +150,7 @@ def train(model,
return_list=True,
)
for epoch in range(num_epochs):
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
......@@ -158,9 +164,11 @@ def train(model,
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format(
lr = optimizer.current_step_lr()
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}, lr={}".format(
epoch + 1, num_epochs, step + 1, len(batch_sampler),
loss.numpy()))
loss.numpy(), lr))
if ((epoch + 1) % save_interval_epochs == 0
or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0:
......@@ -170,6 +178,8 @@ def train(model,
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
evaluate(
......
......@@ -49,7 +49,7 @@ def get_environ_info():
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logging.info('Load pretrained model!')
logging.info('Load pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
......@@ -78,6 +78,23 @@ def load_pretrained_model(model, pretrained_model):
pretrained_model))
def resume(optimizer, resume_model):
if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
ckpt_path = os.path.join(resume_model, 'model')
_, opti_state_dict = fluid.load_dygraph(ckpt_path)
optimizer.set_dict(opti_state_dict)
epoch = resume_model.split('_')[-1]
if epoch.isdigit():
epoch = int(epoch)
return epoch
else:
raise ValueError(
'The resume model directory is not Found: {}'.formnat(
resume_model))
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册