Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
要开心的学习
Mask_RCNN
提交
8075ecc1
M
Mask_RCNN
项目概览
要开心的学习
/
Mask_RCNN
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
Mask_RCNN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
8075ecc1
编写于
7月 02, 2018
作者:
C
Corey Hu
提交者:
Waleed
7月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
edit loss desc
上级
8bed8428
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
31 addition
and
31 deletion
+31
-31
mrcnn/model.py
mrcnn/model.py
+31
-31
未找到文件。
mrcnn/model.py
浏览文件 @
8075ecc1
"""
"""
Mask R-CNN
Mask R-CNN
The main Mask R-CNN model implemen
e
tation.
The main Mask R-CNN model implementation.
Copyright (c) 2017 Matterport, Inc.
Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Licensed under the MIT License (see LICENSE for details)
...
@@ -63,7 +63,7 @@ class BatchNorm(KL.BatchNormalization):
...
@@ -63,7 +63,7 @@ class BatchNorm(KL.BatchNormalization):
Note about training values:
Note about training values:
None: Train BN layers. This is the normal mode
None: Train BN layers. This is the normal mode
False: Freeze BN layers. Good when batch size is small
False: Freeze BN layers. Good when batch size is small
True: (don't use). Set layer in training mode even when
inferencing
True: (don't use). Set layer in training mode even when
making inferences
"""
"""
return
super
(
self
.
__class__
,
self
).
call
(
inputs
,
training
=
training
)
return
super
(
self
.
__class__
,
self
).
call
(
inputs
,
training
=
training
)
...
@@ -97,12 +97,12 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
...
@@ -97,12 +97,12 @@ def identity_block(input_tensor, kernel_size, filters, stage, block,
"""The identity_block is the block that has no conv layer at shortcut
"""The identity_block is the block that has no conv layer at shortcut
# Arguments
# Arguments
input_tensor: input tensor
input_tensor: input tensor
kernel_size: def
ua
lt 3, the kernel size of middle conv layer at main path
kernel_size: def
au
lt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm lay
re
s
train_bn: Boolean. Train or freeze Batch Norm lay
er
s
"""
"""
nb_filter1
,
nb_filter2
,
nb_filter3
=
filters
nb_filter1
,
nb_filter2
,
nb_filter3
=
filters
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
...
@@ -132,12 +132,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block,
...
@@ -132,12 +132,12 @@ def conv_block(input_tensor, kernel_size, filters, stage, block,
"""conv_block is the block that has a conv layer at shortcut
"""conv_block is the block that has a conv layer at shortcut
# Arguments
# Arguments
input_tensor: input tensor
input_tensor: input tensor
kernel_size: def
ua
lt 3, the kernel size of middle conv layer at main path
kernel_size: def
au
lt 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm lay
re
s
train_bn: Boolean. Train or freeze Batch Norm lay
er
s
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well
And the shortcut should have subsample=(2,2) as well
"""
"""
...
@@ -172,7 +172,7 @@ def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
...
@@ -172,7 +172,7 @@ def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
"""Build a ResNet graph.
"""Build a ResNet graph.
architecture: Can be resnet50 or resnet101
architecture: Can be resnet50 or resnet101
stage5: Boolean. If False, stage5 of the network is not created
stage5: Boolean. If False, stage5 of the network is not created
train_bn: Boolean. Train or freeze Batch Norm lay
re
s
train_bn: Boolean. Train or freeze Batch Norm lay
er
s
"""
"""
assert
architecture
in
[
"resnet50"
,
"resnet101"
]
assert
architecture
in
[
"resnet50"
,
"resnet101"
]
# Stage 1
# Stage 1
...
@@ -337,7 +337,7 @@ class ProposalLayer(KE.Layer):
...
@@ -337,7 +337,7 @@ class ProposalLayer(KE.Layer):
############################################################
############################################################
def
log2_graph
(
x
):
def
log2_graph
(
x
):
"""Implementati
n of Log2. TF doesn't have a native implemen
ation."""
"""Implementati
on of Log2. TF doesn't have a native implement
ation."""
return
tf
.
log
(
x
)
/
tf
.
log
(
2.0
)
return
tf
.
log
(
x
)
/
tf
.
log
(
2.0
)
...
@@ -399,7 +399,7 @@ class PyramidROIAlign(KE.Layer):
...
@@ -399,7 +399,7 @@ class PyramidROIAlign(KE.Layer):
ix
=
tf
.
where
(
tf
.
equal
(
roi_level
,
level
))
ix
=
tf
.
where
(
tf
.
equal
(
roi_level
,
level
))
level_boxes
=
tf
.
gather_nd
(
boxes
,
ix
)
level_boxes
=
tf
.
gather_nd
(
boxes
,
ix
)
# Box indic
i
es for crop_and_resize.
# Box indices for crop_and_resize.
box_indices
=
tf
.
cast
(
ix
[:,
0
],
tf
.
int32
)
box_indices
=
tf
.
cast
(
ix
[:,
0
],
tf
.
int32
)
# Keep track of which box is mapped to which level
# Keep track of which box is mapped to which level
...
@@ -457,9 +457,9 @@ def overlaps_graph(boxes1, boxes2):
...
@@ -457,9 +457,9 @@ def overlaps_graph(boxes1, boxes2):
"""Computes IoU overlaps between two sets of boxes.
"""Computes IoU overlaps between two sets of boxes.
boxes1, boxes2: [N, (y1, x1, y2, x2)].
boxes1, boxes2: [N, (y1, x1, y2, x2)].
"""
"""
# 1. Tile boxes2 and repeat
e
boxes1. This allows us to compare
# 1. Tile boxes2 and repeat boxes1. This allows us to compare
# every boxes1 against every boxes2 without loops.
# every boxes1 against every boxes2 without loops.
# TF doesn't have an equivalent to np.repeat
e
() so simulate it
# TF doesn't have an equivalent to np.repeat() so simulate it
# using tf.tile() and tf.reshape.
# using tf.tile() and tf.reshape.
b1
=
tf
.
reshape
(
tf
.
tile
(
tf
.
expand_dims
(
boxes1
,
1
),
b1
=
tf
.
reshape
(
tf
.
tile
(
tf
.
expand_dims
(
boxes1
,
1
),
[
1
,
1
,
tf
.
shape
(
boxes2
)[
0
]]),
[
-
1
,
4
])
[
1
,
1
,
tf
.
shape
(
boxes2
)[
0
]]),
[
-
1
,
4
])
...
@@ -539,7 +539,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
...
@@ -539,7 +539,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
crowd_iou_max
=
tf
.
reduce_max
(
crowd_overlaps
,
axis
=
1
)
crowd_iou_max
=
tf
.
reduce_max
(
crowd_overlaps
,
axis
=
1
)
no_crowd_bool
=
(
crowd_iou_max
<
0.001
)
no_crowd_bool
=
(
crowd_iou_max
<
0.001
)
# Determine postive and negative ROIs
# Determine pos
i
tive and negative ROIs
roi_iou_max
=
tf
.
reduce_max
(
overlaps
,
axis
=
1
)
roi_iou_max
=
tf
.
reduce_max
(
overlaps
,
axis
=
1
)
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
positive_roi_bool
=
(
roi_iou_max
>=
0.5
)
positive_roi_bool
=
(
roi_iou_max
>=
0.5
)
...
@@ -584,7 +584,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
...
@@ -584,7 +584,7 @@ def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config)
# Compute mask targets
# Compute mask targets
boxes
=
positive_rois
boxes
=
positive_rois
if
config
.
USE_MINI_MASK
:
if
config
.
USE_MINI_MASK
:
# Transform ROI co
r
rdinates from normalized image space
# Transform ROI co
o
rdinates from normalized image space
# to normalized mini-mask space.
# to normalized mini-mask space.
y1
,
x1
,
y2
,
x2
=
tf
.
split
(
positive_rois
,
4
,
axis
=
1
)
y1
,
x1
,
y2
,
x2
=
tf
.
split
(
positive_rois
,
4
,
axis
=
1
)
gt_y1
,
gt_x1
,
gt_y2
,
gt_x2
=
tf
.
split
(
roi_gt_boxes
,
4
,
axis
=
1
)
gt_y1
,
gt_x1
,
gt_y2
,
gt_x2
=
tf
.
split
(
roi_gt_boxes
,
4
,
axis
=
1
)
...
@@ -741,7 +741,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
...
@@ -741,7 +741,7 @@ def refine_detections_graph(rois, probs, deltas, window, config):
tf
.
gather
(
pre_nms_scores
,
ixs
),
tf
.
gather
(
pre_nms_scores
,
ixs
),
max_output_size
=
config
.
DETECTION_MAX_INSTANCES
,
max_output_size
=
config
.
DETECTION_MAX_INSTANCES
,
iou_threshold
=
config
.
DETECTION_NMS_THRESHOLD
)
iou_threshold
=
config
.
DETECTION_NMS_THRESHOLD
)
# Map indic
i
es
# Map indices
class_keep
=
tf
.
gather
(
keep
,
tf
.
gather
(
ixs
,
class_keep
))
class_keep
=
tf
.
gather
(
keep
,
tf
.
gather
(
ixs
,
class_keep
))
# Pad with -1 so returned tensors have the same shape
# Pad with -1 so returned tensors have the same shape
gap
=
config
.
DETECTION_MAX_INSTANCES
-
tf
.
shape
(
class_keep
)[
0
]
gap
=
config
.
DETECTION_MAX_INSTANCES
-
tf
.
shape
(
class_keep
)[
0
]
...
@@ -844,8 +844,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
...
@@ -844,8 +844,8 @@ def rpn_graph(feature_map, anchors_per_location, anchor_stride):
rpn_bbox: [batch, H, W, (dy, dx, log(dh), log(dw))] Deltas to be
rpn_bbox: [batch, H, W, (dy, dx, log(dh), log(dw))] Deltas to be
applied to anchors.
applied to anchors.
"""
"""
# TODO: check if stride of 2 causes alignment issues if the featuremap
# TODO: check if stride of 2 causes alignment issues if the feature
map
#
is not even.
# is not even.
# Shared convolutional base of the RPN
# Shared convolutional base of the RPN
shared
=
KL
.
Conv2D
(
512
,
(
3
,
3
),
padding
=
'same'
,
activation
=
'relu'
,
shared
=
KL
.
Conv2D
(
512
,
(
3
,
3
),
padding
=
'same'
,
activation
=
'relu'
,
strides
=
anchor_stride
,
strides
=
anchor_stride
,
...
@@ -908,12 +908,12 @@ def fpn_classifier_graph(rois, feature_maps, image_meta,
...
@@ -908,12 +908,12 @@ def fpn_classifier_graph(rois, feature_maps, image_meta,
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates.
coordinates.
feature_maps: List of feature maps from diffent layers of the pyramid,
feature_maps: List of feature maps from diffe
re
nt layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution.
[P2, P3, P4, P5]. Each has a different resolution.
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling.
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm lay
re
s
train_bn: Boolean. Train or freeze Batch Norm lay
er
s
fc_layers_size: Size of the 2 FC layers
fc_layers_size: Size of the 2 FC layers
Returns:
Returns:
...
@@ -962,12 +962,12 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
...
@@ -962,12 +962,12 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates.
coordinates.
feature_maps: List of feature maps from diffent layers of the pyramid,
feature_maps: List of feature maps from diffe
re
nt layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution.
[P2, P3, P4, P5]. Each has a different resolution.
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling.
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm lay
re
s
train_bn: Boolean. Train or freeze Batch Norm lay
er
s
Returns: Masks [batch, roi_count, height, width, num_classes]
Returns: Masks [batch, roi_count, height, width, num_classes]
"""
"""
...
@@ -1014,7 +1014,7 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
...
@@ -1014,7 +1014,7 @@ def build_fpn_mask_graph(rois, feature_maps, image_meta,
def
smooth_l1_loss
(
y_true
,
y_pred
):
def
smooth_l1_loss
(
y_true
,
y_pred
):
"""Implements Smooth-L1 loss.
"""Implements Smooth-L1 loss.
y_true and y_pred are typicall
l
y: [N, 4], but could be any shape.
y_true and y_pred are typically: [N, 4], but could be any shape.
"""
"""
diff
=
K
.
abs
(
y_true
-
y_pred
)
diff
=
K
.
abs
(
y_true
-
y_pred
)
less_than_one
=
K
.
cast
(
K
.
less
(
diff
,
1.0
),
"float32"
)
less_than_one
=
K
.
cast
(
K
.
less
(
diff
,
1.0
),
"float32"
)
...
@@ -1039,7 +1039,7 @@ def rpn_class_loss_graph(rpn_match, rpn_class_logits):
...
@@ -1039,7 +1039,7 @@ def rpn_class_loss_graph(rpn_match, rpn_class_logits):
# Pick rows that contribute to the loss and filter out the rest.
# Pick rows that contribute to the loss and filter out the rest.
rpn_class_logits
=
tf
.
gather_nd
(
rpn_class_logits
,
indices
)
rpn_class_logits
=
tf
.
gather_nd
(
rpn_class_logits
,
indices
)
anchor_class
=
tf
.
gather_nd
(
anchor_class
,
indices
)
anchor_class
=
tf
.
gather_nd
(
anchor_class
,
indices
)
# Crossentropy loss
# Cross
entropy loss
loss
=
K
.
sparse_categorical_crossentropy
(
target
=
anchor_class
,
loss
=
K
.
sparse_categorical_crossentropy
(
target
=
anchor_class
,
output
=
rpn_class_logits
,
output
=
rpn_class_logits
,
from_logits
=
True
)
from_logits
=
True
)
...
@@ -1129,7 +1129,7 @@ def mrcnn_bbox_loss_graph(target_bbox, target_class_ids, pred_bbox):
...
@@ -1129,7 +1129,7 @@ def mrcnn_bbox_loss_graph(target_bbox, target_class_ids, pred_bbox):
pred_bbox
=
K
.
reshape
(
pred_bbox
,
(
-
1
,
K
.
int_shape
(
pred_bbox
)[
2
],
4
))
pred_bbox
=
K
.
reshape
(
pred_bbox
,
(
-
1
,
K
.
int_shape
(
pred_bbox
)[
2
],
4
))
# Only positive ROIs contribute to the loss. And only
# Only positive ROIs contribute to the loss. And only
# the right class_id of each ROI. Get their indic
i
es.
# the right class_id of each ROI. Get their indices.
positive_roi_ix
=
tf
.
where
(
target_class_ids
>
0
)[:,
0
]
positive_roi_ix
=
tf
.
where
(
target_class_ids
>
0
)[:,
0
]
positive_roi_class_ids
=
tf
.
cast
(
positive_roi_class_ids
=
tf
.
cast
(
tf
.
gather
(
target_class_ids
,
positive_roi_ix
),
tf
.
int64
)
tf
.
gather
(
target_class_ids
,
positive_roi_ix
),
tf
.
int64
)
...
@@ -1194,7 +1194,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
...
@@ -1194,7 +1194,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
use_mini_mask
=
False
):
use_mini_mask
=
False
):
"""Load and return ground truth data for an image (image, mask, bounding boxes).
"""Load and return ground truth data for an image (image, mask, bounding boxes).
augment: (
Depri
cated. Use augmentation instead). If true, apply random
augment: (
depre
cated. Use augmentation instead). If true, apply random
image augmentation. Currently, only horizontal flipping is offered.
image augmentation. Currently, only horizontal flipping is offered.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
...
@@ -1229,7 +1229,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
...
@@ -1229,7 +1229,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
# Random horizontal flips.
# Random horizontal flips.
# TODO: will be removed in a future update in favor of augmentation
# TODO: will be removed in a future update in favor of augmentation
if
augment
:
if
augment
:
logging
.
warning
(
"'augment' is depr
i
cated. Use 'augmentation' instead."
)
logging
.
warning
(
"'augment' is depr
e
cated. Use 'augmentation' instead."
)
if
random
.
randint
(
0
,
1
):
if
random
.
randint
(
0
,
1
):
image
=
np
.
fliplr
(
image
)
image
=
np
.
fliplr
(
image
)
mask
=
np
.
fliplr
(
mask
)
mask
=
np
.
fliplr
(
mask
)
...
@@ -1239,7 +1239,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
...
@@ -1239,7 +1239,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
if
augmentation
:
if
augmentation
:
import
imgaug
import
imgaug
# Augment
o
rs that are safe to apply to masks
# Augment
e
rs that are safe to apply to masks
# Some, such as Affine, have settings that make them unsafe, so always
# Some, such as Affine, have settings that make them unsafe, so always
# test your augmentation on masks
# test your augmentation on masks
MASK_AUGMENTERS
=
[
"Sequential"
,
"SomeOf"
,
"OneOf"
,
"Sometimes"
,
MASK_AUGMENTERS
=
[
"Sequential"
,
"SomeOf"
,
"OneOf"
,
"Sometimes"
,
...
@@ -1248,7 +1248,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
...
@@ -1248,7 +1248,7 @@ def load_image_gt(dataset, config, image_id, augment=False, augmentation=None,
def
hook
(
images
,
augmenter
,
parents
,
default
):
def
hook
(
images
,
augmenter
,
parents
,
default
):
"""Determines which augmenters to apply to masks."""
"""Determines which augmenters to apply to masks."""
return
(
augmenter
.
__class__
.
__name__
in
MASK_AUGMENTERS
)
return
augmenter
.
__class__
.
__name__
in
MASK_AUGMENTERS
# Store shapes before augmentation to compare
# Store shapes before augmentation to compare
image_shape
=
image
.
shape
image_shape
=
image
.
shape
...
@@ -1302,7 +1302,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
...
@@ -1302,7 +1302,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes.
rpn_rois: [N, (y1, x1, y2, x2)] proposal boxes.
gt_class_ids: [instance count] Integer class IDs
gt_class_ids: [instance count] Integer class IDs
gt_boxes: [instance count, (y1, x1, y2, x2)]
gt_boxes: [instance count, (y1, x1, y2, x2)]
gt_masks: [height, width, instance count] Grund truth masks. Can be full
gt_masks: [height, width, instance count] Gr
o
und truth masks. Can be full
size or mini-masks.
size or mini-masks.
Returns:
Returns:
...
@@ -1357,7 +1357,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
...
@@ -1357,7 +1357,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
# Negative ROIs are those with max IoU 0.1-0.5 (hard example mining)
# Negative ROIs are those with max IoU 0.1-0.5 (hard example mining)
# TODO: To hard example mine or not to hard example mine, that's the question
# TODO: To hard example mine or not to hard example mine, that's the question
#
bg_ids = np.where((rpn_roi_iou_max >= 0.1) & (rpn_roi_iou_max < 0.5))[0]
#
bg_ids = np.where((rpn_roi_iou_max >= 0.1) & (rpn_roi_iou_max < 0.5))[0]
bg_ids
=
np
.
where
(
rpn_roi_iou_max
<
0.5
)[
0
]
bg_ids
=
np
.
where
(
rpn_roi_iou_max
<
0.5
)[
0
]
# Subsample ROIs. Aim for 33% foreground.
# Subsample ROIs. Aim for 33% foreground.
...
@@ -1373,7 +1373,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
...
@@ -1373,7 +1373,7 @@ def build_detection_targets(rpn_rois, gt_class_ids, gt_boxes, gt_masks, config):
keep_bg_ids
=
np
.
random
.
choice
(
bg_ids
,
remaining
,
replace
=
False
)
keep_bg_ids
=
np
.
random
.
choice
(
bg_ids
,
remaining
,
replace
=
False
)
else
:
else
:
keep_bg_ids
=
bg_ids
keep_bg_ids
=
bg_ids
# Combine indic
i
es of ROIs to keep
# Combine indices of ROIs to keep
keep
=
np
.
concatenate
([
keep_fg_ids
,
keep_bg_ids
])
keep
=
np
.
concatenate
([
keep_fg_ids
,
keep_bg_ids
])
# Need more?
# Need more?
remaining
=
config
.
TRAIN_ROIS_PER_IMAGE
-
keep
.
shape
[
0
]
remaining
=
config
.
TRAIN_ROIS_PER_IMAGE
-
keep
.
shape
[
0
]
...
@@ -1644,7 +1644,7 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
...
@@ -1644,7 +1644,7 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
dataset: The Dataset object to pick data from
dataset: The Dataset object to pick data from
config: The model config object
config: The model config object
shuffle: If True, shuffles the samples before every epoch
shuffle: If True, shuffles the samples before every epoch
augment: (
Depri
cated. Use augmentation instead). If true, apply random
augment: (
depre
cated. Use augmentation instead). If true, apply random
image augmentation. Currently, only horizontal flipping is offered.
image augmentation. Currently, only horizontal flipping is offered.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
augmentation: Optional. An imgaug (https://github.com/aleju/imgaug) augmentation.
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
For example, passing imgaug.augmenters.Fliplr(0.5) flips images
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录