提交 30292efc 编写于 作者: W Waleed Abdulla

Improve previous commit to avoid mutable default arguments

上级 5202a02d
......@@ -1637,7 +1637,7 @@ def generate_random_rois(image_shape, count, gt_class_ids, gt_boxes):
def data_generator(dataset, config, shuffle=True, augment=False, augmentation=None,
random_rois=0, batch_size=1, detection_targets=False,
no_augmentation_sources=[]):
no_augmentation_sources=None):
"""A generator that returns images and corresponding target class ids,
bounding box deltas, and masks.
......@@ -1656,6 +1656,9 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
detection_targets: If True, generate detection targets (class IDs, bbox
deltas, and masks). Typically for debugging or visualizations because
in trainig detection targets are generated by DetectionTargetLayer.
no_augmentation_sources: Optional. List of sources to exclude for
augmentation. A source is string that identifies a dataset and is
defined in the Dataset class.
Returns a Python generator. Upon calling next() on it, the
generator returns two lists, inputs and outputs. The containtes
......@@ -1674,13 +1677,12 @@ def data_generator(dataset, config, shuffle=True, augment=False, augmentation=No
outputs list: Usually empty in regular training. But if detection_targets
is True then the outputs list contains target class_ids, bbox deltas,
and masks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
"""
b = 0 # batch item index
image_index = -1
image_ids = np.copy(dataset.image_ids)
error_count = 0
no_augmentation_sources = no_augmentation_sources or []
# Anchors
# [anchor_count, (y1, x1, y2, x2)]
......@@ -2283,7 +2285,7 @@ class MaskRCNN():
"*epoch*", "{epoch:04d}")
def train(self, train_dataset, val_dataset, learning_rate, epochs, layers,
augmentation=None, custom_callbacks=[], no_augmentation_sources=[]):
augmentation=None, custom_callbacks=None, no_augmentation_sources=None):
"""Train the model.
train_dataset, val_dataset: Training and validation Dataset objects.
learning_rate: The learning rate to train with
......@@ -2310,12 +2312,11 @@ class MaskRCNN():
imgaug.augmenters.Fliplr(0.5),
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: (list) Optional. Add custom callbacks to be called
custom_callbacks: Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
no_augmentation_sources: (list) Optional. List of sources to be skipped for augmentation
no_augmentation_sources: Optional. List of sources to exclude for
augmentation. A source is string that identifies a dataset and is
defined in the Dataset class.
"""
assert self.mode == "training", "Create model in training mode."
......@@ -2339,8 +2340,7 @@ class MaskRCNN():
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
val_generator = data_generator(val_dataset, self.config, shuffle=True,
batch_size=self.config.BATCH_SIZE,
no_augmentation_sources=no_augmentation_sources)
batch_size=self.config.BATCH_SIZE)
# Callbacks
callbacks = [
......@@ -2351,7 +2351,8 @@ class MaskRCNN():
]
# Add custom callbacks to the list
callbacks+=custom_callbacks
if custom_callbacks:
callbacks += custom_callbacks
# Train
log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册