From 32edc345b5bb5c4cc68dac1b54ecd85d5009139b Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Fri, 29 Jan 2021 17:49:34 +0800 Subject: [PATCH] [Dygraph] fix load_weight and no bbox output (#2135) * fix load_weight and no bbox output --- dygraph/ppdet/modeling/post_process.py | 4 +++- dygraph/ppdet/utils/checkpoint.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/dygraph/ppdet/modeling/post_process.py b/dygraph/ppdet/modeling/post_process.py index a3619a3a8..6c5033116 100644 --- a/dygraph/ppdet/modeling/post_process.py +++ b/dygraph/ppdet/modeling/post_process.py @@ -50,7 +50,9 @@ class BBoxPostProcess(object): including labels, scores and bboxes. The size of bboxes are corresponding to the original image. """ - assert bboxes.shape[0] > 0, 'There is no detection output' + if bboxes.shape[0] == 0: + return paddle.to_tensor([[0, 0.0, 0.0, 0.0, 0.0, 0.0]]) + origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape_list = [] diff --git a/dygraph/ppdet/utils/checkpoint.py b/dygraph/ppdet/utils/checkpoint.py index fca3d47d7..1f4562233 100644 --- a/dygraph/ppdet/utils/checkpoint.py +++ b/dygraph/ppdet/utils/checkpoint.py @@ -91,7 +91,25 @@ def load_weight(model, weight, optimizer=None): "exists.".format(pdparam_path)) param_state_dict = paddle.load(pdparam_path) - model.set_dict(param_state_dict) + model_dict = model.state_dict() + + model_weight = {} + incorrect_keys = 0 + + for key in model_dict.keys(): + if key in param_state_dict.keys(): + model_weight[key] = param_state_dict[key] + else: + logger.info('Unmatched key: {}'.format(key)) + incorrect_keys += 1 + + assert incorrect_keys == 0, "Load weight {} incorrectly, \ + {} keys unmatched, please check again.".format(weight, + incorrect_keys) + logger.info('Finish loading model weight parameter: {}'.format( + pdparam_path)) + + model.set_dict(model_weight) last_epoch = 0 if optimizer is not None and os.path.exists(path + '.pdopt'): -- GitLab