From 833066b540ca097e111dfcd4bf493cb4f0902e15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Jan 2020 20:07:03 -0800 Subject: [PATCH] A few minor things in SplitBN --- timm/models/split_batchnorm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py index 0ed30d7..327c35b 100644 --- a/timm/models/split_batchnorm.py +++ b/timm/models/split_batchnorm.py @@ -6,9 +6,9 @@ import torch.nn.functional as F class SplitBatchNorm2d(torch.nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, - track_running_stats=True, num_splits=1): + track_running_stats=True, num_splits=2): super().__init__(num_features, eps, momentum, affine, track_running_stats) - assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)' + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' self.num_splits = num_splits self.aux_bn = nn.ModuleList([ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) @@ -35,8 +35,7 @@ def convert_splitbn_model(module, num_splits=2): num_splits: number of separate batchnorm layers to split input across Example:: >>> # model is an instance of torch.nn.Module - >>> import apex - >>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2) + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) """ mod = module if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): -- GitLab