提交 ecad1f9c 编写于 作者: Z Zhi Tian

Merge branch 'fixed_8gpus_bug' into dcn

MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d"
RPN_ONLY: True
FCOS_ON: True
BACKBONE:
CONV_BODY: "R-101-FPN-RETINANET"
RESNETS:
STRIDE_IN_1X1: False
BACKBONE_OUT_CHANNELS: 256
NUM_GROUPS: 32
WIDTH_PER_GROUP: 8
RETINANET:
USE_C5: False # FCOS uses P5 instead of C5
FCOS:
# normalizing the regression targets with FPN strides
NORM_REG_TARGETS: True
# positioning centerness on the regress branch.
# Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042
CENTERNESS_ON_REG: True
# using center sampling and GIoU.
# Please refer to https://github.com/yqyao/FCOS_PLUS
CENTER_SAMPLING_RADIUS: 1.5
IOU_LOSS_TYPE: "giou"
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
INPUT:
MIN_SIZE_RANGE_TRAIN: (640, 800)
MAX_SIZE_TRAIN: 1333
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1333
DATALOADER:
SIZE_DIVISIBILITY: 32
SOLVER:
BASE_LR: 0.01
WEIGHT_DECAY: 0.0001
STEPS: (120000, 160000)
MAX_ITER: 180000
IMS_PER_BATCH: 16
WARMUP_METHOD: "constant"
......@@ -45,7 +45,7 @@ class IOULoss(nn.Module):
raise NotImplementedError
if weight is not None and weight.sum() > 0:
return (losses * weight).sum() / weight.sum()
return (losses * weight).sum()
else:
assert losses.numel() != 0
return losses.mean()
return losses.sum()
......@@ -6,7 +6,7 @@ file
import torch
from torch.nn import functional as F
from torch import nn
import os
from ..utils import concat_box_prediction_layers
from fcos_core.layers import IOULoss
from fcos_core.layers import SigmoidFocalLoss
......@@ -19,6 +19,17 @@ from fcos_core.structures.boxlist_ops import cat_boxlist
INF = 100000000
def get_num_gpus():
return int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
def reduce_sum(tensor):
import torch.distributed as dist
tensor = tensor.clone()
dist.all_reduce(tensor, op=dist.reduce_op.SUM)
return tensor
class FCOSLossComputation(object):
"""
This class computes the FCOS losses.
......@@ -37,7 +48,7 @@ class FCOSLossComputation(object):
# we make use of IOU Loss for bounding boxes regression,
# but we found that L1 in log scale can yield a similar performance
self.box_reg_loss_func = IOULoss(self.iou_loss_type)
self.centerness_loss_func = nn.BCEWithLogitsLoss()
self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")
def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):
'''
......@@ -229,28 +240,39 @@ class FCOSLossComputation(object):
reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
cls_loss = self.cls_loss_func(
box_cls_flatten,
labels_flatten.int()
) / (pos_inds.numel() + N) # add N to avoid dividing by a zero
box_regression_flatten = box_regression_flatten[pos_inds]
reg_targets_flatten = reg_targets_flatten[pos_inds]
centerness_flatten = centerness_flatten[pos_inds]
num_pos_per_gpu = pos_inds.numel()
num_gpus = get_num_gpus()
if num_gpus > 1:
total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_per_gpu])).item()
else:
total_num_pos = num_pos_per_gpu
cls_loss = self.cls_loss_func(
box_cls_flatten,
labels_flatten.int()
) / max(total_num_pos / float(num_gpus), 1.0)
if pos_inds.numel() > 0:
centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
sum_centerness_targets = centerness_targets.sum()
sum_centerness_targets = reduce_sum(sum_centerness_targets).item()
reg_loss = self.box_reg_loss_func(
box_regression_flatten,
reg_targets_flatten,
centerness_targets
)
) / (sum_centerness_targets / float(num_gpus))
centerness_loss = self.centerness_loss_func(
centerness_flatten,
centerness_targets
)
) / max(total_num_pos / float(num_gpus), 1.0)
else:
reg_loss = box_regression_flatten.sum()
reduce_sum(centerness_flatten.new_tensor([0.0]))
centerness_loss = centerness_flatten.sum()
return cls_loss, reg_loss, centerness_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册