未验证 提交 7720faa4 编写于 作者: C ceci3 提交者: GitHub

fix superbn states (#805)

* fix superbn
上级 6924c977
...@@ -954,25 +954,45 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -954,25 +954,45 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
def forward(self, input): def forward(self, input):
self._check_data_format(self._data_format) self._check_data_format(self._data_format)
self._check_input_dim(input) self._check_input_dim(input)
feature_dim = int(input.shape[1]) feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim] weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim] bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim] mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim] variance = self._variance[:feature_dim]
mean_out = self._mean
variance_out = self._variance
mean_out_tmp = mean
variance_out_tmp = variance
if self._use_global_stats == None:
self._use_global_stats = not self.training
trainable_statistics = False
else:
trainable_statistics = not self._use_global_stats
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format,
"use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", self._use_global_stats,
"trainable_statistics", trainable_statistics)
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
batch_norm_out = core.ops.batch_norm(input, weight, bias,
self._mean, self._variance,
mean_out, variance_out, *attrs)
self.cur_config = {'prune_dim': feature_dim} self.cur_config = {'prune_dim': feature_dim}
return F.batch_norm( return batch_norm_out[0]
input,
mean,
variance,
weight=weight,
bias=bias,
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format=self._data_format,
use_global_stats=self._use_global_stats)
class SuperSyncBatchNorm(nn.SyncBatchNorm): class SuperSyncBatchNorm(nn.SyncBatchNorm):
...@@ -990,7 +1010,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): ...@@ -990,7 +1010,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
self.cur_config = None self.cur_config = None
def forward(self, input): def forward(self, input):
self._check_data_format()
feature_dim = int(input.shape[1]) feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim] weight = self.weight[:feature_dim]
...@@ -998,24 +1018,35 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm): ...@@ -998,24 +1018,35 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
mean = self._mean[:feature_dim] mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim] variance = self._variance[:feature_dim]
mean_out = mean mean_out = self._mean
# variance and variance out share the same memory variance_out = self._variance
variance_out = variance mean_out_tmp = mean
variance_out_tmp = variance
self.cur_config = {'prune_dim': feature_dim} self.cur_config = {'prune_dim': feature_dim}
attrs = ("momentum", self._momentum, "epsilon", self._epsilon, attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format, "is_test", not self.training, "data_layout", self._data_format,
"use_mkldnn", False, "fuse_with_relu", False, "use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", False, 'trainable_statistics', False) "use_global_stats", False, 'trainable_statistics', False)
if feature_dim != self._mean.shape[0]:
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs) input, weight, bias, self._mean, self._variance, mean_out,
variance_out, *attrs)
return sync_batch_norm_out return sync_batch_norm_out
class SuperInstanceNorm2D(nn.InstanceNorm2D): class SuperInstanceNorm2D(nn.InstanceNorm2D):
""" """
This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class. This interface is used to construct a callable object of the ``SuperInstanceNorm2D`` class.
Parameters: Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``. num_features(int): Indicate the number of channels of the input ``Tensor``.
......
...@@ -879,16 +879,30 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm): ...@@ -879,16 +879,30 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm):
mean = self._mean[:feature_dim] mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim] variance = self._variance[:feature_dim]
mean_out = mean mean_out = self._mean
variance_out = variance variance_out = self._variance
mean_out_tmp = mean
variance_out_tmp = variance
attrs = ("momentum", self._momentum, "epsilon", self._epsilon, attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_layout, "is_test", not self.training, "data_layout", self._data_layout,
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
"use_global_stats", self._use_global_stats, "use_global_stats", self._use_global_stats,
'trainable_statistics', self._trainable_statistics) 'trainable_statistics', self._trainable_statistics)
batch_norm_out = core.ops.batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs) if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
batch_norm_out = core.ops.batch_norm(input, weight, bias,
self._mean, self._variance,
mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out[0], act=self._act) batch_norm_out[0], act=self._act)
......
...@@ -21,8 +21,8 @@ import paddle.nn as nn ...@@ -21,8 +21,8 @@ import paddle.nn as nn
from paddle.nn import ReLU from paddle.nn import ReLU
from paddleslim.nas import ofa from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import * from paddleslim.nas.ofa.layers import *
from paddleslim.nas.ofa.layers_base import Block
class ModelCase1(nn.Layer): class ModelCase1(nn.Layer):
...@@ -51,6 +51,16 @@ class ModelCase1(nn.Layer): ...@@ -51,6 +51,16 @@ class ModelCase1(nn.Layer):
return self.models(inputs) return self.models(inputs)
class ModelCase2(nn.Layer):
def __init__(self):
super(ModelCase2, self).__init__()
models = [SuperSyncBatchNorm(4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.model = ModelCase1() self.model = ModelCase1()
...@@ -62,5 +72,15 @@ class TestCase(unittest.TestCase): ...@@ -62,5 +72,15 @@ class TestCase(unittest.TestCase):
out = self.model(self.data) out = self.model(self.data)
class TestCase2(TestCase):
def setUp(self):
self.model = ModelCase2()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
out = self.model(self.data)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -122,6 +122,16 @@ class ModelCase3(nn.Layer): ...@@ -122,6 +122,16 @@ class ModelCase3(nn.Layer):
return inputs return inputs
class ModelCase4(nn.Layer):
def __init__(self):
super(ModelCase4, self).__init__()
models = [SuperBatchNorm(4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.model = ModelCase1() self.model = ModelCase1()
...@@ -147,5 +157,15 @@ class TestCase3(TestCase): ...@@ -147,5 +157,15 @@ class TestCase3(TestCase):
self.data = paddle.to_tensor(data_np) self.data = paddle.to_tensor(data_np)
class TestCase4(TestCase):
def setUp(self):
self.model = ModelCase4()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
out = self.model(self.data)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册