未验证 提交 fd66d762 编写于 作者: C ceci3 提交者: GitHub

add weight_norm & remove_weight_norm (#26131)

* add weight_norm, test=develop
上级 facc0a10
......@@ -204,6 +204,9 @@ class WeightNormParamAttr(ParamAttr):
"""
:api_attr: Static Graph
Note:
Please use 'paddle.nn.utils.weight_norm' in dygraph mode.
Parameter of weight Norm. Weight Norm is a reparameterization of the weight vectors
in a neural network that decouples the magnitude of those weight vectors from
their direction. Weight Norm has been implemented as discussed in this
......@@ -216,6 +219,7 @@ class WeightNormParamAttr(ParamAttr):
It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient.
There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` ,
:ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` .
Args:
dim(int): Dimension over which to compute the norm. Dim is a non-negative
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy
import collections
from functools import reduce
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.nn.utils import weight_norm, remove_weight_norm
class TestDygraphWeightNorm(unittest.TestCase):
def setUp(self):
self.init_test_case()
self.set_data()
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = None
def set_data(self):
self.data = collections.OrderedDict()
for desc in self.data_desc:
data_name = desc[0]
data_shape = desc[1]
data_value = numpy.random.random(
size=[self.batch_size] + data_shape).astype('float32')
self.data[data_name] = data_value
def norm_except_dim(self, w, dim=None):
shape = w.shape
ndims = len(shape)
shape_numel = reduce(lambda x, y: x * y, shape)
if dim == -1:
return numpy.linalg.norm(w, axis=None, keepdims=True)
elif dim == 0:
tile_shape = list(w.shape)
tile_shape[0] = 1
w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0]))
return numpy.linalg.norm(w_matrix, axis=1, keepdims=True)
elif dim == (ndims - 1):
w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1]))
return numpy.linalg.norm(w_matrix, axis=0, keepdims=True)
else:
perm = list(range(ndims))
perm_ori = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = numpy.transpose(w, perm)
return self.norm_except_dim(p_transposed, 0)
def weight_normalize(self, w, dim=None):
shape = w.shape
ndims = len(shape)
shape_numel = reduce(lambda x, y: x * y, shape)
v = w
g = self.norm_except_dim(w, dim)
g_mul = g
if dim == -1:
v_norm = v / (numpy.linalg.norm(v, axis=None, keepdims=True))
elif dim == 0:
w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0]))
v_norm = v / numpy.linalg.norm(w_matrix, axis=1)
v_norm = numpy.reshape(v_norm, shape)
g = numpy.squeeze(g, axis=1)
elif dim == (ndims - 1):
w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1]))
v_norm = v / numpy.linalg.norm(w_matrix, axis=0, keepdims=True)
v_norm = numpy.reshape(v_norm, shape)
else:
perm = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = numpy.transpose(v, perm)
transposed_shape = p_transposed.shape
transposed_shape_numel = reduce(lambda x, y: x * y,
transposed_shape)
p_matrix = numpy.reshape(
p_transposed, (p_transposed.shape[0],
transposed_shape_numel // p_transposed.shape[0]))
v_norm = v / numpy.expand_dims(
numpy.expand_dims(
numpy.linalg.norm(
p_matrix, axis=1, keepdims=True), axis=0),
axis=(ndims - 1))
v_norm = numpy.reshape(v_norm, transposed_shape)
v_norm = numpy.transpose(v_norm, perm)
g = numpy.squeeze(g, axis=1)
if dim == 1:
eaxis = 2
elif dim == 2:
eaxis = 1
g_mul = numpy.expand_dims(
numpy.expand_dims(
numpy.expand_dims(
g, axis=0), axis=eaxis),
axis=(ndims - 1))
w = g_mul * v_norm
return g, v
def test_check_output(self):
fluid.enable_imperative()
linear = paddle.nn.Conv2D(2, 3, 3)
before_weight = linear.weight.numpy()
if self.dim == None:
self.dim = -1
wn = weight_norm(linear, dim=self.dim)
outputs = []
for name, data in self.data.items():
output = linear(fluid.dygraph.to_variable(data))
outputs.append(output.numpy())
after_weight = linear.weight
self.actual_outputs = [linear.weight_g.numpy(), linear.weight_v.numpy()]
expect_output = self.weight_normalize(before_weight, self.dim)
for expect, actual in zip(expect_output, self.actual_outputs):
self.assertTrue(
numpy.allclose(
numpy.array(actual), expect, atol=0.001))
class TestDygraphWeightNormCase1(TestDygraphWeightNorm):
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = 0
class TestDygraphWeightNormCase2(TestDygraphWeightNorm):
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = 1
class TestDygraphWeightNormCase3(TestDygraphWeightNorm):
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = 3
class TestDygraphRemoveWeightNorm(unittest.TestCase):
def setUp(self):
self.init_test_case()
def init_test_case(self):
self.batch_size = 3
self.data_desc = (['x', [2, 3, 3]], )
self.dim = None
def test_check_output(self):
fluid.enable_imperative()
linear = paddle.nn.Conv2D(2, 3, 3)
before_weight = linear.weight
wn = weight_norm(linear, dim=self.dim)
rwn = remove_weight_norm(linear)
after_weight = linear.weight
self.assertTrue(
numpy.allclose(
before_weight.numpy(), after_weight.numpy(), atol=0.001))
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,7 @@
from .layer import norm
from .functional import extension
from .layer import common
from .utils import weight_norm_hook
from . import initializer
......@@ -25,6 +26,7 @@ __all__ = []
__all__ += norm.__all__
__all__ += extension.__all__
__all__ += common.__all__
__all__ += weight_norm_hook.__all__
# TODO: define alias in nn directory
# from .clip import ErrorClipByValue #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import weight_norm_hook
from .weight_norm_hook import weight_norm, remove_weight_norm
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ... import fluid
from ...fluid import dygraph
from ...fluid import layers as F
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
from ...tensor.math import multiply
__all__ = ['weight_norm', 'remove_weight_norm']
def l2_norm(x, axis, epsilon=1e-12, name=None):
if len(x.shape) == 1:
axis = 0
check_variable_and_dtype(x, "X", ("float32", "float64"), "norm")
helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
norm = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="norm",
inputs={"X": x},
outputs={"Out": out,
"Norm": norm},
attrs={
"axis": 1 if axis is None else axis,
"epsilon": epsilon,
})
return F.squeeze(norm, axes=[axis])
def norm_except_dim(p, dim):
shape = p.shape
ndims = len(shape)
if dim == -1:
return F.sqrt(F.reduce_sum(F.square(p)) + 1e-12)
elif dim == 0:
p_matrix = F.reshape(p, (shape[0], -1))
return l2_norm(p_matrix, axis=1)
elif dim == ndims - 1:
p_matrix = F.reshape(p, (-1, shape[-1]))
return l2_norm(p_matrix, axis=0)
else:
perm = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = F.transpose(p, perm)
return norm_except_dim(p_transposed, 0)
def _weight_norm(v, g, dim):
shape = v.shape
ndims = len(shape)
if dim == -1:
v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
elif dim == 0:
p_matrix = F.reshape(v, (shape[0], -1))
v_normalized = F.l2_normalize(p_matrix, axis=1)
v_normalized = F.reshape(v_normalized, shape)
elif dim == ndims - 1:
p_matrix = F.reshape(v, (-1, shape[-1]))
v_normalized = F.l2_normalize(p_matrix, axis=0)
v_normalized = F.reshape(v_normalized, shape)
else:
perm = list(range(ndims))
perm[0] = dim
perm[dim] = 0
p_transposed = F.transpose(v, perm)
transposed_shape = p_transposed.shape
p_matrix = F.reshape(p_transposed, (p_transposed.shape[0], -1))
v_normalized = F.l2_normalize(p_matrix, axis=1)
v_normalized = F.reshape(v_normalized, transposed_shape)
v_normalized = F.transpose(v_normalized, perm)
weight = multiply(v_normalized, g, axis=dim if dim is not None else -1)
return weight
class WeightNorm(object):
def __init__(self, name, dim):
if dim is None:
dim = -1
self.name = name
self.dim = dim
def compute_weight(self, layer):
g = getattr(layer, self.name + '_g')
v = getattr(layer, self.name + '_v')
return _weight_norm(v, g, self.dim)
@staticmethod
def apply(layer, name, dim):
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError("Cannot register two weight_norm hooks on "
"the same parameter {}".format(name))
if dim is None:
dim = -1
fn = WeightNorm(name, dim)
w = getattr(layer, name)
del layer._parameters[name]
g_var = norm_except_dim(w, dim)
v = layer.create_parameter(w.shape, dtype=w.dtype)
layer.add_parameter(name + "_v", v)
g = layer.create_parameter(g_var.shape, dtype=g_var.dtype)
layer.add_parameter(name + '_g', g)
with dygraph.no_grad():
F.assign(w, v)
F.assign(g_var, g)
setattr(layer, name, fn.compute_weight(layer))
layer.register_forward_pre_hook(fn)
return fn
def remove(self, layer):
w_var = self.compute_weight(layer)
delattr(layer, self.name)
del layer._parameters[self.name + '_g']
del layer._parameters[self.name + '_v']
w = layer.create_parameter(w_var.shape, dtype=w_var.dtype)
layer.add_parameter(self.name, w)
with dygraph.no_grad():
F.assign(w_var, w)
def __call__(self, layer, inputs):
setattr(layer, self.name, self.compute_weight(layer))
def weight_norm(layer, name='weight', dim=0):
"""
This weight_norm layer applies weight normalization to a parameter according to the
following formula:
.. math::
\mathbf{w} = g \dfrac{v}{\|v\|}
Weight normalization is a reparameterization of the weight vectors in a neural network that
decouples the magnitude of those weight vectors from their direction. Weight normalization
replaces the parameter specified by `name`(eg: 'weight') with two parameters: one parameter
specifying the magnitude (eg: 'weight_g') and one parameter specifying the direction
(eg: 'weight_v'). Weight normalization has been implemented as discussed in this paper:
`Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks
<https://arxiv.org/pdf/1602.07868.pdf>`_.
Parameters:
layer(Layer): Layer of paddle, which has weight.
name(str, optional): Name of the weight parameter. Default: 'weight'.
dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative number
which is less than the rank of weight Tensor. For Example, dim can be chosen from 0,
1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw] and rank is 4.
If dim is set to None, meaning that all elements will be normalized. Default: 0.
Returns:
Origin layer with weight norm hook.
Examples:
.. code-block:: python
import numpy as np
from paddle.nn import Conv2D
from paddle.nn.utils import weight_norm
x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32')
paddle.disable_static()
conv = Conv2D(3, 5, 3)
wn = weight_norm(conv)
print(conv.weight_g.shape)
# [5]
print(conv.weight_v.shape)
# [5, 3, 3, 3]
"""
WeightNorm.apply(layer, name, dim)
return layer
def remove_weight_norm(layer, name='weight'):
"""
remove weight normalization from layer.
Parameters:
layer(Layer): Layer of paddle, which has weight.
name(str, optional): Name of the weight parameter. Default: 'weight'.
Returns:
Origin layer without weight norm
Examples:
.. code-block:: python
import paddle
from paddle.nn import Conv2D
from paddle.nn.utils import weight_norm, remove_weight_norm
paddle.disable_static()
conv = Conv2D(3, 5, 3)
wn = weight_norm(conv)
remove_weight_norm(conv)
print(conv.weight_g)
# AttributeError: 'Conv2D' object has no attribute 'weight_g'
"""
for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name:
hook.remove(layer)
del layer._forward_pre_hooks[k]
return layer
raise ValueError("weight_norm of '{}' not found in {}".format(name, layer))
......@@ -201,6 +201,7 @@ packages=['paddle',
'paddle.nn.functional',
'paddle.nn.layer',
'paddle.nn.initializer',
'paddle.nn.utils',
'paddle.metric',
'paddle.static',
'paddle.static.nn',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册