From 73ac7c1d10c69a3747769a4a538f824c9416c7bc Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Wed, 9 Dec 2020 11:49:14 +0800 Subject: [PATCH] [Dygraph]refine resume process and eval process (#1848) * refine resume process and modify code to support eval while batch_size is larger than 1 * modify code according to review --- ppdet/py_op/post_process.py | 8 ++++---- ppdet/utils/checkpoint.py | 10 ++++++++-- tools/train.py | 11 +++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/ppdet/py_op/post_process.py b/ppdet/py_op/post_process.py index f303ff401..58cfcd2c1 100755 --- a/ppdet/py_op/post_process.py +++ b/ppdet/py_op/post_process.py @@ -141,7 +141,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): det_res = [] k = 0 for i in range(len(bbox_nums)): - image_id = int(image_id[i][0]) + cur_image_id = int(image_id[i][0]) det_nums = bbox_nums[i] for j in range(det_nums): dt = bboxes[k] @@ -152,7 +152,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): h = ymax - ymin + 1 bbox = [xmin, ymin, w, h] dt_res = { - 'image_id': image_id, + 'image_id': cur_image_id, 'category_id': category_id, 'bbox': bbox, 'score': score @@ -166,7 +166,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): seg_res = [] k = 0 for i in range(len(mask_nums)): - image_id = int(image_id[i][0]) + cur_image_id = int(image_id[i][0]) det_nums = mask_nums[i] for j in range(det_nums): dt = masks[k] @@ -177,7 +177,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): if 'counts' in sg: sg['counts'] = sg['counts'].decode("utf8") sg_res = { - 'image_id': image_id, + 'image_id': cur_image_id, 'category_id': cat_id, 'segmentation': sg, 'score': score diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index ffe475dfa..a280cbeed 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -91,12 +91,16 @@ def load_weight(model, weight, optimizer=None): model.set_dict(param_state_dict) if optimizer is not None and os.path.exists(path + '.pdopt'): + last_epoch = 0 optim_state_dict = paddle.load(path + '.pdopt') # to slove resume bug, will it be fixed in paddle 2.0 for key in optimizer.state_dict().keys(): if not key in optim_state_dict.keys(): optim_state_dict[key] = optimizer.state_dict()[key] + if 'last_epoch' in optim_state_dict: + last_epoch = optim_state_dict.pop('last_epoch') optimizer.set_state_dict(optim_state_dict) + return last_epoch return @@ -143,10 +147,12 @@ def load_pretrain_weight(model, return -def save_model(model, optimizer, save_dir, save_name): +def save_model(model, optimizer, save_dir, save_name, last_epoch): if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, save_name) paddle.save(model.state_dict(), save_path + ".pdparams") - paddle.save(optimizer.state_dict(), save_path + ".pdopt") + state_dict = optimizer.state_dict() + state_dict['last_epoch'] = last_epoch + paddle.save(state_dict, save_path + ".pdopt") logger.info("Save checkpoint: {}".format(save_dir)) diff --git a/tools/train.py b/tools/train.py index 43727692c..f9383dfa6 100755 --- a/tools/train.py +++ b/tools/train.py @@ -128,8 +128,9 @@ def run(FLAGS, cfg, place): optimizer = create('OptimizerBuilder')(lr, model.parameters()) # Init Model & Optimzer + start_epoch = 0 if FLAGS.weight_type == 'resume': - load_weight(model, cfg.pretrain_weights, optimizer) + start_epoch = load_weight(model, cfg.pretrain_weights, optimizer) else: load_pretrain_weight(model, cfg.pretrain_weights, cfg.get('load_static_weights', False), @@ -153,9 +154,7 @@ def run(FLAGS, cfg, place): start_time = time.time() end_time = time.time() # Run Train - start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch'] - for epoch_id in range(int(cfg.epoch)): - cur_eid = epoch_id + start_epoch + for cur_eid in range(start_epoch, int(cfg.epoch)): train_loader.dataset.epoch = cur_eid for iter_id, data in enumerate(train_loader): start_time = end_time @@ -185,7 +184,7 @@ def run(FLAGS, cfg, place): if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: # Log state - if epoch_id == 0 and iter_id == 0: + if cur_eid == start_epoch and iter_id == 0: train_stats = TrainingStats(cfg.log_iter, outputs.keys()) train_stats.update(outputs) logs = train_stats.log() @@ -203,7 +202,7 @@ def run(FLAGS, cfg, place): save_name = str(cur_eid) if cur_eid + 1 != int( cfg.epoch) else "model_final" save_dir = os.path.join(cfg.save_dir, cfg_name) - save_model(model, optimizer, save_dir, save_name) + save_model(model, optimizer, save_dir, save_name, cur_eid + 1) def main(): -- GitLab