提交 aa160894 编写于 作者: C chenguowei01

add resume training

上级 40950add
...@@ -26,6 +26,7 @@ import models ...@@ -26,6 +26,7 @@ import models
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import load_pretrained_model from utils import load_pretrained_model
from utils import resume
from val import evaluate from val import evaluate
...@@ -117,13 +118,18 @@ def train(model, ...@@ -117,13 +118,18 @@ def train(model,
num_epochs=100, num_epochs=100,
batch_size=2, batch_size=2,
pretrained_model=None, pretrained_model=None,
resume_model=None,
save_interval_epochs=1, save_interval_epochs=1,
num_classes=None, num_classes=None,
num_workers=8): num_workers=8):
ignore_index = model.ignore_index ignore_index = model.ignore_index
nranks = ParallelEnv().nranks nranks = ParallelEnv().nranks
load_pretrained_model(model, pretrained_model) start_epoch = 0
if resume_model is not None:
start_epoch = resume(optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir): if not os.path.isdir(save_dir):
if os.path.exists(save_dir): if os.path.exists(save_dir):
...@@ -144,7 +150,7 @@ def train(model, ...@@ -144,7 +150,7 @@ def train(model,
return_list=True, return_list=True,
) )
for epoch in range(num_epochs): for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
labels = data[1].astype('int64') labels = data[1].astype('int64')
...@@ -158,9 +164,11 @@ def train(model, ...@@ -158,9 +164,11 @@ def train(model,
loss.backward() loss.backward()
optimizer.minimize(loss) optimizer.minimize(loss)
model.clear_gradients() model.clear_gradients()
logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( lr = optimizer.current_step_lr()
epoch + 1, num_epochs, step + 1, len(batch_sampler), logging.info(
loss.numpy())) "[TRAIN] Epoch={}/{}, Step={}/{}, loss={}, lr={}".format(
epoch + 1, num_epochs, step + 1, len(batch_sampler),
loss.numpy(), lr))
if ((epoch + 1) % save_interval_epochs == 0 if ((epoch + 1) % save_interval_epochs == 0
or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0: or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0:
...@@ -170,6 +178,8 @@ def train(model, ...@@ -170,6 +178,8 @@ def train(model,
os.makedirs(current_save_dir) os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(), fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model')) os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None: if eval_dataset is not None:
evaluate( evaluate(
......
...@@ -49,7 +49,7 @@ def get_environ_info(): ...@@ -49,7 +49,7 @@ def get_environ_info():
def load_pretrained_model(model, pretrained_model): def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None: if pretrained_model is not None:
logging.info('Load pretrained model!') logging.info('Load pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path) para_state_dict, _ = fluid.load_dygraph(ckpt_path)
...@@ -78,6 +78,23 @@ def load_pretrained_model(model, pretrained_model): ...@@ -78,6 +78,23 @@ def load_pretrained_model(model, pretrained_model):
pretrained_model)) pretrained_model))
def resume(optimizer, resume_model):
if resume_model is not None:
logging.info('Resume model from {}'.format(resume_model))
if os.path.exists(resume_model):
ckpt_path = os.path.join(resume_model, 'model')
_, opti_state_dict = fluid.load_dygraph(ckpt_path)
optimizer.set_dict(opti_state_dict)
epoch = resume_model.split('_')[-1]
if epoch.isdigit():
epoch = int(epoch)
return epoch
else:
raise ValueError(
'The resume model directory is not Found: {}'.formnat(
resume_model))
def visualize(image, result, save_dir=None, weight=0.6): def visualize(image, result, save_dir=None, weight=0.6):
""" """
Convert segment result to color image, and save added image. Convert segment result to color image, and save added image.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册