提交 5202a02d 编写于 作者: N Nick 提交者: Waleed

Add no augmentation sources

Add the possibility to exclude some sources from augmentation by passing a list of sources. This is useful when you want to retrain a model having few images.
上级 23c82fd6
......@@ -1636,7 +1636,8 @@ def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes):
def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None,
random_rois=0, batch_size=1, detection_targets=False):
random_rois=0, batch_size=1, detection_targets=False,
no_augmentation_sources=[]):
"""A generator that returns images and corresponding target class ids,
bounding box deltas, and masks.
......@@ -1673,6 +1674,8 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
outputs list: Usually empty in regular training. But if detection_targets
is True then the outputs list contains target class_ids, bbox deltas,
and masks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
"""
b = 0 # batch item index
image_index = -1
......@@ -1698,10 +1701,18 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
# Get GT bounding boxes and masks for image.
image_id = image_ids[image_index]
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
# If the image source is not to be augmented pass None as augmentation
if dataset.image_info[image_id]['source'] in no_augmentation_sources:
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(dataset, config, image_id, augment=augment,
augmentation=augmentation,
augmentation=None,
use_mini_mask=config.USE_MINI_MASK)
else:
image, image_meta, gt_class_ids, gt_boxes, gt_masks = \
load_image_gt(dataset, config, image_id, augment=augment,
augmentation=augmentation,
use_mini_mask=config.USE_MINI_MASK)
# Skip images that have no instances. This can happen in cases
# where we train on a subset of classes and the image doesn't
......@@ -2272,7 +2283,7 @@ class MaskRCNN():
"*epoch*", "{epoch:04d}")
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=[]):
augmentation=None, custom_callbacks=[], no_augmentation_sources=[]):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
......@@ -2299,8 +2310,10 @@ class MaskRCNN():
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
custom_callbacks: (list) Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
"""
......@@ -2323,9 +2336,11 @@ class MaskRCNN():
# Data generators
train_generator = data_generator(train_dataset, self.config, shuffle=True,
augmentation=augmentation,
batch_size=self.config.BATCH_SIZE)
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
val_generator = data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE)
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
# Callbacks
callbacks = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册