From 62ad3594140567b0aba5ecde1f29ec6a27659b5c Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 19 Nov 2021 20:09:15 +0800 Subject: [PATCH] add new API paddle.nn.initializer.Orthogonal and calculate_gain (#37163) * add new API paddle.nn.initializer.Orthogonal and calculate_gain * fix comment * fix comment --- python/paddle/fluid/initializer.py | 47 +++++ .../fluid/tests/unittests/test_initializer.py | 199 ++++++++++++++++++ python/paddle/nn/initializer/__init__.py | 7 +- python/paddle/nn/initializer/orthogonal.py | 199 ++++++++++++++++++ python/paddle/tensor/linalg.py | 16 +- 5 files changed, 459 insertions(+), 9 deletions(-) create mode 100644 python/paddle/nn/initializer/orthogonal.py diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 54ba5f22e53..e7fd12df3d0 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -14,6 +14,7 @@ from __future__ import print_function +import math from . import framework from . import core from .framework import in_dygraph_mode, default_main_program @@ -1033,6 +1034,52 @@ def _global_bias_initializer(): return _global_bias_initializer_ +def calculate_gain(nonlinearity, param=None): + """ + Get the recommended gain value of some nonlinearity function. + + Args: + nonlinearity(str): nonlinearity function. + param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to 'leaky_relu'. Default: None, + it will be calculated as 0.01 in the formula. + + Returns: + The recommended gain value for nonlinearity function. + + Examples: + .. code-block:: python + + import paddle + gain = paddle.nn.initializer.calculate_gain('tanh') # 5.0 / 3 + gain = paddle.nn.initializer.calculate_gain('leaky_relu', param=1.0) # 1.0 = math.sqrt(2.0 / (1+param^2)) + + """ + if param is None: + param = 0.01 + else: + assert isinstance(param, (bool, int, float)) + param = float(param) + recommended_gain = { + 'sigmoid': 1, + 'linear': 1, + 'conv1d': 1, + 'conv2d': 1, + 'conv3d': 1, + 'conv_transpose1d': 1, + 'conv_transpose2d': 1, + 'conv_transpose3d': 1, + 'tanh': 5.0 / 3, + 'relu': math.sqrt(2.0), + 'leaky_relu': math.sqrt(2.0 / (1 + param**2)), + 'selu': 3.0 / 4 + } + if nonlinearity in recommended_gain.keys(): + return recommended_gain[nonlinearity] + else: + raise ValueError("nonlinearity function {} is not suppported now.". + format(nonlinearity)) + + # We short the class name, since users will use the initializer with the package # name. The sample code: # diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 8ddb7498971..a3982ab3e4b 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -15,6 +15,7 @@ from __future__ import print_function import numpy as np +import math import unittest import paddle @@ -41,6 +42,17 @@ def output_hist(out): class TestConstantInitializer(unittest.TestCase): + def test_calculate_gain(self): + self.assertEqual(paddle.nn.initializer.calculate_gain('sigmoid'), 1) + self.assertEqual(paddle.nn.initializer.calculate_gain('linear'), 1) + self.assertEqual(paddle.nn.initializer.calculate_gain('conv2d'), 1) + self.assertEqual(paddle.nn.initializer.calculate_gain('tanh'), 5.0 / 3) + self.assertEqual( + paddle.nn.initializer.calculate_gain('relu'), math.sqrt(2.0)) + self.assertEqual( + paddle.nn.initializer.calculate_gain('leaky_relu', 1), 1) + self.assertEqual(paddle.nn.initializer.calculate_gain('selu'), 3.0 / 4) + def test_constant_initializer_default_value(self, dtype="float32"): """Test the constant initializer with default value """ @@ -716,5 +728,192 @@ class TesetconsistencyOfDynamicAndStaticGraph(unittest.TestCase): self.assertTrue(np.array_equal(dynamic_res[1], static_res[1])) +# 2-D Parameter with shape: [10, 15] +class TestOrthogonalInitializer1(unittest.TestCase): + """ + case 1 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal(gain=3.0)) + self.dtype = "float64" + self.in_features = 10 + self.out_features = 15 + self.num_ops = 9 + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + self.assertTrue(np.allclose(np.matmul(a, a.T), 9 * np.eye(10))) + + def test_orthogonal(self): + self.config() + paddle.set_default_dtype(self.dtype) + + paddle.disable_static() + paddle.seed(2021) + linear = paddle.nn.Linear( + self.in_features, self.out_features, weight_attr=self.weight_attr) + res_dygraph = linear.weight.numpy() + + paddle.enable_static() + paddle.seed(2021) + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + linear = paddle.nn.Linear( + self.in_features, + self.out_features, + weight_attr=self.weight_attr) + + block = start_prog.global_block() + self.assertEqual(len(block.ops), self.num_ops) + self.assertEqual(block.ops[0].type, 'gaussian_random') + self.assertEqual(block.ops[1].type, 'qr') + self.assertEqual(block.ops[2].type, 'diag_v2') + self.assertEqual(block.ops[3].type, 'sign') + self.assertEqual(block.ops[4].type, 'elementwise_mul') + self.assertEqual(block.ops[-3].type, 'reshape2') + self.assertEqual(block.ops[-2].type, 'scale') + + exe = paddle.static.Executor() + res_static = exe.run(start_prog, fetch_list=[linear.weight])[0] + + self.check_result(res_dygraph, res_static) + + +# 2-D Parameter with shape: [15, 10] +class TestOrthogonalInitializer2(TestOrthogonalInitializer1): + """ + case 2 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal(gain=2.0)) + self.dtype = "float64" + self.in_features = 15 + self.out_features = 10 + self.num_ops = 8 + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + self.assertTrue(np.allclose(np.matmul(a.T, a), 4 * np.eye(10))) + + +# 2-D Parameter with shape: [10, 10] +class TestOrthogonalInitializer3(TestOrthogonalInitializer1): + """ + case 3 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal()) + self.dtype = "float32" + self.in_features = 10 + self.out_features = 10 + self.num_ops = 8 + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + self.assertTrue(np.allclose(np.matmul(a.T, a), np.eye(10), atol=1.e-6)) + self.assertTrue(np.allclose(np.matmul(a, a.T), np.eye(10), atol=1.e-6)) + + def test_error(self): + self.config() + with self.assertRaises(AssertionError): + paddle.nn.Linear(10, 10, bias_attr=self.weight_attr) + + +# 4-D Parameter with shape: [6, 4, 3, 3] +class TestOrthogonalInitializer4(unittest.TestCase): + """ + case 4 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal(gain=3.0)) + self.dtype = "float64" + self.in_features = 4 + self.out_features = 6 + self.kernel_size = (3, 3) + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + a = a.reshape(6, -1) + self.assertTrue(np.allclose(np.matmul(a, a.T), 9 * np.eye(6))) + + def test_orthogonal(self): + self.config() + paddle.set_default_dtype(self.dtype) + + paddle.disable_static() + paddle.seed(2021) + conv2d = paddle.nn.Conv2D( + self.in_features, + self.out_features, + self.kernel_size, + weight_attr=self.weight_attr) + res_dygraph = conv2d.weight.numpy() + + paddle.enable_static() + paddle.seed(2021) + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + conv2d = paddle.nn.Conv2D( + self.in_features, + self.out_features, + self.kernel_size, + weight_attr=self.weight_attr) + exe = paddle.static.Executor() + res_static = exe.run(paddle.static.default_startup_program(), + fetch_list=[conv2d.weight])[0] + self.check_result(res_dygraph, res_static) + + +# 4-D Parameter with shape: [50, 4, 3, 3] +class TestOrthogonalInitializer5(TestOrthogonalInitializer4): + """ + case 5 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal(gain=2.0)) + self.dtype = "float64" + self.in_features = 4 + self.out_features = 50 + self.kernel_size = (3, 3) + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + a = a.reshape(50, -1) + self.assertTrue(np.allclose(np.matmul(a.T, a), 4 * np.eye(36))) + + +# 4-D Parameter with shape: [36, 4, 3, 3] +class TestOrthogonalInitializer6(TestOrthogonalInitializer4): + """ + case 6 + """ + + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Orthogonal()) + self.dtype = "float32" + self.in_features = 4 + self.out_features = 36 + self.kernel_size = (3, 3) + + def check_result(self, a, b): + self.assertTrue(np.array_equal(a, b)) + a = a.reshape(36, -1) + self.assertTrue(np.allclose(np.matmul(a.T, a), np.eye(36), atol=1.e-6)) + self.assertTrue(np.allclose(np.matmul(a, a.T), np.eye(36), atol=1.e-6)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/initializer/__init__.py b/python/paddle/nn/initializer/__init__.py index 03e91f80dd1..e2b83fa8ce2 100644 --- a/python/paddle/nn/initializer/__init__.py +++ b/python/paddle/nn/initializer/__init__.py @@ -15,6 +15,7 @@ # TODO: define the initializers to create a Parameter in neural network from ...fluid.initializer import Bilinear # noqa: F401 from ...fluid.initializer import set_global_initializer # noqa: F401 +from ...fluid.initializer import calculate_gain # noqa: F401 from .constant import Constant # noqa: F401 @@ -31,6 +32,8 @@ from .normal import TruncatedNormal # noqa: F401 from .uniform import Uniform # noqa: F401 +from .orthogonal import Orthogonal # noqa: F401 + __all__ = [ #noqa 'Bilinear', 'Constant', @@ -42,5 +45,7 @@ __all__ = [ #noqa 'Normal', 'TruncatedNormal', 'Uniform', - 'set_global_initializer' + 'Orthogonal', + 'set_global_initializer', + 'calculate_gain' ] diff --git a/python/paddle/nn/initializer/orthogonal.py b/python/paddle/nn/initializer/orthogonal.py new file mode 100644 index 00000000000..8a3b9bf0027 --- /dev/null +++ b/python/paddle/nn/initializer/orthogonal.py @@ -0,0 +1,199 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...fluid.initializer import Initializer +from ...fluid.data_feeder import check_variable_and_dtype +from ...fluid.core import VarDesc +from ...fluid import unique_name, framework +from ...tensor import diag, transpose, sign, qr, reshape + +__all__ = [] + + +class Orthogonal(Initializer): + """The orthogonal initializer. The initialized tensor is (semi) orthogonal. + + Assuming that 'weight' will be initialized, its shape is [M, N]. + + .. code-block:: text + + if M < N: + The rows are orthogonal vectors + elif M > N: + The columns are orthogonal vectors + else M = N: + Both rows and columns are orthogonal vectors + + Only Tensor with 2 or more dimensions can initialized by Orthogonal. + + Args: + gain(float, optional): The multiplication coefficient for initialized tensor. Default: 1.0. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A parameter initialized by orthogonal initialized. + + Examples: + .. code-block:: python + + import paddle + + weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Orthogonal()) + linear = paddle.nn.Linear(10, 15, weight_attr=weight_attr) + # linear.weight: X * X' = I + + linear = paddle.nn.Linear(15, 10, weight_attr=weight_attr) + # linear.weight: X' * X = I + """ + + def __init__(self, gain=1.0, name=None): + assert gain is not None, 'gain should not be None' + super(Orthogonal, self).__init__() + self._gain = gain + + def __call__(self, var, block=None): + """Initialize the input tensor with orthogonal initializer. + + Args: + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. + + Returns: + The last initialization op, it contain 8 ops in orthogonal initializer. + """ + block = self._check_block(block) + assert isinstance(var, framework.Parameter) + assert isinstance(block, framework.Block) + # 'qr' op only support float32/float64 now + check_variable_and_dtype(var, "Out", ["float32", "float64"], + "Orthogonal") + + self._seed = block.program.random_seed + + shape = var.shape + assert len( + shape + ) >= 2, "Only Tensor with 2 or more dimensions can be initialized by Orthogonal" + + row = shape[0] + col = 1 + for i in shape[1:]: + col *= i + + flatten_shape = [max(row, col), min(row, col)] + + normal_var = block.create_var( + name=unique_name.generate('.'.join(['gaussian_random', 'tmp'])), + dtype=var.dtype, + persistable=False, + stop_gradient=True) + block.append_op( + type='gaussian_random', + inputs={}, + outputs={'Out': normal_var}, + attrs={ + 'mean': 0.0, + 'std': 1.0, + 'shape': flatten_shape, + 'seed': self._seed, + 'dtype': var.dtype + }, + stop_gradient=True) + + q = block.create_var( + name=unique_name.generate('.'.join(['qr', 'q', 'tmp'])), + dtype=normal_var.dtype, + persistable=False, + stop_gradient=True) + r = block.create_var( + name=unique_name.generate('.'.join(['qr', 'r', 'tmp'])), + dtype=normal_var.dtype, + persistable=False, + stop_gradient=True) + block.append_op( + type='qr', + inputs={'X': [normal_var]}, + outputs={ + 'Q': q, + 'R': r, + }, + attrs={'mode': 'reduced'}, + stop_gradient=True) + + r_diag = block.create_var( + name=unique_name.generate('.'.join(['diag', 'tmp'])), + dtype=r.dtype, + persistable=False, + stop_gradient=True) + block.append_op( + type='diag_v2', + inputs={'X': r}, + outputs={'Out': r_diag}, + attrs={'offset': 0, + 'padding_value': 0}, + stop_gradient=True) + + r_sign = r_diag + block.append_op( + type='sign', + inputs={'X': [r_diag]}, + outputs={'Out': r_sign}, + stop_gradient=True) + + block.append_op( + type='elementwise_mul', + inputs={'X': q, + 'Y': r_sign}, + outputs={'Out': q}, + attrs={}, + stop_gradient=True) + + x_shape = block.create_var( + name=unique_name.generate('.'.join(['transpose', 'shape', 'tmp'])), + dtype=q.dtype, + persistable=False, + stop_gradient=True) + if row < col: + q_transpose = block.create_var( + name=unique_name.generate('.'.join(['transpose', 'tmp'])), + dtype=q.dtype, + persistable=False, + stop_gradient=True) + block.append_op( + type='transpose2', + inputs={'X': q}, + outputs={'Out': q_transpose, + 'XShape': x_shape}, + attrs={'axis': [1, 0]}, + stop_gradient=True) + q = q_transpose + + block.append_op( + type='reshape2', + inputs={'X': q}, + outputs={'Out': q, + "XShape": x_shape}, + attrs={'shape': var.shape}, + stop_gradient=True) + + op = block.append_op( + type='scale', + inputs={'X': q}, + outputs={'Out': var}, + attrs={'scale': self._gain, + 'bias': 0.0}) + + return op diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index abfc72c797a..c5bf19e83de 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -209,28 +209,28 @@ def norm(x, p='fro', axis=None, keepdim=False, name=None): # [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.]]] # compute frobenius norm along last two dimensions. - out_fro = paddle.norm(x, p='fro', axis=[0,1]) + out_fro = paddle.linalg.norm(x, p='fro', axis=[0,1]) # out_fro.numpy() [17.435596 16.911535 16.7332 16.911535] # compute 2-order vector norm along last dimension. - out_pnorm = paddle.norm(x, p=2, axis=-1) + out_pnorm = paddle.linalg.norm(x, p=2, axis=-1) #out_pnorm.numpy(): [[21.118711 13.190906 5.477226] # [ 3.7416575 11.224972 19.131126]] # compute 2-order norm along [0,1] dimension. - out_pnorm = paddle.norm(x, p=2, axis=[0,1]) + out_pnorm = paddle.linalg.norm(x, p=2, axis=[0,1]) #out_pnorm.numpy(): [17.435596 16.911535 16.7332 16.911535] # compute inf-order norm - out_pnorm = paddle.norm(x, p=np.inf) + out_pnorm = paddle.linalg.norm(x, p=np.inf) #out_pnorm.numpy() = [12.] - out_pnorm = paddle.norm(x, p=np.inf, axis=0) + out_pnorm = paddle.linalg.norm(x, p=np.inf, axis=0) #out_pnorm.numpy(): [[12. 11. 10. 9.] [8. 7. 6. 7.] [8. 9. 10. 11.]] # compute -inf-order norm - out_pnorm = paddle.norm(x, p=-np.inf) + out_pnorm = paddle.linalg.norm(x, p=-np.inf) #out_pnorm.numpy(): [0.] - out_pnorm = paddle.norm(x, p=-np.inf, axis=0) + out_pnorm = paddle.linalg.norm(x, p=-np.inf, axis=0) #out_pnorm.numpy(): [[0. 1. 2. 3.] [4. 5. 6. 5.] [4. 3. 2. 1.]] """ @@ -1084,7 +1084,7 @@ def cholesky(x, upper=False, name=None): a_t = np.transpose(a, [1, 0]) x_data = np.matmul(a, a_t) + 1e-03 x = paddle.to_tensor(x_data) - out = paddle.cholesky(x, upper=False) + out = paddle.linalg.cholesky(x, upper=False) print(out) # [[1.190523 0. 0. ] # [0.9906703 0.27676893 0. ] -- GitLab