未验证 提交 32edc345 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] fix load_weight and no bbox output (#2135)

* fix load_weight and no bbox output
上级 f87637de
...@@ -50,7 +50,9 @@ class BBoxPostProcess(object): ...@@ -50,7 +50,9 @@ class BBoxPostProcess(object):
including labels, scores and bboxes. The size of including labels, scores and bboxes. The size of
bboxes are corresponding to the original image. bboxes are corresponding to the original image.
""" """
assert bboxes.shape[0] > 0, 'There is no detection output' if bboxes.shape[0] == 0:
return paddle.to_tensor([[0, 0.0, 0.0, 0.0, 0.0, 0.0]])
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
origin_shape_list = [] origin_shape_list = []
......
...@@ -91,7 +91,25 @@ def load_weight(model, weight, optimizer=None): ...@@ -91,7 +91,25 @@ def load_weight(model, weight, optimizer=None):
"exists.".format(pdparam_path)) "exists.".format(pdparam_path))
param_state_dict = paddle.load(pdparam_path) param_state_dict = paddle.load(pdparam_path)
model.set_dict(param_state_dict) model_dict = model.state_dict()
model_weight = {}
incorrect_keys = 0
for key in model_dict.keys():
if key in param_state_dict.keys():
model_weight[key] = param_state_dict[key]
else:
logger.info('Unmatched key: {}'.format(key))
incorrect_keys += 1
assert incorrect_keys == 0, "Load weight {} incorrectly, \
{} keys unmatched, please check again.".format(weight,
incorrect_keys)
logger.info('Finish loading model weight parameter: {}'.format(
pdparam_path))
model.set_dict(model_weight)
last_epoch = 0 last_epoch = 0
if optimizer is not None and os.path.exists(path + '.pdopt'): if optimizer is not None and os.path.exists(path + '.pdopt'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册