未验证 提交 db73d68f 编写于 作者: C cnn 提交者: GitHub

fix error (#3843)

上级 1adca6bf
...@@ -551,7 +551,6 @@ class S2ANetHead(nn.Layer): ...@@ -551,7 +551,6 @@ class S2ANetHead(nn.Layer):
fam_cls_score1 = fam_cls_score fam_cls_score1 = fam_cls_score
feat_labels = paddle.to_tensor(feat_labels) feat_labels = paddle.to_tensor(feat_labels)
if (feat_labels >= 0).astype(paddle.int32).sum() > 0:
feat_labels_one_hot = paddle.nn.functional.one_hot( feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1) feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:] feat_labels_one_hot = feat_labels_one_hot[:, 1:]
...@@ -575,8 +574,6 @@ class S2ANetHead(nn.Layer): ...@@ -575,8 +574,6 @@ class S2ANetHead(nn.Layer):
fam_cls = fam_cls * feat_label_weights fam_cls = fam_cls * feat_label_weights
fam_cls_total = paddle.sum(fam_cls) fam_cls_total = paddle.sum(fam_cls)
else:
fam_cls_total = paddle.zeros([0], dtype=fam_cls_score1.dtype)
fam_cls_losses.append(fam_cls_total) fam_cls_losses.append(fam_cls_total)
# step3: regression loss # step3: regression loss
...@@ -673,7 +670,6 @@ class S2ANetHead(nn.Layer): ...@@ -673,7 +670,6 @@ class S2ANetHead(nn.Layer):
odm_cls_score1 = odm_cls_score odm_cls_score1 = odm_cls_score
feat_labels = paddle.to_tensor(feat_labels) feat_labels = paddle.to_tensor(feat_labels)
if (feat_labels >= 0).astype(paddle.int32).sum() > 0:
feat_labels_one_hot = paddle.nn.functional.one_hot( feat_labels_one_hot = paddle.nn.functional.one_hot(
feat_labels, self.cls_out_channels + 1) feat_labels, self.cls_out_channels + 1)
feat_labels_one_hot = feat_labels_one_hot[:, 1:] feat_labels_one_hot = feat_labels_one_hot[:, 1:]
...@@ -696,8 +692,6 @@ class S2ANetHead(nn.Layer): ...@@ -696,8 +692,6 @@ class S2ANetHead(nn.Layer):
odm_cls = odm_cls * feat_label_weights odm_cls = odm_cls * feat_label_weights
odm_cls_total = paddle.sum(odm_cls) odm_cls_total = paddle.sum(odm_cls)
else:
odm_cls_total = paddle.zeros([0], dtype=odm_cls_score1.dtype)
odm_cls_losses.append(odm_cls_total) odm_cls_losses.append(odm_cls_total)
# # step3: regression loss # # step3: regression loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册