diff --git a/python/paddle/v2/framework/initializer.py b/python/paddle/v2/framework/initializer.py index 507fd16062af1e2458eb9b45407e91a8d29ea9ce..98a87bfa86efb39f381b9f99b2b1f0d7ec7d9833 100644 --- a/python/paddle/v2/framework/initializer.py +++ b/python/paddle/v2/framework/initializer.py @@ -1,6 +1,10 @@ import paddle.v2.framework.framework as framework +import numpy as np -__all__ = ['ConstantInitializer', 'UniformInitializer'] +__all__ = [ + 'ConstantInitializer', 'UniformInitializer', 'NormalInitializer', + 'XavierInitializer' +] class Initializer(object): @@ -20,6 +24,41 @@ class Initializer(object): """ raise NotImplementedError() + def _compute_fans(self, var): + """Compute the fan_in and the fan_out for layers + + This method computes the fan_in and the fan_out + for neural network layers, if not specified. It is + not possible to perfectly estimate fan_in and fan_out. + This method will estimate it correctly for matrix multiply and + convolutions. + + Args: + var: variable for which fan_in and fan_out have to be computed + + Returns: + tuple of two integers (fan_in, fan_out) + """ + shape = var.shape + if not shape or len(shape) == 0: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + # This is the case for simple matrix multiply + fan_in = shape[0] + fan_out = shape[1] + else: + # Assume this to be a convolutional kernel + # In PaddlePaddle, the shape of the kernel is like: + # [num_filters, num_filter_channels, ...] where the remaining + # dimensions are the filter_size + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + + return (fan_in, fan_out) + class ConstantInitializer(Initializer): """Implements the constant initializer @@ -156,3 +195,93 @@ class NormalInitializer(Initializer): }) var.op = op return op + + +class XavierInitializer(Initializer): + """Implements the Xavier initializer + + This class implements the Xavier weight initializer from the paper + Understanding the difficulty of training deep feedforward neural + networks[1] by Xavier Glorot and Yoshua Bengio. + + This initializer is designed to keep the scale of the gradients + approximately same in all the layers. In case of Uniform distribution, + the range is [-x, x], where x = sqrt(6 / (fan_in + fan_out)). + In case of Normal distribution, the mean is 0 and the standard deviation + is sqrt(2/ (fan_in + fan_out)). + + References: + [1] Understanding the difficulty of training deep feedforward neural + networks. International conference on artificial intelligence and + statistics. + (http://proceedings.mlr.press/v9/glorot10a.html) + """ + + def __init__(self, uniform=True, fan_in=None, fan_out=None, seed=0): + """Constructor for XavierInitializer + + Args: + uniform: whether to use uniform or normal distribution + fan_in: fan_in for Xavier initialization. If None, it is + inferred from the variable. + fan_out: fan_out for Xavier initialization. If None, it is + inferred from the variable. + seed: random seed + + Note: It is recommended to set fan_in and fan_out to None for + most cases. + """ + assert uniform is not None + assert seed is not None + super(XavierInitializer, self).__init__() + self._uniform = uniform + self._fan_in = fan_in + self._fan_out = fan_out + self._seed = seed + + def __call__(self, var, block): + """Add xavier initialization ops for a variable + + Args: + var: Variable that needs to be initialized + block: The block in which initialization ops + should be added + + Returns: + the initialization op + """ + assert isinstance(var, framework.Variable) + assert isinstance(block, framework.Block) + f_in, f_out = self._compute_fans(var) + + # If fan_in and fan_out are passed, use them + fan_in = f_in if self._fan_in is None else self._fan_in + fan_out = f_out if self._fan_out is None else self._fan_out + + if self._uniform: + limit = np.sqrt(6.0 / float(fan_in + fan_out)) + op = block.prepend_op( + type="uniform_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "min": -limit, + "max": limit, + "seed": self._seed + }) + + else: + std = np.sqrt(2.0 / float(fan_in + fan_out)) + op = block.prepend_op( + type="gaussian_random", + outputs={"Out": var}, + attrs={ + "shape": var.shape, + "data_type": int(var.data_type), + "mean": 0.0, + "std": std, + "seed": self._seed + }) + var.op = op + return op diff --git a/python/paddle/v2/framework/tests/test_initializer.py b/python/paddle/v2/framework/tests/test_initializer.py index f28fc8a86c7c8e683e00249a2f73dbbe6d7be27c..bd4d2e39d770aebb7468d516f463533185ea8680 100644 --- a/python/paddle/v2/framework/tests/test_initializer.py +++ b/python/paddle/v2/framework/tests/test_initializer.py @@ -1,3 +1,4 @@ +import numpy as np import unittest import paddle.v2.framework.framework as framework @@ -116,5 +117,111 @@ class TestNormalInitializer(unittest.TestCase): self.assertEqual(init_op.attr('seed'), 123) +class TestXavierInitializer(unittest.TestCase): + def test_uniform_xavier_initializer(self): + """Test Xavier initializer with uniform distribution on + for matrix multiply. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer()) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + limit = np.sqrt(6.0 / (param.shape[0] + param.shape[1])) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_uniform_xavier_initializer_conv(self): + """Test Xavier initializer with uniform distribution on + for convolutions. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10, 15, 20], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer()) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + receptive_field_size = float(15 * 20) + limit = np.sqrt(6.0 / ( + (param.shape[0] + param.shape[1]) * receptive_field_size)) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_normal_xavier_initializer(self): + """Test Xavier initializer with normal distribution on + for matrix multiply. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer(uniform=False)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'gaussian_random') + std = np.sqrt(2.0 / (param.shape[0] + param.shape[1])) + self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA) + self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_normal_xavier_initializer_conv(self): + """Test Xavier initializer with normal distribution on + for convolutions. + """ + program = framework.Program() + block = program.global_block() + param = block.create_parameter( + dtype="float32", + shape=[5, 10, 15, 20], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer(uniform=False)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'gaussian_random') + receptive_field_size = float(15 * 20) + std = np.sqrt(2.0 / ( + (param.shape[0] + param.shape[1]) * receptive_field_size)) + self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA) + self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 0) + + def test_xavier_initializer_supplied_arguments(self): + """Test the Xavier initializer with supplied arguments + """ + program = framework.Program() + block = program.global_block() + block.create_parameter( + dtype="float32", + shape=[5, 10], + lod_level=0, + name="param", + initializer=initializer.XavierInitializer( + fan_in=12, fan_out=23, seed=134)) + self.assertEqual(len(block.ops), 1) + init_op = block.ops[0] + self.assertEqual(init_op.type, 'uniform_random') + limit = np.sqrt(6.0 / (12 + 23)) + self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA) + self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA) + self.assertEqual(init_op.attr('seed'), 134) + + if __name__ == '__main__': unittest.main()