未验证 提交 73ac7c1d 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph]refine resume process and eval process (#1848)

* refine resume process and modify code to support eval while batch_size is larger than 1

* modify code according to review
上级 56a673a0
...@@ -141,7 +141,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): ...@@ -141,7 +141,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
det_res = [] det_res = []
k = 0 k = 0
for i in range(len(bbox_nums)): for i in range(len(bbox_nums)):
image_id = int(image_id[i][0]) cur_image_id = int(image_id[i][0])
det_nums = bbox_nums[i] det_nums = bbox_nums[i]
for j in range(det_nums): for j in range(det_nums):
dt = bboxes[k] dt = bboxes[k]
...@@ -152,7 +152,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): ...@@ -152,7 +152,7 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
h = ymax - ymin + 1 h = ymax - ymin + 1
bbox = [xmin, ymin, w, h] bbox = [xmin, ymin, w, h]
dt_res = { dt_res = {
'image_id': image_id, 'image_id': cur_image_id,
'category_id': category_id, 'category_id': category_id,
'bbox': bbox, 'bbox': bbox,
'score': score 'score': score
...@@ -166,7 +166,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): ...@@ -166,7 +166,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
seg_res = [] seg_res = []
k = 0 k = 0
for i in range(len(mask_nums)): for i in range(len(mask_nums)):
image_id = int(image_id[i][0]) cur_image_id = int(image_id[i][0])
det_nums = mask_nums[i] det_nums = mask_nums[i]
for j in range(det_nums): for j in range(det_nums):
dt = masks[k] dt = masks[k]
...@@ -177,7 +177,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map): ...@@ -177,7 +177,7 @@ def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
if 'counts' in sg: if 'counts' in sg:
sg['counts'] = sg['counts'].decode("utf8") sg['counts'] = sg['counts'].decode("utf8")
sg_res = { sg_res = {
'image_id': image_id, 'image_id': cur_image_id,
'category_id': cat_id, 'category_id': cat_id,
'segmentation': sg, 'segmentation': sg,
'score': score 'score': score
......
...@@ -91,12 +91,16 @@ def load_weight(model, weight, optimizer=None): ...@@ -91,12 +91,16 @@ def load_weight(model, weight, optimizer=None):
model.set_dict(param_state_dict) model.set_dict(param_state_dict)
if optimizer is not None and os.path.exists(path + '.pdopt'): if optimizer is not None and os.path.exists(path + '.pdopt'):
last_epoch = 0
optim_state_dict = paddle.load(path + '.pdopt') optim_state_dict = paddle.load(path + '.pdopt')
# to slove resume bug, will it be fixed in paddle 2.0 # to slove resume bug, will it be fixed in paddle 2.0
for key in optimizer.state_dict().keys(): for key in optimizer.state_dict().keys():
if not key in optim_state_dict.keys(): if not key in optim_state_dict.keys():
optim_state_dict[key] = optimizer.state_dict()[key] optim_state_dict[key] = optimizer.state_dict()[key]
if 'last_epoch' in optim_state_dict:
last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict) optimizer.set_state_dict(optim_state_dict)
return last_epoch
return return
...@@ -143,10 +147,12 @@ def load_pretrain_weight(model, ...@@ -143,10 +147,12 @@ def load_pretrain_weight(model,
return return
def save_model(model, optimizer, save_dir, save_name): def save_model(model, optimizer, save_dir, save_name, last_epoch):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
paddle.save(model.state_dict(), save_path + ".pdparams") paddle.save(model.state_dict(), save_path + ".pdparams")
paddle.save(optimizer.state_dict(), save_path + ".pdopt") state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir)) logger.info("Save checkpoint: {}".format(save_dir))
...@@ -128,8 +128,9 @@ def run(FLAGS, cfg, place): ...@@ -128,8 +128,9 @@ def run(FLAGS, cfg, place):
optimizer = create('OptimizerBuilder')(lr, model.parameters()) optimizer = create('OptimizerBuilder')(lr, model.parameters())
# Init Model & Optimzer # Init Model & Optimzer
start_epoch = 0
if FLAGS.weight_type == 'resume': if FLAGS.weight_type == 'resume':
load_weight(model, cfg.pretrain_weights, optimizer) start_epoch = load_weight(model, cfg.pretrain_weights, optimizer)
else: else:
load_pretrain_weight(model, cfg.pretrain_weights, load_pretrain_weight(model, cfg.pretrain_weights,
cfg.get('load_static_weights', False), cfg.get('load_static_weights', False),
...@@ -153,9 +154,7 @@ def run(FLAGS, cfg, place): ...@@ -153,9 +154,7 @@ def run(FLAGS, cfg, place):
start_time = time.time() start_time = time.time()
end_time = time.time() end_time = time.time()
# Run Train # Run Train
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch'] for cur_eid in range(start_epoch, int(cfg.epoch)):
for epoch_id in range(int(cfg.epoch)):
cur_eid = epoch_id + start_epoch
train_loader.dataset.epoch = cur_eid train_loader.dataset.epoch = cur_eid
for iter_id, data in enumerate(train_loader): for iter_id, data in enumerate(train_loader):
start_time = end_time start_time = end_time
...@@ -185,7 +184,7 @@ def run(FLAGS, cfg, place): ...@@ -185,7 +184,7 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state # Log state
if epoch_id == 0 and iter_id == 0: if cur_eid == start_epoch and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys()) train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs) train_stats.update(outputs)
logs = train_stats.log() logs = train_stats.log()
...@@ -203,7 +202,7 @@ def run(FLAGS, cfg, place): ...@@ -203,7 +202,7 @@ def run(FLAGS, cfg, place):
save_name = str(cur_eid) if cur_eid + 1 != int( save_name = str(cur_eid) if cur_eid + 1 != int(
cfg.epoch) else "model_final" cfg.epoch) else "model_final"
save_dir = os.path.join(cfg.save_dir, cfg_name) save_dir = os.path.join(cfg.save_dir, cfg_name)
save_model(model, optimizer, save_dir, save_name) save_model(model, optimizer, save_dir, save_name, cur_eid + 1)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册