diff --git a/official/vision/losses/maskrcnn_losses.py b/official/vision/losses/maskrcnn_losses.py index 63c9a17a6c090b87bcf329ce5c6c7a03ad444376..dd5f5571755ca250f1afe04fc9b2bcebfcec82d3 100644 --- a/official/vision/losses/maskrcnn_losses.py +++ b/official/vision/losses/maskrcnn_losses.py @@ -136,8 +136,16 @@ class RpnBoxLoss(object): box_targets, box_outputs, sample_weight=valid_mask) # The loss is normalized by the sum of non-zero weights and additional # normalizer provided by the function caller. Using + 0.01 here to avoid - # division by zero. - box_loss /= normalizer * (tf.reduce_sum(valid_mask) + 0.01) + # division by zero. For each replica, get the sum of non-zero masks. Then + # get the mean of sums from all replicas. Note there is an extra division + # by `num_replicas` in train_step(). So it is equivalent to normalizing + # the box loss by the global sum of non-zero masks. + replica_context = tf.distribute.get_replica_context() + valid_mask = tf.reduce_sum(valid_mask) + valid_mask_mean = replica_context.all_reduce( + tf.distribute.ReduceOp.MEAN, valid_mask + ) + box_loss /= normalizer * (valid_mask_mean + 0.01) return box_loss @@ -291,8 +299,16 @@ class FastrcnnBoxLoss(object): box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask) # The loss is normalized by the number of ones in mask, # additional normalizer provided by the user and using 0.01 here to avoid - # division by 0. - box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01) + # division by 0. For each replica, get the sum of non-zero masks. Then + # get the mean of sums from all replicas. Note there is an extra division + # by `num_replicas` in train_step(). So it is equivalent to normalizing + # the box loss by the global sum of non-zero masks. + replica_context = tf.distribute.get_replica_context() + mask = tf.reduce_sum(mask) + mask_mean = replica_context.all_reduce( + tf.distribute.ReduceOp.MEAN, mask + ) + box_loss /= normalizer * (mask_mean + 0.01) return box_loss @@ -341,7 +357,15 @@ class MaskrcnnLoss(object): mask_outputs = tf.expand_dims(mask_outputs, axis=-1) mask_loss = self._binary_crossentropy(mask_targets, mask_outputs, sample_weight=weights) - + # For each replica, get the sum of non-zero weights. Then get the mean of + # sums from all replicas. Note there is an extra division by + # `num_replicas` in train_step(). So it is equivalent to normalizing the + # mask loss by the global sum of non-zero weights. + replica_context = tf.distribute.get_replica_context() + weights = tf.reduce_sum(weights) + weights_mean = replica_context.all_reduce( + tf.distribute.ReduceOp.MEAN, weights + ) # The loss is normalized by the number of 1's in weights and # + 0.01 is used to avoid division by zero. - return mask_loss / (tf.reduce_sum(weights) + 0.01) + return mask_loss / (weights_mean + 0.01)