未验证 提交 9be38bee 编写于 作者: J Jianfeng Wang 提交者: GitHub

feat(detection): use numerical stable focal loss and EMA normalizer in RetinaNet (#49)

* feat(detection): use numerical stable focal loss and EMA normalizer in
RetinaNet

* chore(detection): rename some variables
上级 70c98b36
......@@ -23,8 +23,8 @@
```bash
python3 tools/inference.py -f configs/retinanet_res50_coco_1x_800size.py \
-w /path/to/retinanet_weights.pkl
-i ../../assets/cat.jpg \
-w /path/to/retinanet_weights.pkl \
-i ../../assets/cat.jpg
```
`tools/inference.py`的命令行选项如下:
......@@ -62,11 +62,12 @@ python3 tools/train.py -f configs/retinanet_res50_coco_1x_800size.py -n 8
`tools/train.py`提供了灵活的命令行选项,包括:
- `-f`, 所需要训练的网络结构描述文件。可以是RetinaNet、Faster R-CNN等.
- `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu.
- `-f`, 所需要训练的网络结构描述文件。可以是RetinaNet、Faster R-CNN等
- `-n`, 用于训练的devices(gpu)数量,默认使用所有可用的gpu
- `-w`, 预训练的backbone网络权重的路径。
- `-b`,训练时采用的`batch size`, 默认2,表示每张卡训2张图。
- `-d`, COCO2017数据集的上级目录,默认`/data/datasets`
- `--enable_sublinear`, 开启sublinear memory优化,可用于在有限的显存中训练大模型。
默认情况下模型会存在 `log-of-模型名`目录下。
......@@ -88,19 +89,37 @@ nvcc -I $MGE/_internal/include -shared -o lib_nms.so -Xcompiler "-fno-strict-ali
## 如何测试
在得到训练完保存的模型之后,可以通过tools下的test.py文件测试模型在`COCO2017`验证集的性能
在得到训练完保存的模型之后,可以通过tools下的test.py文件测试模型在`COCO2017`验证集的性能
验证某个epoch的性能:
```bash
python3 tools/test.py -f configs/retinanet_res50_coco_1x_800size.py -n 8 \
-se 15
```
验证连续若干个epoch性能:
```bash
python3 tools/test.py -f configs/retinanet_res50_coco_1x_800size.py -n 8 \
-w /path/to/retinanet_weights.pt \
-se 15 -ee 17
```
验证某个指定weights的性能:
```bash
python3 tools/test.py -f configs/retinanet_res50_coco_1x_800size.py -n 8 \
-w /path/to/retinanet_weights.pt
```
`tools/test.py`的命令行选项如下:
- `-f`, 所需要测试的网络结构描述文件。
- `-n`, 用于测试的devices(gpu)数量,默认1
- `-w`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重。
- `-n`, 用于测试的devices(gpu)数量,默认1
- `-w`, 需要测试的模型;可以从顶部的表格中下载训练好的检测器权重, 也可以用自行训练好的权重,指定该参数时会忽略`-se``-ee`参数
- `-d`,COCO2017数据集的上级目录,默认`/data/datasets`
- `-se`,连续测试的起始epoch数,默认为最后一个epoch,该参数的值必须大于等于0且小于模型的最大epoch数。
- `-ee`,连续测试的结束epoch数,默认等于`-se`(即只测试1个epoch),该参数的值必须大于等于`-se`且小于模型的最大epoch数。
## 参考文献
......
......@@ -85,3 +85,11 @@ def get_padded_tensor(
else:
raise Exception("Not supported tensor dim: %d" % ndim)
return padded_array
def softplus(x: Tensor) -> Tensor:
return F.log(1 + F.exp(-F.abs(x))) + F.relu(x)
def logsigmoid(x: Tensor) -> Tensor:
return -softplus(-x)
......@@ -9,10 +9,12 @@
import megengine.functional as F
from megengine.core import Tensor
from official.vision.detection import layers
def get_focal_loss(
score: Tensor,
label: Tensor,
logits: Tensor,
labels: Tensor,
ignore_label: int = -1,
background: int = 0,
alpha: float = 0.5,
......@@ -27,10 +29,10 @@ def get_focal_loss(
FL(p_t) = -\alpha_t(1-p_t)^\gamma \log(p_t)
Args:
score (Tensor):
the predicted score with the shape of :math:`(B, A, C)`
label (Tensor):
the assigned label of boxes with shape of :math:`(B, A)`
logits (Tensor):
the predicted logits with the shape of :math:`(B, A, C)`
labels (Tensor):
the assigned labels of boxes with shape of :math:`(B, A)`
ignore_label (int):
the value of ignore class. Default: -1
background (int):
......@@ -39,30 +41,31 @@ def get_focal_loss(
parameter to mitigate class imbalance. Default: 0.5
gamma (float):
parameter to mitigate easy/hard loss imbalance. Default: 0
norm_type (str): current support 'fg', 'none':
'fg': loss will be normalized by number of fore-ground samples
'none": not norm
norm_type (str): current support "fg", "none":
"fg": loss will be normalized by number of fore-ground samples
"none": not norm
Returns:
the calculated focal loss.
"""
class_range = F.arange(1, score.shape[2] + 1)
class_range = F.arange(1, logits.shape[2] + 1)
label = F.add_axis(label, axis=2)
pos_part = (1 - score) ** gamma * F.log(F.clamp(score, 1e-8))
neg_part = score ** gamma * F.log(F.clamp(1 - score, 1e-8))
labels = F.add_axis(labels, axis=2)
scores = F.sigmoid(logits)
pos_part = (1 - scores) ** gamma * layers.logsigmoid(logits)
neg_part = scores ** gamma * layers.logsigmoid(-logits)
pos_loss = -(label == class_range) * pos_part * alpha
pos_loss = -(labels == class_range) * pos_part * alpha
neg_loss = (
-(label != class_range) * (label != ignore_label) * neg_part * (1 - alpha)
-(labels != class_range) * (labels != ignore_label) * neg_part * (1 - alpha)
)
loss = pos_loss + neg_loss
loss = (pos_loss + neg_loss).sum()
if norm_type == "fg":
fg_mask = (label != background) * (label != ignore_label)
return loss.sum() / F.maximum(fg_mask.sum(), 1)
fg_mask = (labels != background) * (labels != ignore_label)
return loss / F.maximum(fg_mask.sum(), 1)
elif norm_type == "none":
return loss.sum()
return loss
else:
raise NotImplementedError
......@@ -70,7 +73,7 @@ def get_focal_loss(
def get_smooth_l1_loss(
pred_bbox: Tensor,
gt_bbox: Tensor,
label: Tensor,
labels: Tensor,
beta: int = 1,
background: int = 0,
ignore_label: int = -1,
......@@ -83,42 +86,43 @@ def get_smooth_l1_loss(
the predicted bbox with the shape of :math:`(B, A, 4)`
gt_bbox (Tensor):
the ground-truth bbox with the shape of :math:`(B, A, 4)`
label (Tensor):
the assigned label of boxes with shape of :math:`(B, A)`
labels (Tensor):
the assigned labels of boxes with shape of :math:`(B, A)`
beta (int):
the parameter of smooth l1 loss. Default: 1
background (int):
the value of background class. Default: 0
ignore_label (int):
the value of ignore class. Default: -1
norm_type (str): current support 'fg', 'all', 'none':
'fg': loss will be normalized by number of fore-ground samples
'all': loss will be normalized by number of all samples
'none': not norm
norm_type (str): current support "fg", "all", "none":
"fg": loss will be normalized by number of fore-ground samples
"all": loss will be normalized by number of all samples
"none": not norm
Returns:
the calculated smooth l1 loss.
"""
pred_bbox = pred_bbox.reshape(-1, 4)
gt_bbox = gt_bbox.reshape(-1, 4)
label = label.reshape(-1)
labels = labels.reshape(-1)
fg_mask = (label != background) * (label != ignore_label)
fg_mask = (labels != background) * (labels != ignore_label)
losses = get_smooth_l1_base(pred_bbox, gt_bbox, beta)
loss = get_smooth_l1_base(pred_bbox, gt_bbox, beta)
loss = (loss.sum(axis=1) * fg_mask).sum()
if norm_type == "fg":
loss = (losses.sum(axis=1) * fg_mask).sum() / F.maximum(fg_mask.sum(), 1)
loss = loss / F.maximum(fg_mask.sum(), 1)
elif norm_type == "all":
all_mask = label != ignore_label
loss = (losses.sum(axis=1) * fg_mask).sum() / F.maximum(all_mask.sum(), 1)
all_mask = labels != ignore_label
loss = loss / F.maximum(all_mask.sum(), 1)
elif norm_type == "none":
return loss
else:
raise NotImplementedError
return loss
def get_smooth_l1_base(
pred_bbox: Tensor, gt_bbox: Tensor, beta: float,
):
def get_smooth_l1_base(pred_bbox: Tensor, gt_bbox: Tensor, beta: float) -> Tensor:
r"""
Args:
......@@ -147,12 +151,12 @@ def get_smooth_l1_base(
return loss
def softmax_loss(score, label, ignore_label=-1):
max_score = F.zero_grad(score.max(axis=1, keepdims=True))
score -= max_score
log_prob = score - F.log(F.exp(score).sum(axis=1, keepdims=True))
mask = label != ignore_label
vlabel = label * mask
loss = -(F.indexing_one_hot(log_prob, vlabel.astype("int32"), 1) * mask).sum()
def softmax_loss(scores: Tensor, labels: Tensor, ignore_label: int = -1) -> Tensor:
max_scores = F.zero_grad(scores.max(axis=1, keepdims=True))
scores -= max_scores
log_prob = scores - F.log(F.exp(scores).sum(axis=1, keepdims=True))
mask = labels != ignore_label
vlabels = labels * mask
loss = -(F.indexing_one_hot(log_prob, vlabels.astype("int32"), 1) * mask).sum()
loss = loss / F.maximum(mask.sum(), 1)
return loss
......@@ -51,20 +51,20 @@ class RCNN(M.Module):
flatten_feature = F.flatten(pool_features, start_axis=1)
roi_feature = F.relu(self.fc1(flatten_feature))
roi_feature = F.relu(self.fc2(roi_feature))
pred_cls = self.pred_cls(roi_feature)
pred_delta = self.pred_delta(roi_feature)
pred_logits = self.pred_cls(roi_feature)
pred_offsets = self.pred_delta(roi_feature)
if self.training:
# loss for classification
loss_rcnn_cls = layers.softmax_loss(pred_cls, labels)
loss_rcnn_cls = layers.softmax_loss(pred_logits, labels)
# loss for regression
pred_delta = pred_delta.reshape(-1, self.cfg.num_classes + 1, 4)
pred_offsets = pred_offsets.reshape(-1, self.cfg.num_classes + 1, 4)
vlabels = labels.reshape(-1, 1).broadcast((labels.shapeof(0), 4))
pred_delta = F.indexing_one_hot(pred_delta, vlabels, axis=1)
pred_offsets = F.indexing_one_hot(pred_offsets, vlabels, axis=1)
loss_rcnn_loc = layers.get_smooth_l1_loss(
pred_delta,
pred_offsets,
bbox_targets,
labels,
self.cfg.rcnn_smooth_l1_beta,
......@@ -74,14 +74,14 @@ class RCNN(M.Module):
return loss_dict
else:
# slice 1 for removing background
pred_scores = F.softmax(pred_cls, axis=1)[:, 1:]
pred_delta = pred_delta[:, 4:].reshape(-1, 4)
pred_scores = F.softmax(pred_logits, axis=1)[:, 1:]
pred_offsets = pred_offsets[:, 4:].reshape(-1, 4)
target_shape = (rcnn_rois.shapeof(0), self.cfg.num_classes, 4)
# rois (N, 4) -> (N, 1, 4) -> (N, 80, 4) -> (N * 80, 4)
base_rois = (
F.add_axis(rcnn_rois[:, 1:5], 1).broadcast(target_shape).reshape(-1, 4)
)
pred_bbox = self.box_coder.decode(base_rois, pred_delta)
pred_bbox = self.box_coder.decode(base_rois, pred_offsets)
return pred_bbox, pred_scores
def get_ground_truth(self, rpn_rois, im_info, gt_boxes):
......
......@@ -74,8 +74,8 @@ class RetinaNetHead(M.Module):
M.init.fill_(self.cls_score.bias, bias_value)
def forward(self, features: List[Tensor]):
logits, bbox_reg = [], []
logits, offsets = [], []
for feature in features:
logits.append(self.cls_score(self.cls_subnet(feature)))
bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature)))
return logits, bbox_reg
offsets.append(self.bbox_pred(self.bbox_subnet(feature)))
return logits, offsets
......@@ -55,12 +55,12 @@ class RPN(M.Module):
for fm, stride in zip(features, self.stride_list)
]
pred_cls_score_list = []
pred_bbox_offsets_list = []
pred_cls_logit_list = []
pred_bbox_offset_list = []
for x in features:
t = F.relu(self.rpn_conv(x))
scores = self.rpn_cls_score(t)
pred_cls_score_list.append(
pred_cls_logit_list.append(
scores.reshape(
scores.shape[0],
2,
......@@ -70,7 +70,7 @@ class RPN(M.Module):
)
)
bbox_offsets = self.rpn_bbox_offsets(t)
pred_bbox_offsets_list.append(
pred_bbox_offset_list.append(
bbox_offsets.reshape(
bbox_offsets.shape[0],
self.num_cell_anchors,
......@@ -81,19 +81,19 @@ class RPN(M.Module):
)
# sample from the predictions
rpn_rois = self.find_top_rpn_proposals(
pred_bbox_offsets_list, pred_cls_score_list, all_anchors_list, im_info
pred_bbox_offset_list, pred_cls_logit_list, all_anchors_list, im_info
)
if self.training:
rpn_labels, rpn_bbox_targets = self.get_ground_truth(
boxes, im_info, all_anchors_list
)
pred_cls_score, pred_bbox_offsets = self.merge_rpn_score_box(
pred_cls_score_list, pred_bbox_offsets_list
pred_cls_logits, pred_bbox_offsets = self.merge_rpn_score_box(
pred_cls_logit_list, pred_bbox_offset_list
)
# rpn loss
loss_rpn_cls = layers.softmax_loss(pred_cls_score, rpn_labels)
loss_rpn_cls = layers.softmax_loss(pred_cls_logits, rpn_labels)
loss_rpn_loc = layers.get_smooth_l1_loss(
pred_bbox_offsets,
rpn_bbox_targets,
......@@ -107,7 +107,7 @@ class RPN(M.Module):
return rpn_rois
def find_top_rpn_proposals(
self, rpn_bbox_offsets_list, rpn_cls_prob_list, all_anchors_list, im_info
self, rpn_bbox_offset_list, rpn_cls_score_list, all_anchors_list, im_info
):
prev_nms_top_n = (
self.cfg.train_prev_nms_top_n
......@@ -123,37 +123,37 @@ class RPN(M.Module):
batch_per_gpu = self.cfg.batch_per_gpu if self.training else 1
nms_threshold = self.cfg.rpn_nms_threshold
list_size = len(rpn_bbox_offsets_list)
list_size = len(rpn_bbox_offset_list)
return_rois = []
for bid in range(batch_per_gpu):
batch_proposals_list = []
batch_probs_list = []
batch_proposal_list = []
batch_score_list = []
batch_level_list = []
for l in range(list_size):
# get proposals and probs
# get proposals and scores
offsets = (
rpn_bbox_offsets_list[l][bid].dimshuffle(2, 3, 0, 1).reshape(-1, 4)
rpn_bbox_offset_list[l][bid].dimshuffle(2, 3, 0, 1).reshape(-1, 4)
)
all_anchors = all_anchors_list[l]
proposals = self.box_coder.decode(all_anchors, offsets)
probs = rpn_cls_prob_list[l][bid, 1].dimshuffle(1, 2, 0).reshape(1, -1)
scores = rpn_cls_score_list[l][bid, 1].dimshuffle(1, 2, 0).reshape(1, -1)
# prev nms top n
probs, order = F.argsort(probs, descending=True)
num_proposals = F.minimum(probs.shapeof(1), prev_nms_top_n)
probs = probs.reshape(-1)[:num_proposals]
scores, order = F.argsort(scores, descending=True)
num_proposals = F.minimum(scores.shapeof(1), prev_nms_top_n)
scores = scores.reshape(-1)[:num_proposals]
order = order.reshape(-1)[:num_proposals]
proposals = proposals.ai[order, :]
batch_proposals_list.append(proposals)
batch_probs_list.append(probs)
batch_level_list.append(mge.ones(probs.shapeof(0)) * l)
batch_proposal_list.append(proposals)
batch_score_list.append(scores)
batch_level_list.append(mge.ones(scores.shapeof(0)) * l)
proposals = F.concat(batch_proposals_list, axis=0)
scores = F.concat(batch_probs_list, axis=0)
level = F.concat(batch_level_list, axis=0)
proposals = F.concat(batch_proposal_list, axis=0)
scores = F.concat(batch_score_list, axis=0)
levels = F.concat(batch_level_list, axis=0)
proposals = layers.get_clipped_box(proposals, im_info[bid, :])
# filter empty
......@@ -161,19 +161,19 @@ class RPN(M.Module):
_, keep_inds = F.cond_take(keep_mask == 1, keep_mask)
proposals = proposals.ai[keep_inds, :]
scores = scores.ai[keep_inds]
level = level.ai[keep_inds]
levels = levels.ai[keep_inds]
# gather the proposals and probs
# gather the proposals and scores
# sort nms by scores
scores, order = F.argsort(scores.reshape(1, -1), descending=True)
order = order.reshape(-1)
proposals = proposals.ai[order, :]
level = level.ai[order]
levels = levels.ai[order]
# apply total level nms
# apply total levels nms
rois = F.concat([proposals, scores.reshape(-1, 1)], axis=1)
keep_inds = batched_nms(
proposals, scores, level, nms_threshold, post_nms_top_n
proposals, scores, levels, nms_threshold, post_nms_top_n
)
rois = rois.ai[keep_inds]
......@@ -184,34 +184,34 @@ class RPN(M.Module):
return F.zero_grad(F.concat(return_rois, axis=0))
def merge_rpn_score_box(self, rpn_cls_score_list, rpn_bbox_offsets_list):
def merge_rpn_score_box(self, rpn_cls_score_list, rpn_bbox_offset_list):
final_rpn_cls_score_list = []
final_rpn_bbox_offsets_list = []
final_rpn_bbox_offset_list = []
for bid in range(self.cfg.batch_per_gpu):
batch_rpn_cls_score_list = []
batch_rpn_bbox_offsets_list = []
batch_rpn_bbox_offset_list = []
for i in range(len(self.in_features)):
rpn_cls_score = (
rpn_cls_scores = (
rpn_cls_score_list[i][bid].dimshuffle(2, 3, 1, 0).reshape(-1, 2)
)
rpn_bbox_offsets = (
rpn_bbox_offsets_list[i][bid].dimshuffle(2, 3, 0, 1).reshape(-1, 4)
rpn_bbox_offset_list[i][bid].dimshuffle(2, 3, 0, 1).reshape(-1, 4)
)
batch_rpn_cls_score_list.append(rpn_cls_score)
batch_rpn_bbox_offsets_list.append(rpn_bbox_offsets)
batch_rpn_cls_score_list.append(rpn_cls_scores)
batch_rpn_bbox_offset_list.append(rpn_bbox_offsets)
batch_rpn_cls_score = F.concat(batch_rpn_cls_score_list, axis=0)
batch_rpn_bbox_offsets = F.concat(batch_rpn_bbox_offsets_list, axis=0)
batch_rpn_cls_scores = F.concat(batch_rpn_cls_score_list, axis=0)
batch_rpn_bbox_offsets = F.concat(batch_rpn_bbox_offset_list, axis=0)
final_rpn_cls_score_list.append(batch_rpn_cls_score)
final_rpn_bbox_offsets_list.append(batch_rpn_bbox_offsets)
final_rpn_cls_score_list.append(batch_rpn_cls_scores)
final_rpn_bbox_offset_list.append(batch_rpn_bbox_offsets)
final_rpn_cls_score = F.concat(final_rpn_cls_score_list, axis=0)
final_rpn_bbox_offsets = F.concat(final_rpn_bbox_offsets_list, axis=0)
return final_rpn_cls_score, final_rpn_bbox_offsets
final_rpn_cls_scores = F.concat(final_rpn_cls_score_list, axis=0)
final_rpn_bbox_offsets = F.concat(final_rpn_bbox_offset_list, axis=0)
return final_rpn_cls_scores, final_rpn_bbox_offsets
def per_level_gt(self, gt_boxes, im_info, anchors, allow_low_quality_matches=True):
ignore_label = self.cfg.ignore_label
......@@ -292,10 +292,10 @@ class RPN(M.Module):
sample_label_mask = labels == sample_value
num_mask = sample_label_mask.sum()
num_final_samples = F.minimum(num_mask, num_samples)
# here, we use the bernoulli probability to sample the anchors
sample_prob = num_final_samples / num_mask
# here, we use the bernoulli scoreability to sample the anchors
sample_score = num_final_samples / num_mask
uniform_rng = rand.uniform(sample_label_mask.shapeof(0))
to_ignore_mask = (uniform_rng >= sample_prob) * sample_label_mask
to_ignore_mask = (uniform_rng >= sample_score) * sample_label_mask
labels = labels * (1 - to_ignore_mask) + to_ignore_mask * ignore_label
return labels
......@@ -78,6 +78,8 @@ class RetinaNet(M.Module):
),
}
self.loss_normalizer = mge.tensor(100.0)
def preprocess_image(self, image):
normed_image = (
image - np.array(self.cfg.img_mean)[None, :, None, None]
......@@ -89,14 +91,14 @@ class RetinaNet(M.Module):
features = self.backbone(image)
features = [features[f] for f in self.in_features]
box_cls, box_delta = self.head(features)
box_logits, box_offsets = self.head(features)
box_cls_list = [
box_logits_list = [
_.dimshuffle(0, 2, 3, 1).reshape(self.batch_size, -1, self.cfg.num_classes)
for _ in box_cls
for _ in box_logits
]
box_delta_list = [
_.dimshuffle(0, 2, 3, 1).reshape(self.batch_size, -1, 4) for _ in box_delta
box_offsets_list = [
_.dimshuffle(0, 2, 3, 1).reshape(self.batch_size, -1, 4) for _ in box_offsets
]
anchors_list = [
......@@ -104,32 +106,45 @@ class RetinaNet(M.Module):
for i in range(len(features))
]
all_level_box_cls = F.sigmoid(F.concat(box_cls_list, axis=1))
all_level_box_delta = F.concat(box_delta_list, axis=1)
all_level_box_logits = F.concat(box_logits_list, axis=1)
all_level_box_offsets = F.concat(box_offsets_list, axis=1)
all_level_anchors = F.concat(anchors_list, axis=0)
if self.training:
box_gt_cls, box_gt_delta = self.get_ground_truth(
box_gt_scores, box_gt_offsets = self.get_ground_truth(
all_level_anchors,
inputs["gt_boxes"],
inputs["im_info"][:, 4].astype(np.int32),
)
norm_type = "none" if self.cfg.loss_normalizer_momentum > 0.0 else "fg"
rpn_cls_loss = layers.get_focal_loss(
all_level_box_cls,
box_gt_cls,
all_level_box_logits,
box_gt_scores,
alpha=self.cfg.focal_loss_alpha,
gamma=self.cfg.focal_loss_gamma,
norm_type=norm_type,
)
rpn_bbox_loss = (
layers.get_smooth_l1_loss(
all_level_box_delta,
box_gt_delta,
box_gt_cls,
all_level_box_offsets,
box_gt_offsets,
box_gt_scores,
self.cfg.smooth_l1_beta,
norm_type=norm_type,
)
* self.cfg.reg_loss_weight
)
if norm_type == "none":
F.add_update(
self.loss_normalizer,
(box_gt_scores > 0).sum(),
alpha=self.cfg.loss_normalizer_momentum,
beta=1 - self.cfg.loss_normalizer_momentum,
)
rpn_cls_loss = rpn_cls_loss / F.maximum(self.loss_normalizer, 1)
rpn_bbox_loss = rpn_bbox_loss / F.maximum(self.loss_normalizer, 1)
total = rpn_cls_loss + rpn_bbox_loss
loss_dict = {
"total_loss": total,
......@@ -143,7 +158,7 @@ class RetinaNet(M.Module):
assert self.batch_size == 1
transformed_box = self.box_coder.decode(
all_level_anchors, all_level_box_delta[0],
all_level_anchors, all_level_box_offsets[0],
)
transformed_box = transformed_box.reshape(-1, 4)
......@@ -155,7 +170,8 @@ class RetinaNet(M.Module):
clipped_box = layers.get_clipped_box(
transformed_box, inputs["im_info"][0, 2:4]
).reshape(-1, 4)
return all_level_box_cls[0], clipped_box
all_level_box_scores = F.sigmoid(all_level_box_logits)
return all_level_box_scores[0], clipped_box
def get_ground_truth(self, anchors, batched_gt_boxes, batched_valid_gt_box_number):
total_anchors = anchors.shape[0]
......@@ -245,6 +261,7 @@ class RetinaNetConfig:
self.cls_prior_prob = 0.01
# ------------------------ loss cfg -------------------------- #
self.loss_normalizer_momentum = 0.9 # 0.0 means disable EMA normalizer
self.focal_loss_alpha = 0.25
self.focal_loss_gamma = 2
self.smooth_l1_beta = 0 # use L1 loss
......
......@@ -55,9 +55,16 @@ def main():
sys.path.insert(0, os.path.dirname(args.file))
current_network = importlib.import_module(os.path.basename(args.file).split(".")[0])
cfg = current_network.Cfg()
if args.end_epoch == -1:
args.end_epoch = args.start_epoch
if args.weight_file:
args.start_epoch = args.end_epoch = -1
else:
if args.start_epoch == -1:
args.start_epoch = cfg.max_epoch - 1
if args.end_epoch == -1:
args.end_epoch = args.start_epoch
assert 0 <= args.start_epoch <= args.end_epoch < cfg.max_epoch
for epoch_num in range(args.start_epoch, args.end_epoch + 1):
if args.weight_file:
......@@ -86,7 +93,6 @@ def main():
proc.start()
procs.append(proc)
cfg = current_network.Cfg()
num_imgs = dict(coco=5000, objects365=30000)
for _ in tqdm(range(num_imgs[cfg.test_dataset["name"]])):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册