未验证 提交 73ac7c1d 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]refine resume process and eval process (#1848)

* refine resume process and modify code to support eval while batch_size is larger than 1

* modify code according to review
上级 56a673a0
......@@ -141,7 +141,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
det_res = []
k = 0
for i in range(len(bbox_nums)):
image_id = int(image_id[i][0])
cur_image_id = int(image_id[i][0])
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
......@@ -152,7 +152,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
dt_res = {
'image_id': image_id,
'image_id': cur_image_id,
'category_id': category_id,
'bbox': bbox,
'score': score
......@@ -166,7 +166,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
seg_res = []
k = 0
for i in range(len(mask_nums)):
image_id = int(image_id[i][0])
cur_image_id = int(image_id[i][0])
det_nums = mask_nums[i]
for j in range(det_nums):
dt = masks[k]
......@@ -177,7 +177,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
if 'counts' in sg:
sg['counts'] = sg['counts'].decode("utf8")
sg_res = {
'image_id': image_id,
'image_id': cur_image_id,
'category_id': cat_id,
'segmentation': sg,
'score': score
......
......@@ -91,12 +91,16 @@ def load_weight(model, weight, optimizer=None):
model.set_dict(param_state_dict)
if optimizer is not None and os.path.exists(path + '.pdopt'):
last_epoch = 0
optim_state_dict = paddle.load(path + '.pdopt')
# to slove resume bug, will it be fixed in paddle 2.0
for key in optimizer.state_dict().keys():
if not key in optim_state_dict.keys():
optim_state_dict[key] = optimizer.state_dict()[key]
if 'last_epoch' in optim_state_dict:
last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict)
return last_epoch
return
......@@ -143,10 +147,12 @@ def load_pretrain_weight(model,
return
def save_model(model, optimizer, save_dir, save_name):
def save_model(model, optimizer, save_dir, save_name, last_epoch):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)
paddle.save(model.state_dict(), save_path + ".pdparams")
paddle.save(optimizer.state_dict(), save_path + ".pdopt")
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir))
......@@ -128,8 +128,9 @@ def run(FLAGS, cfg, place):
optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer
start_epoch = 0
if FLAGS.weight_type == 'resume':
load_weight(model, cfg.pretrain_weights, optimizer)
start_epoch = load_weight(model, cfg.pretrain_weights, optimizer)
else:
load_pretrain_weight(model, cfg.pretrain_weights,
cfg.get('load_static_weights', False),
......@@ -153,9 +154,7 @@ def run(FLAGS, cfg, place):
start_time = time.time()
end_time = time.time()
# Run Train
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for epoch_id in range(int(cfg.epoch)):
cur_eid = epoch_id + start_epoch
for cur_eid in range(start_epoch, int(cfg.epoch)):
train_loader.dataset.epoch = cur_eid
for iter_id, data in enumerate(train_loader):
start_time = end_time
......@@ -185,7 +184,7 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state
if epoch_id == 0 and iter_id == 0:
if cur_eid == start_epoch and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs)
logs = train_stats.log()
......@@ -203,7 +202,7 @@ def run(FLAGS, cfg, place):
save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name)
save_model(model, optimizer, save_dir, save_name)
save_model(model, optimizer, save_dir, save_name, cur_eid + 1)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册