From 7720faa49b2ee2d73c1ff7b0c51ffa12b910a27a Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 1 Jul 2021 16:49:14 +0800 Subject: [PATCH] fix superbn states (#805) * fix superbn --- paddleslim/nas/ofa/layers.py | 69 +++++++++++++++++++++++--------- paddleslim/nas/ofa/layers_old.py | 22 ++++++++-- tests/test_ofa_layers.py | 22 +++++++++- tests/test_ofa_layers_old.py | 20 +++++++++ 4 files changed, 109 insertions(+), 24 deletions(-) diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index c602d4ee..188f38cb 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -954,25 +954,45 @@ class SuperBatchNorm2D(nn.BatchNorm2D): def forward(self, input): self._check_data_format(self._data_format) self._check_input_dim(input) - feature_dim = int(input.shape[1]) weight = self.weight[:feature_dim] bias = self.bias[:feature_dim] mean = self._mean[:feature_dim] variance = self._variance[:feature_dim] + + mean_out = self._mean + variance_out = self._variance + mean_out_tmp = mean + variance_out_tmp = variance + + if self._use_global_stats == None: + self._use_global_stats = not self.training + trainable_statistics = False + else: + trainable_statistics = not self._use_global_stats + + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", self._data_format, + "use_mkldnn", False, "fuse_with_relu", False, + "use_global_stats", self._use_global_stats, + "trainable_statistics", trainable_statistics) + + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, + variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim] = mean + self._variance[:feature_dim] = variance + mean_out[:feature_dim] = mean_out_tmp + variance_out[:feature_dim] = variance_out_tmp + else: + batch_norm_out = core.ops.batch_norm(input, weight, bias, + self._mean, self._variance, + mean_out, variance_out, *attrs) + self.cur_config = {'prune_dim': feature_dim} - return F.batch_norm( - input, - mean, - variance, - weight=weight, - bias=bias, - training=self.training, - momentum=self._momentum, - epsilon=self._epsilon, - data_format=self._data_format, - use_global_stats=self._use_global_stats) + return batch_norm_out[0] class SuperSyncBatchNorm(nn.SyncBatchNorm): @@ -990,7 +1010,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): self.cur_config = None def forward(self, input): - + self._check_data_format() feature_dim = int(input.shape[1]) weight = self.weight[:feature_dim] @@ -998,24 +1018,35 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): mean = self._mean[:feature_dim] variance = self._variance[:feature_dim] - mean_out = mean - # variance and variance out share the same memory - variance_out = variance + mean_out = self._mean + variance_out = self._variance + mean_out_tmp = mean + variance_out_tmp = variance self.cur_config = {'prune_dim': feature_dim} attrs = ("momentum", self._momentum, "epsilon", self._epsilon, "is_test", not self.training, "data_layout", self._data_format, "use_mkldnn", False, "fuse_with_relu", False, "use_global_stats", False, 'trainable_statistics', False) - sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( - input, weight, bias, mean, variance, mean_out, variance_out, *attrs) + if feature_dim != self._mean.shape[0]: + sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( + input, weight, bias, mean, variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim] = mean + self._variance[:feature_dim] = variance + mean_out[:feature_dim] = mean_out_tmp + variance_out[:feature_dim] = variance_out_tmp + else: + sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( + input, weight, bias, self._mean, self._variance, mean_out, + variance_out, *attrs) return sync_batch_norm_out class SuperInstanceNorm2D(nn.InstanceNorm2D): """ - This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class. + This interface is used to construct a callable object of the ``SuperInstanceNorm2D`` class. Parameters: num_features(int): Indicate the number of channels of the input ``Tensor``. diff --git a/paddleslim/nas/ofa/layers_old.py b/paddleslim/nas/ofa/layers_old.py index 4c1dd51e..7ce58df8 100644 --- a/paddleslim/nas/ofa/layers_old.py +++ b/paddleslim/nas/ofa/layers_old.py @@ -879,16 +879,30 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm): mean = self._mean[:feature_dim] variance = self._variance[:feature_dim] - mean_out = mean - variance_out = variance + mean_out = self._mean + variance_out = self._variance + mean_out_tmp = mean + variance_out_tmp = variance attrs = ("momentum", self._momentum, "epsilon", self._epsilon, "is_test", not self.training, "data_layout", self._data_layout, "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, "use_global_stats", self._use_global_stats, 'trainable_statistics', self._trainable_statistics) - batch_norm_out = core.ops.batch_norm( - input, weight, bias, mean, variance, mean_out, variance_out, *attrs) + + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, + variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim] = mean + self._variance[:feature_dim] = variance + mean_out[:feature_dim] = mean_out_tmp + variance_out[:feature_dim] = variance_out_tmp + else: + batch_norm_out = core.ops.batch_norm(input, weight, bias, + self._mean, self._variance, + mean_out, variance_out, *attrs) + return dygraph_utils._append_activation_in_dygraph( batch_norm_out[0], act=self._act) diff --git a/tests/test_ofa_layers.py b/tests/test_ofa_layers.py index 41137544..11c550e3 100644 --- a/tests/test_ofa_layers.py +++ b/tests/test_ofa_layers.py @@ -21,8 +21,8 @@ import paddle.nn as nn from paddle.nn import ReLU from paddleslim.nas import ofa from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig -from paddleslim.nas.ofa.convert_super import supernet from paddleslim.nas.ofa.layers import * +from paddleslim.nas.ofa.layers_base import Block class ModelCase1(nn.Layer): @@ -51,6 +51,16 @@ class ModelCase1(nn.Layer): return self.models(inputs) +class ModelCase2(nn.Layer): + def __init__(self): + super(ModelCase2, self).__init__() + models = [SuperSyncBatchNorm(4)] + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs): + return self.models(inputs) + + class TestCase(unittest.TestCase): def setUp(self): self.model = ModelCase1() @@ -62,5 +72,15 @@ class TestCase(unittest.TestCase): out = self.model(self.data) +class TestCase2(TestCase): + def setUp(self): + self.model = ModelCase2() + data_np = np.random.random((1, 3, 64, 64)).astype(np.float32) + self.data = paddle.to_tensor(data_np) + + def test_ofa(self): + out = self.model(self.data) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_ofa_layers_old.py b/tests/test_ofa_layers_old.py index 69cf0167..4d66019f 100644 --- a/tests/test_ofa_layers_old.py +++ b/tests/test_ofa_layers_old.py @@ -122,6 +122,16 @@ class ModelCase3(nn.Layer): return inputs +class ModelCase4(nn.Layer): + def __init__(self): + super(ModelCase4, self).__init__() + models = [SuperBatchNorm(4)] + self.models = paddle.nn.Sequential(*models) + + def forward(self, inputs): + return self.models(inputs) + + class TestCase(unittest.TestCase): def setUp(self): self.model = ModelCase1() @@ -147,5 +157,15 @@ class TestCase3(TestCase): self.data = paddle.to_tensor(data_np) +class TestCase4(TestCase): + def setUp(self): + self.model = ModelCase4() + data_np = np.random.random((1, 3, 64, 64)).astype(np.float32) + self.data = paddle.to_tensor(data_np) + + def test_ofa(self): + out = self.model(self.data) + + if __name__ == '__main__': unittest.main() -- GitLab