未验证 提交 c7b5ac4b 编写于 作者: Z zhang wenhui 提交者: GitHub

fix norm bug, test=develop (#26827)

* fix norm bug, test=develop

* fix norm bug, test=develop

* fix norm bug, test=develop

* fix norm bug, test=develop

* fix norm bug, test=develop
上级 9ee4e3dc
...@@ -43,6 +43,21 @@ class TestBatchNorm(unittest.TestCase): ...@@ -43,6 +43,21 @@ class TestBatchNorm(unittest.TestCase):
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32') x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
def error1d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1d(1, data_format='NCDHW')
batch_norm1d(fluid.dygraph.to_variable(x_data_4))
def error2d_dataformat():
x_data_3 = np.random.random(size=(2, 1, 3)).astype('float32')
batch_norm2d = paddle.nn.BatchNorm2d(1, data_format='NCDHW')
batch_norm2d(fluid.dygraph.to_variable(x_data_3))
def error3d_dataformat():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm3d = paddle.nn.BatchNorm3d(1, data_format='NCL')
batch_norm3d(fluid.dygraph.to_variable(x_data_4))
def error1d(): def error1d():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32') x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
batch_norm1d = paddle.nn.BatchNorm1d(1) batch_norm1d = paddle.nn.BatchNorm1d(1)
...@@ -62,6 +77,9 @@ class TestBatchNorm(unittest.TestCase): ...@@ -62,6 +77,9 @@ class TestBatchNorm(unittest.TestCase):
self.assertRaises(ValueError, error1d) self.assertRaises(ValueError, error1d)
self.assertRaises(ValueError, error2d) self.assertRaises(ValueError, error2d)
self.assertRaises(ValueError, error3d) self.assertRaises(ValueError, error3d)
self.assertRaises(ValueError, error1d_dataformat)
self.assertRaises(ValueError, error2d_dataformat)
self.assertRaises(ValueError, error3d_dataformat)
def test_dygraph(self): def test_dygraph(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
......
...@@ -35,24 +35,33 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -35,24 +35,33 @@ class TestDygraphGroupNormv2(unittest.TestCase):
def compute_v1(x): def compute_v1(x):
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
gn = fluid.dygraph.GroupNorm(channels=2, groups=2) gn = fluid.dygraph.GroupNorm(channels=6, groups=2)
y = gn(fluid.dygraph.to_variable(x)) y = gn(fluid.dygraph.to_variable(x))
return y.numpy() return y.numpy()
def compute_v2(x): def compute_v2(x):
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2) gn = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
y = gn(fluid.dygraph.to_variable(x)) y = gn(fluid.dygraph.to_variable(x))
return y.numpy() return y.numpy()
def test_weight_bias_false():
with fluid.dygraph.guard(p):
gn = paddle.nn.GroupNorm(
num_channels=6,
num_groups=2,
weight_attr=False,
bias_attr=False)
x = np.random.randn(*shape).astype("float32") x = np.random.randn(*shape).astype("float32")
y1 = compute_v1(x) y1 = compute_v1(x)
y2 = compute_v2(x) y2 = compute_v2(x)
self.assertTrue(np.allclose(y1, y2)) self.assertTrue(np.allclose(y1, y2))
test_weight_bias_false()
def test_static(self): def test_static(self):
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("layer_norm"): if core.is_compiled_with_cuda() and core.op_support_gpu("group_norm"):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
for p in places: for p in places:
exe = fluid.Executor(p) exe = fluid.Executor(p)
...@@ -60,7 +69,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -60,7 +69,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
def compute_v1(x_np): def compute_v1(x_np):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
gn = fluid.dygraph.GroupNorm(channels=2, groups=2) gn = fluid.dygraph.GroupNorm(channels=6, groups=2)
x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = gn(x) y = gn(x)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -69,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase): ...@@ -69,7 +78,7 @@ class TestDygraphGroupNormv2(unittest.TestCase):
def compute_v2(x_np): def compute_v2(x_np):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
gn = paddle.nn.GroupNorm(num_channels=2, num_groups=2) gn = paddle.nn.GroupNorm(num_channels=6, num_groups=2)
x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype) x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = gn(x) y = gn(x)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -48,7 +48,13 @@ class TestInstanceNorm(unittest.TestCase): ...@@ -48,7 +48,13 @@ class TestInstanceNorm(unittest.TestCase):
instance_norm3d = paddle.nn.BatchNorm3d(1) instance_norm3d = paddle.nn.BatchNorm3d(1)
instance_norm3d(fluid.dygraph.to_variable(x_data_4)) instance_norm3d(fluid.dygraph.to_variable(x_data_4))
def weight_bias_false():
x_data_4 = np.random.random(size=(2, 1, 3, 3)).astype('float32')
instance_norm3d = paddle.nn.BatchNorm3d(
1, weight_attr=False, bias_attr=False)
with fluid.dygraph.guard(p): with fluid.dygraph.guard(p):
weight_bias_false()
self.assertRaises(ValueError, error1d) self.assertRaises(ValueError, error1d)
self.assertRaises(ValueError, error2d) self.assertRaises(ValueError, error2d)
self.assertRaises(ValueError, error3d) self.assertRaises(ValueError, error3d)
......
...@@ -165,7 +165,7 @@ def batch_norm(x, ...@@ -165,7 +165,7 @@ def batch_norm(x,
w = paddle.to_tensor(weight_data) w = paddle.to_tensor(weight_data)
b = paddle.to_tensor(bias_data) b = paddle.to_tensor(bias_data)
batch_norm_out = paddle.nn.functional.batch_norm(x, rm, rv, w, b) batch_norm_out = paddle.nn.functional.batch_norm(x, rm, rv, w, b)
print batch_norm_out print(batch_norm_out.numpy())
""" """
assert len(x.shape) >= 2, "input dim must be larger than 1" assert len(x.shape) >= 2, "input dim must be larger than 1"
...@@ -176,6 +176,15 @@ def batch_norm(x, ...@@ -176,6 +176,15 @@ def batch_norm(x,
mean_out = running_mean mean_out = running_mean
variance_out = running_var variance_out = running_var
true_data_format = ['NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW']
if data_format not in true_data_format:
raise ValueError(
"data_format must be one of 'NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW', but receive {}".
format(data_format))
if data_format != 'NCWH':
data_format = 'NCHW'
if in_dygraph_mode(): if in_dygraph_mode():
# for dygraph need tuple # for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout", attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
...@@ -270,7 +279,7 @@ def layer_norm(x, ...@@ -270,7 +279,7 @@ def layer_norm(x,
layer_norm = paddle.nn.functional.layer_norm(x, x.shape[1:]) layer_norm = paddle.nn.functional.layer_norm(x, x.shape[1:])
layer_norm_out = layer_norm(x) layer_norm_out = layer_norm(x)
print(layer_norm_out.numpy) print(layer_norm_out.numpy())
""" """
input_shape = list(x.shape) input_shape = list(x.shape)
input_ndim = len(input_shape) input_ndim = len(input_shape)
...@@ -302,10 +311,10 @@ def layer_norm(x, ...@@ -302,10 +311,10 @@ def layer_norm(x,
# create output # create output
helper = LayerHelper('layer_norm', **locals()) helper = LayerHelper('layer_norm', **locals())
mean_out = helper.create_variable_for_type_inference( mean_out = helper.create_variable_for_type_inference(
dtype=x.type, stop_gradient=True) dtype=x.dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference( variance_out = helper.create_variable_for_type_inference(
dtype=x.type, stop_gradient=True) dtype=x.dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(x.type) layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type="layer_norm", type="layer_norm",
...@@ -362,7 +371,7 @@ def instance_norm(x, ...@@ -362,7 +371,7 @@ def instance_norm(x,
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
instance_norm_out = paddle.nn.functional.instancenorm(x) instance_norm_out = paddle.nn.functional.instancenorm(x)
print(instance_norm_out.numpy) print(instance_norm_out.numpy())
""" """
......
...@@ -78,7 +78,7 @@ class _InstanceNormBase(layers.Layer): ...@@ -78,7 +78,7 @@ class _InstanceNormBase(layers.Layer):
super(_InstanceNormBase, self).__init__() super(_InstanceNormBase, self).__init__()
if weight_attr == False or bias_attr == False: if weight_attr == False or bias_attr == False:
assert weight_attr == param_attr, "weight_attr and bias_attr must be set to Fasle at the same time in InstanceNorm" assert weight_attr == bias_attr, "weight_attr and bias_attr must be set to Fasle at the same time in InstanceNorm"
self._epsilon = epsilon self._epsilon = epsilon
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
...@@ -176,7 +176,7 @@ class InstanceNorm1d(_InstanceNormBase): ...@@ -176,7 +176,7 @@ class InstanceNorm1d(_InstanceNormBase):
instance_norm = paddle.nn.InstanceNorm1d(2) instance_norm = paddle.nn.InstanceNorm1d(2)
instance_norm_out = instance_norm(x) instance_norm_out = instance_norm(x)
print(instance_norm_out.numpy) print(instance_norm_out.numpy())
""" """
...@@ -253,7 +253,7 @@ class InstanceNorm2d(_InstanceNormBase): ...@@ -253,7 +253,7 @@ class InstanceNorm2d(_InstanceNormBase):
instance_norm = paddle.nn.InstanceNorm2d(2) instance_norm = paddle.nn.InstanceNorm2d(2)
instance_norm_out = instance_norm(x) instance_norm_out = instance_norm(x)
print(instance_norm_out.numpy) print(instance_norm_out.numpy())
""" """
def _check_input_dim(self, input): def _check_input_dim(self, input):
...@@ -329,7 +329,7 @@ class InstanceNorm3d(_InstanceNormBase): ...@@ -329,7 +329,7 @@ class InstanceNorm3d(_InstanceNormBase):
instance_norm = paddle.nn.InstanceNorm3d(2) instance_norm = paddle.nn.InstanceNorm3d(2)
instance_norm_out = instance_norm(x) instance_norm_out = instance_norm(x)
print(instance_norm_out.numpy) print(instance_norm_out.numpy())
""" """
def _check_input_dim(self, input): def _check_input_dim(self, input):
...@@ -346,8 +346,8 @@ class GroupNorm(layers.Layer): ...@@ -346,8 +346,8 @@ class GroupNorm(layers.Layer):
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_ . Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_ .
Parameters: Parameters:
num_channels(int): The number of channels of input.
num_groups(int): The number of groups that divided from channels. num_groups(int): The number of groups that divided from channels.
num_channels(int): The number of channels of input.
epsilon(float, optional): The small value added to the variance to prevent epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05. division by zero. Default: 1e-05.
weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable weight_attr(ParamAttr|bool, optional): The parameter attribute for the learnable
...@@ -375,19 +375,19 @@ class GroupNorm(layers.Layer): ...@@ -375,19 +375,19 @@ class GroupNorm(layers.Layer):
np.random.seed(123) np.random.seed(123)
x_data = np.random.random(size=(2, 6, 2, 2)).astype('float32') x_data = np.random.random(size=(2, 6, 2, 2)).astype('float32')
x = paddle.to_tensor(x_data) x = paddle.to_tensor(x_data)
group_norm = paddle.nn.GroupNorm(num_channels=3, num_groups=6) group_norm = paddle.nn.GroupNorm(num_channels=6, num_groups=6)
group_norm_out = group_norm(x) group_norm_out = group_norm(x)
print(group_norm_out.numpy) print(group_norm_out.numpy())
""" """
def __init__(self, def __init__(self,
num_channels,
num_groups, num_groups,
num_channels,
epsilon=1e-05, epsilon=1e-05,
weight_attr=None, weight_attr=None,
bias_attr=None, bias_attr=None,
data_layout='NCHW', data_format='NCHW',
name=None): name=None):
super(GroupNorm, self).__init__() super(GroupNorm, self).__init__()
self._weight_attr = weight_attr self._weight_attr = weight_attr
...@@ -395,18 +395,33 @@ class GroupNorm(layers.Layer): ...@@ -395,18 +395,33 @@ class GroupNorm(layers.Layer):
self._epsilon = epsilon self._epsilon = epsilon
self._num_channels = num_channels self._num_channels = num_channels
self._num_groups = num_groups self._num_groups = num_groups
if data_layout != 'NCHW': if data_format != 'NCHW':
raise ValueError("unsupported data layout:" + data_layout) raise ValueError("unsupported data layout:" + data_layout)
param_shape = [self._num_channels] param_shape = [self._num_channels]
self.weight = self.create_parameter( if weight_attr == False:
attr=self._weight_attr or False, self.weight = self.create_parameter(
shape=param_shape, attr=None, shape=param_shape, default_initializer=Constant(1.0))
default_initializer=Constant(1.0)) self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
self.bias = self.create_parameter( if bias_attr == False:
attr=self._weight_attr or False, shape=param_shape, is_bias=True) self.bias = self.create_parameter(
attr=None,
shape=param_shape,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True
else:
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
def forward(self, input): def forward(self, input):
inputs = {'X': input} inputs = {'X': input}
...@@ -500,7 +515,7 @@ class LayerNorm(layers.Layer): ...@@ -500,7 +515,7 @@ class LayerNorm(layers.Layer):
layer_norm = paddle.nn.LayerNorm(x_data.shape[1:]) layer_norm = paddle.nn.LayerNorm(x_data.shape[1:])
layer_norm_out = layer_norm(x) layer_norm_out = layer_norm(x)
print(layer_norm_out.numpy) print(layer_norm_out.numpy())
""" """
def __init__(self, def __init__(self,
...@@ -603,8 +618,7 @@ class _BatchNormBase(layers.Layer): ...@@ -603,8 +618,7 @@ class _BatchNormBase(layers.Layer):
initializer=Constant(0.0), initializer=Constant(0.0),
trainable=False, trainable=False,
do_model_average=True), do_model_average=True),
shape=param_shape, shape=param_shape)
dtype=self._dtype)
self._mean.stop_gradient = True self._mean.stop_gradient = True
self._variance = self.create_parameter( self._variance = self.create_parameter(
...@@ -613,8 +627,7 @@ class _BatchNormBase(layers.Layer): ...@@ -613,8 +627,7 @@ class _BatchNormBase(layers.Layer):
initializer=Constant(1.0), initializer=Constant(1.0),
trainable=False, trainable=False,
do_model_average=True), do_model_average=True),
shape=param_shape, shape=param_shape)
dtype=self._dtype)
self._variance.stop_gradient = True self._variance.stop_gradient = True
self._data_format = data_format self._data_format = data_format
...@@ -628,8 +641,13 @@ class _BatchNormBase(layers.Layer): ...@@ -628,8 +641,13 @@ class _BatchNormBase(layers.Layer):
def _check_input_dim(self, input): def _check_input_dim(self, input):
raise NotImplementedError("BatchNorm Base error") raise NotImplementedError("BatchNorm Base error")
def _check_data_format(self, input):
raise NotImplementedError("BatchNorm Base data format error")
def forward(self, input): def forward(self, input):
self._check_data_format(self._data_format)
self._check_input_dim(input) self._check_input_dim(input)
if not self.training and not self._track_running_stats: if not self.training and not self._track_running_stats:
...@@ -730,9 +748,15 @@ class BatchNorm1d(_BatchNormBase): ...@@ -730,9 +748,15 @@ class BatchNorm1d(_BatchNormBase):
batch_norm = paddle.nn.BatchNorm1d(1) batch_norm = paddle.nn.BatchNorm1d(1)
batch_norm_out = batch_norm(x) batch_norm_out = batch_norm(x)
print(batch_norm_out.numpy) print(batch_norm_out.numpy())
""" """
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NC' or input == 'NCL':
self._data_format = 'NCHW'
else:
raise ValueError('expected NC , NCL or None for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 2 and len(input.shape) != 3: if len(input.shape) != 2 and len(input.shape) != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'.format( raise ValueError('expected 2D or 3D input (got {}D input)'.format(
...@@ -816,9 +840,15 @@ class BatchNorm2d(_BatchNormBase): ...@@ -816,9 +840,15 @@ class BatchNorm2d(_BatchNormBase):
batch_norm = paddle.nn.BatchNorm2d(1) batch_norm = paddle.nn.BatchNorm2d(1)
batch_norm_out = batch_norm(x) batch_norm_out = batch_norm(x)
print(batch_norm_out.numpy) print(batch_norm_out.numpy())
""" """
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCWH':
self._data_format = input
else:
raise ValueError('expected NCHW or NCWH for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 4: if len(input.shape) != 4:
raise ValueError('expected 4D input (got {}D input)'.format( raise ValueError('expected 4D input (got {}D input)'.format(
...@@ -902,9 +932,15 @@ class BatchNorm3d(_BatchNormBase): ...@@ -902,9 +932,15 @@ class BatchNorm3d(_BatchNormBase):
batch_norm = paddle.nn.BatchNorm3d(1) batch_norm = paddle.nn.BatchNorm3d(1)
batch_norm_out = batch_norm(x) batch_norm_out = batch_norm(x)
print(batch_norm_out.numpy) print(batch_norm_out.numpy())
""" """
def _check_data_format(self, input):
if input == 'NCHW' or input == 'NCDHW':
self._data_format = 'NCHW'
else:
raise ValueError('expected NCDHW or None for data_format input')
def _check_input_dim(self, input): def _check_input_dim(self, input):
if len(input.shape) != 5: if len(input.shape) != 5:
raise ValueError('expected 5D input (got {}D input)'.format( raise ValueError('expected 5D input (got {}D input)'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册