From bbb9b28a546334550b6aa28b6b6812344238d4c3 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 25 Nov 2021 15:45:40 +0800 Subject: [PATCH] add new API paddle.nn.initializer.Dirac (#37389) * add new API paddle.nn.initializer.Dirac * fix doc --- python/paddle/fluid/initializer.py | 13 +- .../fluid/tests/unittests/test_initializer.py | 113 +++++++++ python/paddle/nn/initializer/__init__.py | 3 + python/paddle/nn/initializer/dirac.py | 223 ++++++++++++++++++ 4 files changed, 346 insertions(+), 6 deletions(-) create mode 100644 python/paddle/nn/initializer/dirac.py diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index e7fd12df3d..930995cee6 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -1039,9 +1039,10 @@ def calculate_gain(nonlinearity, param=None): Get the recommended gain value of some nonlinearity function. Args: - nonlinearity(str): nonlinearity function. - param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to 'leaky_relu'. Default: None, - it will be calculated as 0.01 in the formula. + nonlinearity(str): name of nonlinearity activation function. If it is a linear function, which is one of + "linear/conv1d/conv2d/conv3d/conv1d_transpose/conv2d_transpose/conv3d_transpose" , will return 1.0 + param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to + 'leaky_relu'. Default: None, it will be calculated as 0.01 in the formula. Returns: The recommended gain value for nonlinearity function. @@ -1065,9 +1066,9 @@ def calculate_gain(nonlinearity, param=None): 'conv1d': 1, 'conv2d': 1, 'conv3d': 1, - 'conv_transpose1d': 1, - 'conv_transpose2d': 1, - 'conv_transpose3d': 1, + 'conv1d_transpose': 1, + 'conv2d_transpose': 1, + 'conv3d_transpose': 1, 'tanh': 5.0 / 3, 'relu': math.sqrt(2.0), 'leaky_relu': math.sqrt(2.0 / (1 + param**2)), diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index a3982ab3e4..6fdad811ee 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -915,5 +915,118 @@ class TestOrthogonalInitializer6(TestOrthogonalInitializer4): self.assertTrue(np.allclose(np.matmul(a, a.T), np.eye(36), atol=1.e-6)) +# initialize Conv1D weight +class TestDiracInitializer1(unittest.TestCase): + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Dirac()) + self.dtype = "float64" + self.in_channels = 3 + self.out_channels = 2 + self.kernel_size = 3 + self.input_shape = [8, self.in_channels, 10] + self.conv_layer = paddle.nn.Conv1D + self.num_ops = 8 #fill_constant*2, reshape*2, assign_value*2, scatter, cast + + def check_result(self, w_dygraph, w_static, conv_in, conv_out): + self.assertTrue(np.array_equal(w_dygraph, w_static)) + self.assertTrue(np.array_equal(conv_out, conv_in[:, 0:2, 1:9])) + + def test_dirac(self): + self.config() + paddle.set_default_dtype(self.dtype) + + paddle.disable_static() + conv = self.conv_layer( + self.in_channels, + self.out_channels, + self.kernel_size, + weight_attr=self.weight_attr) + weight_dygraph = conv.weight.numpy() + + paddle.enable_static() + start_prog = paddle.static.Program() + main_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + inp = paddle.rand(self.input_shape) + conv = self.conv_layer( + self.in_channels, + self.out_channels, + self.kernel_size, + weight_attr=self.weight_attr) + + output = conv(inp) + block = start_prog.global_block() + self.assertEqual(len(block.ops), self.num_ops) + self.assertEqual(block.ops[0].type, 'fill_constant') + self.assertEqual(block.ops[1].type, 'reshape') + self.assertEqual(block.ops[2].type, 'assign_value') + self.assertEqual(block.ops[3].type, 'assign_value') + self.assertEqual(block.ops[4].type, 'scatter') + self.assertEqual(block.ops[5].type, 'reshape') + + exe = paddle.static.Executor() + exe.run(start_prog) + fetch = exe.run(main_prog, fetch_list=[inp, output, conv.weight]) + conv_input = fetch[0] + conv_output = fetch[1] + weight_static = fetch[2] + + self.check_result(weight_dygraph, weight_static, conv_input, + conv_output) + + +# initialize Conv2D weight +class TestDiracInitializer2(TestDiracInitializer1): + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Dirac(groups=1)) + self.dtype = "float64" + self.in_channels = 4 + self.out_channels = 8 + self.kernel_size = (3, 3) + self.input_shape = [8, self.in_channels, 10, 10] + self.conv_layer = paddle.nn.Conv2D + self.num_ops = 8 + + def check_result(self, w_dygraph, w_static, conv_in, conv_out): + self.assertTrue(np.array_equal(w_dygraph, w_static)) + self.assertTrue( + np.array_equal(conv_out[:, 0:4, :, :], conv_in[:, :, 1:9, 1:9])) + self.assertTrue( + np.array_equal(conv_out[:, 4:8, :, :], np.zeros([8, 4, 8, 8]))) + + +# initialize Conv3D weight +class TestDiracInitializer3(TestDiracInitializer1): + def config(self): + self.weight_attr = paddle.ParamAttr( + initializer=paddle.nn.initializer.Dirac(groups=2)) + self.dtype = "float32" + self.in_channels = 5 + self.out_channels = 10 + self.kernel_size = (3, 3, 3) + self.input_shape = [8, self.in_channels, 10, 10, 10] + self.conv_layer = paddle.nn.Conv3D + self.num_ops = 7 + + def check_result(self, w_dygraph, w_static, conv_in, conv_out): + self.assertTrue(np.array_equal(w_dygraph, w_static)) + self.assertTrue( + np.array_equal(conv_out[:, 0:5, :, :, :], conv_in[:, :, 1:9, 1:9, 1: + 9])) + self.assertTrue( + np.array_equal(conv_out[:, 5:10, :, :, :], conv_in[:, :, 1:9, 1:9, + 1:9])) + + def test_error(self): + self.config() + with self.assertRaises(AssertionError): + paddle.nn.Linear(10, 10, weight_attr=self.weight_attr) + + with self.assertRaises(AssertionError): + paddle.nn.Conv2D(5, 9, (3, 3), weight_attr=self.weight_attr) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/initializer/__init__.py b/python/paddle/nn/initializer/__init__.py index e2b83fa8ce..e048ee2b1e 100644 --- a/python/paddle/nn/initializer/__init__.py +++ b/python/paddle/nn/initializer/__init__.py @@ -34,6 +34,8 @@ from .uniform import Uniform # noqa: F401 from .orthogonal import Orthogonal # noqa: F401 +from .dirac import Dirac # noqa: F401 + __all__ = [ #noqa 'Bilinear', 'Constant', @@ -46,6 +48,7 @@ __all__ = [ #noqa 'TruncatedNormal', 'Uniform', 'Orthogonal', + 'Dirac', 'set_global_initializer', 'calculate_gain' ] diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py new file mode 100644 index 0000000000..55765782e5 --- /dev/null +++ b/python/paddle/nn/initializer/dirac.py @@ -0,0 +1,223 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...fluid.initializer import Initializer +from ...fluid.data_feeder import check_variable_and_dtype +from ...fluid.core import VarDesc +from ...fluid import unique_name, framework + +__all__ = [] + + +class Dirac(Initializer): + """Initialize the 3D/4D/5D Tensor with Dirac delta function. + + It can reserve the feature of convolution layer input, which means that + as many channels are reserved as possible. + + In this initialize method, elements in the middle of convolution kernels will + be set to 1 . The formula can be described as: + + $ Assuming: N=min(in\_channels, out\_channels)$ + + $ X[d, d, shape[2]//2, shape[3]//2, ...]=1, \ d=0,1...N$ + + Args: + groups(int): 0-dimension of the Tensor will be divided by groups, each group has the same value. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Dirac initializer instance objects. + + Examples: + .. code-block:: python + + import paddle + + #1.For kernel_size is uneven number: + + attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Dirac()) + conv = paddle.nn.Conv1D(3, 2, 3, weight_attr=attr) + conv.weight + # Tensor(shape=[2, 3, 3], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[[0., 1., 0.], + # [0., 0., 0.], + # [0., 0., 0.]], + # + # [[0., 0., 0.], + # [0., 1., 0.], + # [0., 0., 0.]]]) + + input = paddle.rand([8, 3, 10]) + output = conv(input) + output == input[:, 0:2, 1:9] + # output.shape is [8, 2, 8], It means output is almost the same with input, 2 channels are reserved + + + #2. For kernel_size is even number: + attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Dirac()) + conv = paddle.nn.Conv1D(3, 2, 4, weight_attr=attr) + conv.weight + # Tensor(shape=[2, 3, 4], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[[0., 0., 1., 0.], + # [0., 0., 0., 0.], + # [0., 0., 0., 0.]], + # + # [[0., 0., 0., 0.], + # [0., 0., 1., 0.], + # [0., 0., 0., 0.]]]) + """ + + def __init__(self, groups=1, name=None): + assert groups > 0 and isinstance( + groups, int), " 'groups' must be a positive integer. " + super(Dirac, self).__init__() + self._groups = groups + + def __call__(self, var, block=None): + """Initialize the input tensor with dirac initializer. + + Args: + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. + + Returns: + The most critical OP(scatter) in this initializer, which contains 7~8 ops in total. + """ + block = self._check_block(block) + assert isinstance(var, framework.Parameter) + assert isinstance(block, framework.Block) + check_variable_and_dtype( + var, "Out", ['float16', 'bfloat16', 'float32', 'float64'], 'Dirac') + + assert len(var.shape) in [ + 3, 4, 5 + ], "Only Tensor with 3/4/5 dimensions can be initialized by Dirac" + assert (var.shape[0] % self._groups + ) == 0, "Tensor 0-dimension must be divisible by groups" + + if var.dtype != VarDesc.VarType.FP32: + out_var = block.create_var( + name=unique_name.generate(".".join(['dirac', var.name, 'tmp'])), + shape=var.shape, + dtype=VarDesc.VarType.FP32, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_var = var + + block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': out_var}, + attrs={ + 'value': float(0), + 'dtype': out_var.dtype, + 'shape': out_var.shape, + }, + stop_gradient=True) + + origin_shape = var.shape + num_per_group = origin_shape[0] // self._groups + min_shape = min(num_per_group, origin_shape[1]) + + idx_list = [] + value_list = [] + strides = [] + prod = 1 + for dim in reversed(origin_shape): + strides.insert(0, prod) + prod *= dim + for i in range(self._groups): + for j in range(min_shape): + value_list.append(1.0) + offset = 0 + for (k, stride) in enumerate(strides): + if (k == 0): + offset += (j + i * num_per_group) * stride + elif (k == 1): + offset += j * stride + else: + offset += origin_shape[k] // 2 * stride + idx_list.append(offset) + + block.append_op( + type="reshape", + inputs={"X": out_var}, + attrs={'shape': [-1]}, + outputs={"Out": out_var}, + stop_gradient=True) + + index_tensor = block.create_var( + name=unique_name.generate('scatter_index'), + persistable=False, + stop_gradient=True) + + block.append_op( + type='assign_value', + outputs={'Out': index_tensor}, + attrs={ + 'dtype': VarDesc.VarType.INT64, + 'shape': [len(idx_list)], + 'int64_values': idx_list + }, + stop_gradient=True) + + value_tensor = block.create_var( + name=unique_name.generate('scatter_value'), + persistable=False, + stop_gradient=True) + + block.append_op( + type='assign_value', + outputs={'Out': value_tensor}, + attrs={ + 'dtype': VarDesc.VarType.FP32, + 'shape': [len(value_list)], + 'fp32_values': value_list + }, + stop_gradient=True) + + op = block.append_op( + type="scatter", + inputs={ + "X": out_var, + "Ids": index_tensor, + "Updates": value_tensor + }, + attrs={'overwrite': True}, + outputs={"Out": out_var}, + stop_gradient=True) + + block.append_op( + type="reshape", + inputs={"X": out_var}, + attrs={'shape': origin_shape}, + outputs={"Out": out_var}, + stop_gradient=True) + + if var.dtype != VarDesc.VarType.FP32: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}, + stop_gradient=True) + + if not framework.in_dygraph_mode(): + var.op = op + return op -- GitLab