From fd66d76231d6fa3245f72b7d36199565864c05ef Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 20 Aug 2020 19:10:48 +0800 Subject: [PATCH] add weight_norm & remove_weight_norm (#26131) * add weight_norm, test=develop --- python/paddle/fluid/param_attr.py | 4 + .../unittests/test_dygraph_weight_norm.py | 183 ++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/utils/__init__.py | 16 ++ python/paddle/nn/utils/weight_norm_hook.py | 225 ++++++++++++++++++ python/setup.py.in | 1 + 6 files changed, 431 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py create mode 100644 python/paddle/nn/utils/__init__.py create mode 100644 python/paddle/nn/utils/weight_norm_hook.py diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py index a45443632b0..8e0470beded 100644 --- a/python/paddle/fluid/param_attr.py +++ b/python/paddle/fluid/param_attr.py @@ -204,6 +204,9 @@ class WeightNormParamAttr(ParamAttr): """ :api_attr: Static Graph + Note: + Please use 'paddle.nn.utils.weight_norm' in dygraph mode. + Parameter of weight Norm. Weight Norm is a reparameterization of the weight vectors in a neural network that decouples the magnitude of those weight vectors from their direction. Weight Norm has been implemented as discussed in this @@ -216,6 +219,7 @@ class WeightNormParamAttr(ParamAttr): It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient. There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . + Args: dim(int): Dimension over which to compute the norm. Dim is a non-negative diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py new file mode 100644 index 00000000000..f33334d536d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dygraph_weight_norm.py @@ -0,0 +1,183 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy +import collections +from functools import reduce +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.nn.utils import weight_norm, remove_weight_norm + + +class TestDygraphWeightNorm(unittest.TestCase): + def setUp(self): + self.init_test_case() + self.set_data() + + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = None + + def set_data(self): + self.data = collections.OrderedDict() + for desc in self.data_desc: + data_name = desc[0] + data_shape = desc[1] + data_value = numpy.random.random( + size=[self.batch_size] + data_shape).astype('float32') + self.data[data_name] = data_value + + def norm_except_dim(self, w, dim=None): + shape = w.shape + ndims = len(shape) + shape_numel = reduce(lambda x, y: x * y, shape) + if dim == -1: + return numpy.linalg.norm(w, axis=None, keepdims=True) + elif dim == 0: + tile_shape = list(w.shape) + tile_shape[0] = 1 + w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0])) + return numpy.linalg.norm(w_matrix, axis=1, keepdims=True) + elif dim == (ndims - 1): + w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1])) + return numpy.linalg.norm(w_matrix, axis=0, keepdims=True) + else: + perm = list(range(ndims)) + perm_ori = list(range(ndims)) + perm[0] = dim + perm[dim] = 0 + p_transposed = numpy.transpose(w, perm) + return self.norm_except_dim(p_transposed, 0) + + def weight_normalize(self, w, dim=None): + shape = w.shape + ndims = len(shape) + shape_numel = reduce(lambda x, y: x * y, shape) + v = w + g = self.norm_except_dim(w, dim) + g_mul = g + + if dim == -1: + v_norm = v / (numpy.linalg.norm(v, axis=None, keepdims=True)) + elif dim == 0: + w_matrix = numpy.reshape(w, (shape[0], shape_numel // shape[0])) + v_norm = v / numpy.linalg.norm(w_matrix, axis=1) + v_norm = numpy.reshape(v_norm, shape) + g = numpy.squeeze(g, axis=1) + elif dim == (ndims - 1): + w_matrix = numpy.reshape(w, (shape_numel // shape[-1], shape[-1])) + v_norm = v / numpy.linalg.norm(w_matrix, axis=0, keepdims=True) + v_norm = numpy.reshape(v_norm, shape) + else: + perm = list(range(ndims)) + perm[0] = dim + perm[dim] = 0 + p_transposed = numpy.transpose(v, perm) + transposed_shape = p_transposed.shape + transposed_shape_numel = reduce(lambda x, y: x * y, + transposed_shape) + p_matrix = numpy.reshape( + p_transposed, (p_transposed.shape[0], + transposed_shape_numel // p_transposed.shape[0])) + v_norm = v / numpy.expand_dims( + numpy.expand_dims( + numpy.linalg.norm( + p_matrix, axis=1, keepdims=True), axis=0), + axis=(ndims - 1)) + v_norm = numpy.reshape(v_norm, transposed_shape) + v_norm = numpy.transpose(v_norm, perm) + g = numpy.squeeze(g, axis=1) + if dim == 1: + eaxis = 2 + elif dim == 2: + eaxis = 1 + g_mul = numpy.expand_dims( + numpy.expand_dims( + numpy.expand_dims( + g, axis=0), axis=eaxis), + axis=(ndims - 1)) + w = g_mul * v_norm + return g, v + + def test_check_output(self): + fluid.enable_imperative() + linear = paddle.nn.Conv2D(2, 3, 3) + before_weight = linear.weight.numpy() + if self.dim == None: + self.dim = -1 + wn = weight_norm(linear, dim=self.dim) + outputs = [] + for name, data in self.data.items(): + output = linear(fluid.dygraph.to_variable(data)) + outputs.append(output.numpy()) + after_weight = linear.weight + self.actual_outputs = [linear.weight_g.numpy(), linear.weight_v.numpy()] + + expect_output = self.weight_normalize(before_weight, self.dim) + + for expect, actual in zip(expect_output, self.actual_outputs): + self.assertTrue( + numpy.allclose( + numpy.array(actual), expect, atol=0.001)) + + +class TestDygraphWeightNormCase1(TestDygraphWeightNorm): + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = 0 + + +class TestDygraphWeightNormCase2(TestDygraphWeightNorm): + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = 1 + + +class TestDygraphWeightNormCase3(TestDygraphWeightNorm): + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = 3 + + +class TestDygraphRemoveWeightNorm(unittest.TestCase): + def setUp(self): + self.init_test_case() + + def init_test_case(self): + self.batch_size = 3 + self.data_desc = (['x', [2, 3, 3]], ) + self.dim = None + + def test_check_output(self): + fluid.enable_imperative() + linear = paddle.nn.Conv2D(2, 3, 3) + before_weight = linear.weight + wn = weight_norm(linear, dim=self.dim) + rwn = remove_weight_norm(linear) + after_weight = linear.weight + self.assertTrue( + numpy.allclose( + before_weight.numpy(), after_weight.numpy(), atol=0.001)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 72a9f8d5c93..3b75629ede9 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -18,6 +18,7 @@ from .layer import norm from .functional import extension from .layer import common +from .utils import weight_norm_hook from . import initializer @@ -25,6 +26,7 @@ __all__ = [] __all__ += norm.__all__ __all__ += extension.__all__ __all__ += common.__all__ +__all__ += weight_norm_hook.__all__ # TODO: define alias in nn directory # from .clip import ErrorClipByValue #DEFINE_ALIAS diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py new file mode 100644 index 00000000000..6562ac35e1e --- /dev/null +++ b/python/paddle/nn/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import weight_norm_hook +from .weight_norm_hook import weight_norm, remove_weight_norm diff --git a/python/paddle/nn/utils/weight_norm_hook.py b/python/paddle/nn/utils/weight_norm_hook.py new file mode 100644 index 00000000000..ad53bf39466 --- /dev/null +++ b/python/paddle/nn/utils/weight_norm_hook.py @@ -0,0 +1,225 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ... import fluid +from ...fluid import dygraph +from ...fluid import layers as F +from ...fluid.layer_helper import LayerHelper +from ...fluid.data_feeder import check_variable_and_dtype +from ...tensor.math import multiply + +__all__ = ['weight_norm', 'remove_weight_norm'] + + +def l2_norm(x, axis, epsilon=1e-12, name=None): + if len(x.shape) == 1: + axis = 0 + check_variable_and_dtype(x, "X", ("float32", "float64"), "norm") + + helper = LayerHelper("l2_normalize", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + norm = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="norm", + inputs={"X": x}, + outputs={"Out": out, + "Norm": norm}, + attrs={ + "axis": 1 if axis is None else axis, + "epsilon": epsilon, + }) + return F.squeeze(norm, axes=[axis]) + + +def norm_except_dim(p, dim): + shape = p.shape + ndims = len(shape) + if dim == -1: + return F.sqrt(F.reduce_sum(F.square(p)) + 1e-12) + elif dim == 0: + p_matrix = F.reshape(p, (shape[0], -1)) + return l2_norm(p_matrix, axis=1) + elif dim == ndims - 1: + p_matrix = F.reshape(p, (-1, shape[-1])) + return l2_norm(p_matrix, axis=0) + else: + perm = list(range(ndims)) + perm[0] = dim + perm[dim] = 0 + p_transposed = F.transpose(p, perm) + return norm_except_dim(p_transposed, 0) + + +def _weight_norm(v, g, dim): + shape = v.shape + ndims = len(shape) + + if dim == -1: + v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12) + elif dim == 0: + p_matrix = F.reshape(v, (shape[0], -1)) + v_normalized = F.l2_normalize(p_matrix, axis=1) + v_normalized = F.reshape(v_normalized, shape) + elif dim == ndims - 1: + p_matrix = F.reshape(v, (-1, shape[-1])) + v_normalized = F.l2_normalize(p_matrix, axis=0) + v_normalized = F.reshape(v_normalized, shape) + else: + perm = list(range(ndims)) + perm[0] = dim + perm[dim] = 0 + p_transposed = F.transpose(v, perm) + transposed_shape = p_transposed.shape + p_matrix = F.reshape(p_transposed, (p_transposed.shape[0], -1)) + v_normalized = F.l2_normalize(p_matrix, axis=1) + v_normalized = F.reshape(v_normalized, transposed_shape) + v_normalized = F.transpose(v_normalized, perm) + weight = multiply(v_normalized, g, axis=dim if dim is not None else -1) + return weight + + +class WeightNorm(object): + def __init__(self, name, dim): + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + def compute_weight(self, layer): + g = getattr(layer, self.name + '_g') + v = getattr(layer, self.name + '_v') + return _weight_norm(v, g, self.dim) + + @staticmethod + def apply(layer, name, dim): + for k, hook in layer._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + raise RuntimeError("Cannot register two weight_norm hooks on " + "the same parameter {}".format(name)) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + w = getattr(layer, name) + del layer._parameters[name] + + g_var = norm_except_dim(w, dim) + v = layer.create_parameter(w.shape, dtype=w.dtype) + layer.add_parameter(name + "_v", v) + g = layer.create_parameter(g_var.shape, dtype=g_var.dtype) + layer.add_parameter(name + '_g', g) + with dygraph.no_grad(): + F.assign(w, v) + F.assign(g_var, g) + setattr(layer, name, fn.compute_weight(layer)) + + layer.register_forward_pre_hook(fn) + return fn + + def remove(self, layer): + w_var = self.compute_weight(layer) + delattr(layer, self.name) + del layer._parameters[self.name + '_g'] + del layer._parameters[self.name + '_v'] + w = layer.create_parameter(w_var.shape, dtype=w_var.dtype) + layer.add_parameter(self.name, w) + with dygraph.no_grad(): + F.assign(w_var, w) + + def __call__(self, layer, inputs): + setattr(layer, self.name, self.compute_weight(layer)) + + +def weight_norm(layer, name='weight', dim=0): + """ + This weight_norm layer applies weight normalization to a parameter according to the + following formula: + + .. math:: + + \mathbf{w} = g \dfrac{v}{\|v\|} + + Weight normalization is a reparameterization of the weight vectors in a neural network that + decouples the magnitude of those weight vectors from their direction. Weight normalization + replaces the parameter specified by `name`(eg: 'weight') with two parameters: one parameter + specifying the magnitude (eg: 'weight_g') and one parameter specifying the direction + (eg: 'weight_v'). Weight normalization has been implemented as discussed in this paper: + `Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks + `_. + + Parameters: + layer(Layer): Layer of paddle, which has weight. + name(str, optional): Name of the weight parameter. Default: 'weight'. + dim(int, optional): Dimension over which to compute the norm. Dim is a non-negative number + which is less than the rank of weight Tensor. For Example, dim can be chosen from 0, + 1, 2, 3 for convolution whose weight shape is [cout, cin, kh, kw] and rank is 4. + If dim is set to None, meaning that all elements will be normalized. Default: 0. + + Returns: + Origin layer with weight norm hook. + + Examples: + .. code-block:: python + + import numpy as np + from paddle.nn import Conv2D + from paddle.nn.utils import weight_norm + + x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + paddle.disable_static() + conv = Conv2D(3, 5, 3) + wn = weight_norm(conv) + print(conv.weight_g.shape) + # [5] + print(conv.weight_v.shape) + # [5, 3, 3, 3] + """ + WeightNorm.apply(layer, name, dim) + return layer + + +def remove_weight_norm(layer, name='weight'): + """ + remove weight normalization from layer. + + Parameters: + layer(Layer): Layer of paddle, which has weight. + name(str, optional): Name of the weight parameter. Default: 'weight'. + + Returns: + Origin layer without weight norm + + Examples: + .. code-block:: python + import paddle + from paddle.nn import Conv2D + from paddle.nn.utils import weight_norm, remove_weight_norm + + paddle.disable_static() + conv = Conv2D(3, 5, 3) + wn = weight_norm(conv) + remove_weight_norm(conv) + print(conv.weight_g) + # AttributeError: 'Conv2D' object has no attribute 'weight_g' + """ + for k, hook in layer._forward_pre_hooks.items(): + if isinstance(hook, WeightNorm) and hook.name == name: + hook.remove(layer) + del layer._forward_pre_hooks[k] + return layer + + raise ValueError("weight_norm of '{}' not found in {}".format(name, layer)) diff --git a/python/setup.py.in b/python/setup.py.in index 29bc68444e1..4706099c0c3 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -201,6 +201,7 @@ packages=['paddle', 'paddle.nn.functional', 'paddle.nn.layer', 'paddle.nn.initializer', + 'paddle.nn.utils', 'paddle.metric', 'paddle.static', 'paddle.static.nn', -- GitLab