From f137929f1abc40278f0b53571d3a3dd7d564e21b Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 12 Nov 2019 10:31:54 +0800 Subject: [PATCH] refine focal loss (#23) --- ppdet/modeling/anchor_heads/retina_head.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ppdet/modeling/anchor_heads/retina_head.py b/ppdet/modeling/anchor_heads/retina_head.py index 41246e8b6..cb6cb1cfb 100644 --- a/ppdet/modeling/anchor_heads/retina_head.py +++ b/ppdet/modeling/anchor_heads/retina_head.py @@ -389,6 +389,7 @@ class RetinaHead(object): im_info=im_info, num_classes=self.num_classes - 1) fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num') + score_tgt = fluid.layers.cast(score_tgt, 'int32') loss_cls = fluid.layers.sigmoid_focal_loss( x=score_pred, label=score_tgt, -- GitLab