未验证 提交 42d2915a 编写于 作者: C ceci3 提交者: GitHub

[2.0 API] add SyncBatchNorm.convert_sync_batch_norm (#26688)

* add cnvert,test=develop
上级 68e0560c
......@@ -85,10 +85,35 @@ class TestBatchNorm(unittest.TestCase):
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v3(x, is_test, trainable_statistics):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(
shape[1],
is_test=is_test,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0),
trainable=False),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.0),
trainable=False),
trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
def compute_v4(x):
with fluid.dygraph.guard(p):
bn = paddle.nn.BatchNorm2d(
shape[1], weight_attr=False, bias_attr=False)
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x, False, False)
y2 = compute_v2(x)
y3 = compute_v3(x, False, False)
y4 = compute_v4(x)
self.assertTrue(np.allclose(y1, y2))
self.assertTrue(np.allclose(y3, y4))
def test_static(self):
places = [fluid.CPUPlace()]
......
......@@ -299,7 +299,7 @@ class TestLayer(LayerTest):
my_syncbn = paddle.nn.SyncBatchNorm(3)
dy_ret = my_syncbn(base.to_variable(t))
dy_ret_value = dy_ret.numpy()
self.assertTrue(np.array_equal(static_ret, static_ret))
self.assertTrue(np.array_equal(static_ret, dy_ret_value))
def test_relu(self):
with self.static_graph():
......
......@@ -221,5 +221,21 @@ class TestDygraphSyncBatchNormAPIError(unittest.TestCase):
self.assertRaises(TypeError, my_sync_batch_norm, x2)
class TestConvertSyncBatchNorm(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
with program_guard(Program(), Program()):
model = paddle.nn.Sequential(
paddle.nn.Conv2d(3, 5, 3), paddle.nn.BatchNorm2d(5))
sync_model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
for idx, sublayer in enumerate(model.sublayers()):
if isinstance(sublayer, paddle.nn.BatchNorm2d):
self.assertEqual(
isinstance(sync_model[idx], paddle.nn.SyncBatchNorm),
True)
if __name__ == '__main__':
unittest.main()
......@@ -27,6 +27,7 @@
# TODO: define normalization api
import six
from ...fluid.dygraph.nn import InstanceNorm
from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS
......@@ -36,7 +37,6 @@ from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS
from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS
from ...fluid.dygraph import layers
from ...framework import get_default_dtype, set_default_dtype
from ...fluid.framework import in_dygraph_mode
......@@ -50,6 +50,7 @@ from ..functional import batch_norm, layer_norm, instance_norm
import numpy as np
import numbers
import warnings
from ...fluid.dygraph.base import no_grad
__all__ = [
'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm',
......@@ -566,17 +567,28 @@ class _BatchNormBase(layers.Layer):
param_shape = [num_features]
# create parameter
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = (self._weight_attr is False) or (
self._weight_attr and self._weight_attr.learning_rate == 0.)
if weight_attr == False:
self.weight = self.create_parameter(
attr=None, shape=param_shape, default_initializer=Constant(1.0))
self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
self.bias.stop_gradient = (self._bias_attr is False) or (
self._bias_attr and self._bias_attr.learning_rate == 0.)
if bias_attr == False:
self.bias = self.create_parameter(
attr=None,
shape=param_shape,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True
else:
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
moving_mean_name = None
moving_variance_name = None
......@@ -611,6 +623,7 @@ class _BatchNormBase(layers.Layer):
self._epsilon = epsilon
self._fuse_with_relu = False
self._track_running_stats = track_running_stats
self._name = name
def _check_input_dim(self, input):
raise NotImplementedError("BatchNorm Base error")
......@@ -898,7 +911,7 @@ class BatchNorm3d(_BatchNormBase):
len(input.shape)))
class SyncBatchNorm(layers.Layer):
class SyncBatchNorm(_BatchNormBase):
"""
This interface is used to construct a callable object of the ``SyncBatchNorm`` class.
It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can
......@@ -984,72 +997,16 @@ class SyncBatchNorm(layers.Layer):
def __init__(self,
num_features,
epsilon=1e-05,
momentum=0.9,
track_running_stats=True,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
track_running_stats=True,
name=None):
super(SyncBatchNorm, self).__init__()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._num_features = num_features
self._data_layout = data_format
self._momentum = momentum
self._epsilon = epsilon
self._track_running_stats = track_running_stats
if self._track_running_stats == False:
warnings.warn(
"moving mean and moving variance will be calculated whether `track_running_stats` is set to `True` or `False`, we will fix it in the next version."
)
param_shape = [self._num_features]
# create parameter
if weight_attr == False:
self.weight = self.create_parameter(
attr=None, shape=param_shape, default_initializer=Constant(1.0))
self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
if bias_attr == False:
self.bias = self.create_parameter(
attr=None,
shape=param_shape,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True
else:
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
self.bias.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
self._mean = self.create_parameter(
attr=ParamAttr(
name=None,
initializer=Constant(0.0),
trainable=False,
do_model_average=True),
shape=param_shape,
dtype=self._dtype)
self._mean.stop_gradient = True
self._variance = self.create_parameter(
attr=ParamAttr(
name=None,
initializer=Constant(1.0),
trainable=False,
do_model_average=True),
shape=param_shape,
dtype=self._dtype)
self._variance.stop_gradient = True
super(SyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, track_running_stats, name)
def forward(self, x):
# create output
......@@ -1063,7 +1020,7 @@ class SyncBatchNorm(layers.Layer):
if in_dygraph_mode():
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", False, "fuse_with_relu",
self._data_format, "use_mkldnn", False, "fuse_with_relu",
False, "use_global_stats", False, 'trainable_statistics',
False)
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
......@@ -1073,13 +1030,13 @@ class SyncBatchNorm(layers.Layer):
return sync_batch_norm_out
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'BatchNorm')
'SyncBatchNorm')
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": not self.training,
"data_layout": self._data_layout,
"data_layout": self._data_format,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": False,
......@@ -1112,3 +1069,45 @@ class SyncBatchNorm(layers.Layer):
self._helper.append_op(
type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return sync_batch_norm_out
@classmethod
def convert_sync_batchnorm(cls, layer):
"""
Helper function to convert :class: `paddle.nn.BatchNorm*d` layers in the model to :class: `paddle.nn.SyncBatchNorm` layers.
Parameters:
layer(paddle.nn.Layer): model containing one or more `BatchNorm*d` layers.
Returns:
The original model with converted SyncBatchNorm layers. If BatchNorm*d layer in the model, use SyncBatchNorm layer instead.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
paddle.disable_static()
model = nn.Sequential(nn.Conv2d(3, 5, 3), nn.BatchNorm2d(5))
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
"""
layer_output = layer
if isinstance(layer, _BatchNormBase):
layer_output = SyncBatchNorm(layer._num_features, layer._epsilon,
layer._momentum, layer._weight_attr,
layer._bias_attr, layer._data_format,
layer._name)
if layer._weight_attr != False and layer._bias_attr != False:
with no_grad():
layer_output.weight = layer.weight
layer_output.bias = layer.bias
layer_output._mean = layer._mean
layer_output._variance = layer._variance
for name, sublayer in layer.named_sublayers():
layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer))
del layer
return layer_output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册