未验证 提交 c7a19f16 编写于 作者: S shangliang Xu 提交者: GitHub

fix some code (#3710)

上级 b33838c0
......@@ -16,13 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
import pycocotools.mask as mask_util
from ..initializer import *
from ..initializer import linear_init_
__all__ = ['DETRHead']
......
......@@ -50,9 +50,7 @@ def _no_grad_normal_(tensor, mean=0., std=1.):
def _no_grad_fill_(tensor, value=0.):
with paddle.no_grad():
v = paddle.rand(shape=tensor.shape, dtype=tensor.dtype)
v[...] = value
tensor.set_value(v)
tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
return tensor
......
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from .iou_loss import GIoULoss
from ..transformers import bbox_cxcywh_to_xyxy, bbox_overlaps, sigmoid_focal_loss
from ..transformers import bbox_cxcywh_to_xyxy, sigmoid_focal_loss
__all__ = ['DETRLoss']
......@@ -211,7 +211,7 @@ class DETRLoss(nn.Layer):
num_gts = paddle.clip(
num_gts / paddle.distributed.get_world_size(), min=1).item()
except:
num_gts = max(num_gts, 1)
num_gts = max(num_gts.item(), 1)
total_loss = dict()
total_loss.update(
self._get_loss_class(logits[-1], gt_class, match_indices,
......
......@@ -23,8 +23,8 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..layers import MultiHeadAttention, _convert_attention_mask
from .position_encoding import PositionEmbedding
from .utils import *
from ..initializer import *
from .utils import _get_clones
from ..initializer import linear_init_, conv_init_, xavier_uniform_, normal_
__all__ = ['DETRTransformer']
......
......@@ -54,5 +54,4 @@ def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
if alpha >= 0:
alpha_t = alpha * label + (1 - alpha) * (1 - label)
loss = alpha_t * loss
return loss.mean(1).sum() / normalizer if normalizer > 1. else loss.mean(
1).sum()
return loss.mean(1).sum() / normalizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册