未验证 提交 7720faa4 编写于 作者: C ceci3 提交者: GitHub

fix superbn states (#805)

* fix superbn
上级 6924c977
......@@ -954,25 +954,45 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
def forward(self, input):
self._check_data_format(self._data_format)
self._check_input_dim(input)
feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
mean_out = self._mean
variance_out = self._variance
mean_out_tmp = mean
variance_out_tmp = variance
if self._use_global_stats == None:
self._use_global_stats = not self.training
trainable_statistics = False
else:
trainable_statistics = not self._use_global_stats
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format,
"use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", self._use_global_stats,
"trainable_statistics", trainable_statistics)
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
batch_norm_out = core.ops.batch_norm(input, weight, bias,
self._mean, self._variance,
mean_out, variance_out, *attrs)
self.cur_config = {'prune_dim': feature_dim}
return F.batch_norm(
input,
mean,
variance,
weight=weight,
bias=bias,
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format=self._data_format,
use_global_stats=self._use_global_stats)
return batch_norm_out[0]
class SuperSyncBatchNorm(nn.SyncBatchNorm):
......@@ -990,7 +1010,7 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
self.cur_config = None
def forward(self, input):
self._check_data_format()
feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim]
......@@ -998,24 +1018,35 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
mean_out = self._mean
variance_out = self._variance
mean_out_tmp = mean
variance_out_tmp = variance
self.cur_config = {'prune_dim': feature_dim}
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format,
"use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", False, 'trainable_statistics', False)
if feature_dim != self._mean.shape[0]:
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs)
input, weight, bias, self._mean, self._variance, mean_out,
variance_out, *attrs)
return sync_batch_norm_out
class SuperInstanceNorm2D(nn.InstanceNorm2D):
"""
This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class.
This interface is used to construct a callable object of the ``SuperInstanceNorm2D`` class.
Parameters:
num_features(int): Indicate the number of channels of the input ``Tensor``.
......
......@@ -879,16 +879,30 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm):
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
mean_out = mean
variance_out = variance
mean_out = self._mean
variance_out = self._variance
mean_out_tmp = mean
variance_out_tmp = variance
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_layout,
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
"use_global_stats", self._use_global_stats,
'trainable_statistics', self._trainable_statistics)
batch_norm_out = core.ops.batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs)
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
else:
batch_norm_out = core.ops.batch_norm(input, weight, bias,
self._mean, self._variance,
mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out[0], act=self._act)
......
......@@ -21,8 +21,8 @@ import paddle.nn as nn
from paddle.nn import ReLU
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import *
from paddleslim.nas.ofa.layers_base import Block
class ModelCase1(nn.Layer):
......@@ -51,6 +51,16 @@ class ModelCase1(nn.Layer):
return self.models(inputs)
class ModelCase2(nn.Layer):
def __init__(self):
super(ModelCase2, self).__init__()
models = [SuperSyncBatchNorm(4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class TestCase(unittest.TestCase):
def setUp(self):
self.model = ModelCase1()
......@@ -62,5 +72,15 @@ class TestCase(unittest.TestCase):
out = self.model(self.data)
class TestCase2(TestCase):
def setUp(self):
self.model = ModelCase2()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
out = self.model(self.data)
if __name__ == '__main__':
unittest.main()
......@@ -122,6 +122,16 @@ class ModelCase3(nn.Layer):
return inputs
class ModelCase4(nn.Layer):
def __init__(self):
super(ModelCase4, self).__init__()
models = [SuperBatchNorm(4)]
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs):
return self.models(inputs)
class TestCase(unittest.TestCase):
def setUp(self):
self.model = ModelCase1()
......@@ -147,5 +157,15 @@ class TestCase3(TestCase):
self.data = paddle.to_tensor(data_np)
class TestCase4(TestCase):
def setUp(self):
self.model = ModelCase4()
data_np = np.random.random((1, 3, 64, 64)).astype(np.float32)
self.data = paddle.to_tensor(data_np)
def test_ofa(self):
out = self.model(self.data)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册