diff --git a/official/vision/beta/configs/semantic_segmentation.py b/official/vision/beta/configs/semantic_segmentation.py index 3aefcefec705902912f74dd726c7cdaf287ba5f6..a952e3416bdbccc84177633e54fe268c5cdd3510 100644 --- a/official/vision/beta/configs/semantic_segmentation.py +++ b/official/vision/beta/configs/semantic_segmentation.py @@ -63,6 +63,7 @@ class SegmentationHead(hyperparams.Config): num_convs: int = 2 num_filters: int = 256 use_depthwise_convolution: bool = False + kernel_size: int = 3 prediction_kernel_size: int = 1 upsample_factor: int = 1 feature_fusion: Optional[ diff --git a/official/vision/beta/modeling/factory.py b/official/vision/beta/modeling/factory.py index b03d0ea9d55a47c914ea44d046196dfb79bb488a..b75c347a44e7bc8f9e81d2ef0a51d313a76142ab 100644 --- a/official/vision/beta/modeling/factory.py +++ b/official/vision/beta/modeling/factory.py @@ -356,6 +356,7 @@ def build_segmentation_model( num_classes=model_config.num_classes, level=head_config.level, num_convs=head_config.num_convs, + kernel_size=head_config.kernel_size, prediction_kernel_size=head_config.prediction_kernel_size, num_filters=head_config.num_filters, use_depthwise_convolution=head_config.use_depthwise_convolution,