未验证 提交 c7a19f16 编写于 作者: S shangliang Xu 提交者: GitHub

fix some code (#3710)

上级 b33838c0
...@@ -16,13 +16,12 @@ from __future__ import absolute_import ...@@ -16,13 +16,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
import pycocotools.mask as mask_util import pycocotools.mask as mask_util
from ..initializer import * from ..initializer import linear_init_
__all__ = ['DETRHead'] __all__ = ['DETRHead']
......
...@@ -50,9 +50,7 @@ def _no_grad_normal_(tensor, mean=0., std=1.): ...@@ -50,9 +50,7 @@ def _no_grad_normal_(tensor, mean=0., std=1.):
def _no_grad_fill_(tensor, value=0.): def _no_grad_fill_(tensor, value=0.):
with paddle.no_grad(): with paddle.no_grad():
v = paddle.rand(shape=tensor.shape, dtype=tensor.dtype) tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
v[...] = value
tensor.set_value(v)
return tensor return tensor
......
...@@ -21,7 +21,7 @@ import paddle.nn as nn ...@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from .iou_loss import GIoULoss from .iou_loss import GIoULoss
from ..transformers import bbox_cxcywh_to_xyxy, bbox_overlaps, sigmoid_focal_loss from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
__all__ = ['DETRLoss'] __all__ = ['DETRLoss']
...@@ -211,7 +211,7 @@ class DETRLoss(nn.Layer): ...@@ -211,7 +211,7 @@ class DETRLoss(nn.Layer):
num_gts = paddle.clip( num_gts = paddle.clip(
num_gts / paddle.distributed.get_world_size(), min=1).item() num_gts / paddle.distributed.get_world_size(), min=1).item()
except: except:
num_gts = max(num_gts, 1) num_gts = max(num_gts.item(), 1)
total_loss = dict() total_loss = dict()
total_loss.update( total_loss.update(
self._get_loss_class(logits[-1], gt_class, match_indices, self._get_loss_class(logits[-1], gt_class, match_indices,
......
...@@ -23,8 +23,8 @@ import paddle.nn.functional as F ...@@ -23,8 +23,8 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ..layers import MultiHeadAttention, _convert_attention_mask from ..layers import MultiHeadAttention, _convert_attention_mask
from .position_encoding import PositionEmbedding from .position_encoding import PositionEmbedding
from .utils import * from .utils import _get_clones
from ..initializer import * from ..initializer import linear_init_, conv_init_, xavier_uniform_, normal_
__all__ = ['DETRTransformer'] __all__ = ['DETRTransformer']
......
...@@ -54,5 +54,4 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): ...@@ -54,5 +54,4 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
if alpha >= 0: if alpha >= 0:
alpha_t = alpha * label + (1 - alpha) * (1 - label) alpha_t = alpha * label + (1 - alpha) * (1 - label)
loss = alpha_t * loss loss = alpha_t * loss
return loss.mean(1).sum() / normalizer if normalizer > 1. else loss.mean( return loss.mean(1).sum() / normalizer
1).sum()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册