提交 2a7bcfc8 编写于 作者: K Kolesnikov Sergey 提交者: Waleed

loss weights

上级 6cfc657c
...@@ -164,6 +164,16 @@ class Config(object): ...@@ -164,6 +164,16 @@ class Config(object):
# Weight decay regularization # Weight decay regularization
WEIGHT_DECAY = 0.0001 WEIGHT_DECAY = 0.0001
# Loss weights for more precise optimization.
# Can be used for R-CNN training setup.
LOSS_WEIGHTS = {
"rpn_class_loss": 1.,
"rpn_bbox_loss": 1.,
"mrcnn_class_loss": 1.,
"mrcnn_bbox_loss": 1.,
"mrcnn_mask_loss": 1.
}
# Use RPN ROIs or externally generated ROIs for training # Use RPN ROIs or externally generated ROIs for training
# Keep this True for most situations. Set to False if you want to train # Keep this True for most situations. Set to False if you want to train
# the head branches on ROI generated by code rather than the ROIs from # the head branches on ROI generated by code rather than the ROIs from
......
...@@ -2122,31 +2122,37 @@ class MaskRCNN(): ...@@ -2122,31 +2122,37 @@ class MaskRCNN():
metrics. Then calls the Keras compile() function. metrics. Then calls the Keras compile() function.
""" """
# Optimizer object # Optimizer object
optimizer = keras.optimizers.SGD(lr=learning_rate, momentum=momentum, optimizer = keras.optimizers.SGD(
clipnorm=self.config.GRADIENT_CLIP_NORM) lr=learning_rate, momentum=momentum,
clipnorm=self.config.GRADIENT_CLIP_NORM)
# Add Losses # Add Losses
# First, clear previously set losses to avoid duplication # First, clear previously set losses to avoid duplication
self.keras_model._losses = [] self.keras_model._losses = []
self.keras_model._per_input_losses = {} self.keras_model._per_input_losses = {}
loss_names = ["rpn_class_loss", "rpn_bbox_loss", loss_names = [
"mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"] "rpn_class_loss", "rpn_bbox_loss",
"mrcnn_class_loss", "mrcnn_bbox_loss", "mrcnn_mask_loss"]
for name in loss_names: for name in loss_names:
layer = self.keras_model.get_layer(name) layer = self.keras_model.get_layer(name)
if layer.output in self.keras_model.losses: if layer.output in self.keras_model.losses:
continue continue
self.keras_model.add_loss( loss = (
tf.reduce_mean(layer.output, keep_dims=True)) tf.reduce_mean(layer.output, keep_dims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.add_loss(loss)
# Add L2 Regularization # Add L2 Regularization
# Skip gamma and beta weights of batch normalization layers. # Skip gamma and beta weights of batch normalization layers.
reg_losses = [keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32) reg_losses = [
for w in self.keras_model.trainable_weights keras.regularizers.l2(self.config.WEIGHT_DECAY)(w) / tf.cast(tf.size(w), tf.float32)
if 'gamma' not in w.name and 'beta' not in w.name] for w in self.keras_model.trainable_weights
if 'gamma' not in w.name and 'beta' not in w.name]
self.keras_model.add_loss(tf.add_n(reg_losses)) self.keras_model.add_loss(tf.add_n(reg_losses))
# Compile # Compile
self.keras_model.compile(optimizer=optimizer, loss=[ self.keras_model.compile(
None] * len(self.keras_model.outputs)) optimizer=optimizer,
loss=[None] * len(self.keras_model.outputs))
# Add metrics for losses # Add metrics for losses
for name in loss_names: for name in loss_names:
...@@ -2154,8 +2160,10 @@ class MaskRCNN(): ...@@ -2154,8 +2160,10 @@ class MaskRCNN():
continue continue
layer = self.keras_model.get_layer(name) layer = self.keras_model.get_layer(name)
self.keras_model.metrics_names.append(name) self.keras_model.metrics_names.append(name)
self.keras_model.metrics_tensors.append(tf.reduce_mean( loss = (
layer.output, keep_dims=True)) tf.reduce_mean(layer.output, keep_dims=True)
* self.config.LOSS_WEIGHTS.get(name, 1.))
self.keras_model.metrics_tensors.append(loss)
def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1): def set_trainable(self, layer_regex, keras_model=None, indent=0, verbose=1):
"""Sets model layers as trainable if their names match """Sets model layers as trainable if their names match
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册