diff --git a/dygraph/train.py b/dygraph/train.py index 24ddbcc3c52e7be87522ada4d284a77afc054c37..71d6985bf7924b790ab84d8314659a23f3e08193 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -26,6 +26,7 @@ import models import utils.logging as logging from utils import get_environ_info from utils import load_pretrained_model +from utils import resume from val import evaluate @@ -117,13 +118,18 @@ def train(model, num_epochs=100, batch_size=2, pretrained_model=None, + resume_model=None, save_interval_epochs=1, num_classes=None, num_workers=8): ignore_index = model.ignore_index nranks = ParallelEnv().nranks - load_pretrained_model(model, pretrained_model) + start_epoch = 0 + if resume_model is not None: + start_epoch = resume(optimizer, resume_model) + elif pretrained_model is not None: + load_pretrained_model(model, pretrained_model) if not os.path.isdir(save_dir): if os.path.exists(save_dir): @@ -144,7 +150,7 @@ def train(model, return_list=True, ) - for epoch in range(num_epochs): + for epoch in range(start_epoch, num_epochs): for step, data in enumerate(loader): images = data[0] labels = data[1].astype('int64') @@ -158,9 +164,11 @@ def train(model, loss.backward() optimizer.minimize(loss) model.clear_gradients() - logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( - epoch + 1, num_epochs, step + 1, len(batch_sampler), - loss.numpy())) + lr = optimizer.current_step_lr() + logging.info( + "[TRAIN] Epoch={}/{}, Step={}/{}, loss={}, lr={}".format( + epoch + 1, num_epochs, step + 1, len(batch_sampler), + loss.numpy(), lr)) if ((epoch + 1) % save_interval_epochs == 0 or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0: @@ -170,6 +178,8 @@ def train(model, os.makedirs(current_save_dir) fluid.save_dygraph(model.state_dict(), os.path.join(current_save_dir, 'model')) + fluid.save_dygraph(optimizer.state_dict(), + os.path.join(current_save_dir, 'model')) if eval_dataset is not None: evaluate( diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py index 7a450b352e0dcf98c1eeaa093878c9b3ba649dfd..7127b44cf3a938f4a9db0bb6a6128c3516facc46 100644 --- a/dygraph/utils/utils.py +++ b/dygraph/utils/utils.py @@ -49,7 +49,7 @@ def get_environ_info(): def load_pretrained_model(model, pretrained_model): if pretrained_model is not None: - logging.info('Load pretrained model!') + logging.info('Load pretrained model from {}'.format(pretrained_model)) if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') para_state_dict, _ = fluid.load_dygraph(ckpt_path) @@ -78,6 +78,23 @@ def load_pretrained_model(model, pretrained_model): pretrained_model)) +def resume(optimizer, resume_model): + if resume_model is not None: + logging.info('Resume model from {}'.format(resume_model)) + if os.path.exists(resume_model): + ckpt_path = os.path.join(resume_model, 'model') + _, opti_state_dict = fluid.load_dygraph(ckpt_path) + optimizer.set_dict(opti_state_dict) + epoch = resume_model.split('_')[-1] + if epoch.isdigit(): + epoch = int(epoch) + return epoch + else: + raise ValueError( + 'The resume model directory is not Found: {}'.formnat( + resume_model)) + + def visualize(image, result, save_dir=None, weight=0.6): """ Convert segment result to color image, and save added image.