未验证 提交 3fba4778 编写于 作者: W wangguanzhong 提交者: GitHub

minor fix for cornernet & yolov4 (#621)

上级 52ecf506
...@@ -23,12 +23,8 @@ from paddle.fluid.initializer import Constant ...@@ -23,12 +23,8 @@ from paddle.fluid.initializer import Constant
from ..backbones.hourglass import _conv_norm, kaiming_init from ..backbones.hourglass import _conv_norm, kaiming_init
from ppdet.core.workspace import register from ppdet.core.workspace import register
import numpy as np import numpy as np
try: import logging
import cornerpool_lib logger = logging.getLogger(__name__)
except:
print(
"warning: cornerpool_lib not found, compile in ext_op at first if needed"
)
__all__ = ['CornerHead'] __all__ = ['CornerHead']
...@@ -247,6 +243,10 @@ class CornerHead(object): ...@@ -247,6 +243,10 @@ class CornerHead(object):
ae_threshold=1, ae_threshold=1,
num_dets=1000, num_dets=1000,
top_k=100): top_k=100):
try:
import cornerpool_lib
except:
logger.error("cornerpool_lib not found, compile in ext_op at first")
self.train_batch_size = train_batch_size self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size self.test_batch_size = test_batch_size
self.num_classes = num_classes self.num_classes = num_classes
......
...@@ -59,7 +59,7 @@ class CornerNetSqueeze(object): ...@@ -59,7 +59,7 @@ class CornerNetSqueeze(object):
body_feats = self.backbone(im) body_feats = self.backbone(im)
if self.fpn is not None: if self.fpn is not None:
body_feats, _ = self.fpn.get_output(body_feats) body_feats, _ = self.fpn.get_output(body_feats)
body_feats = [body_feats.values()[-1]] body_feats = [list(body_feats.values())[-1]]
if mode == 'train': if mode == 'train':
target_vars = [ target_vars = [
'tl_heatmaps', 'br_heatmaps', 'tag_masks', 'tl_regrs', 'tl_heatmaps', 'br_heatmaps', 'tag_masks', 'tl_regrs',
......
...@@ -166,7 +166,7 @@ class YOLOv3Loss(object): ...@@ -166,7 +166,7 @@ class YOLOv3Loss(object):
# self.scale_x_y, Sequence) else self.scale_x_y[i] # self.scale_x_y, Sequence) else self.scale_x_y[i]
loss_obj_pos, loss_obj_neg = self._calc_obj_loss( loss_obj_pos, loss_obj_neg = self._calc_obj_loss(
output, obj, tobj, gt_box, self._batch_size, anchors, output, obj, tobj, gt_box, self._batch_size, anchors,
num_classes, downsample, self._ignore_thresh, scale_x_y) num_classes, downsample, self._ignore_thresh)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
...@@ -276,7 +276,7 @@ class YOLOv3Loss(object): ...@@ -276,7 +276,7 @@ class YOLOv3Loss(object):
return (tx, ty, tw, th, tscale, tobj, tcls) return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors, def _calc_obj_loss(self, output, obj, tobj, gt_box, batch_size, anchors,
num_classes, downsample, ignore_thresh, scale_x_y): num_classes, downsample, ignore_thresh):
# A prediction bbox overlap any gt_bbox over ignore_thresh, # A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows: # objectness loss will be ignored, process as follows:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册