initializer.py 37.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import math
18
from . import framework
19
from . import core
20
from .framework import in_dygraph_mode, default_main_program
21
import numpy as np
22
from .core import VarDesc
W
Wu Yi 已提交
23
from . import unique_name
24
from .data_feeder import check_variable_and_dtype, check_type, check_dtype
25

26
__all__ = [
27
    'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
28 29
    'MSRA', 'ConstantInitializer', 'UniformInitializer', 'NormalInitializer',
    'TruncatedNormalInitializer', 'XavierInitializer', 'BilinearInitializer',
30
    'MSRAInitializer', 'NumpyArrayInitializer', 'set_global_initializer'
31
]
32

33 34 35
_global_weight_initializer_ = None
_global_bias_initializer_ = None

36 37 38 39 40 41 42 43 44 45

class Initializer(object):
    """Base class for variable initializers

    Defines the common interface of variable initializers.
    They add operations to the init program that are used
    to initialize variables. Users should not use this class
    directly, but need to use one of its implementations.
    """

W
whs 已提交
46
    def __init__(self):
47 48
        pass

49
    def __call__(self, param, block=None):
50 51 52 53
        """Add corresponding initialization operations to the network
        """
        raise NotImplementedError()

54 55 56 57 58 59 60 61 62 63
    def _check_block(self, block):
        if block is None:
            if in_dygraph_mode():
                block = default_main_program().global_block()
            else:
                raise ValueError(
                    "The parameter 'block' is needed in static graph mode.")

        return block

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    def _compute_fans(self, var):
        """Compute the fan_in and the fan_out for layers

        This method computes the fan_in and the fan_out
        for neural network layers, if not specified. It is
        not possible to perfectly estimate fan_in and fan_out.
        This method will estimate it correctly for matrix multiply and
        convolutions.

        Args:
            var: variable for which fan_in and fan_out have to be computed

        Returns:
            tuple of two integers (fan_in, fan_out)
        """
        shape = var.shape
        if not shape or len(shape) == 0:
            fan_in = fan_out = 1
        elif len(shape) == 1:
            fan_in = fan_out = shape[0]
        elif len(shape) == 2:
            # This is the case for simple matrix multiply
            fan_in = shape[0]
            fan_out = shape[1]
        else:
            # Assume this to be a convolutional kernel
            # In PaddlePaddle, the shape of the kernel is like:
            # [num_filters, num_filter_channels, ...] where the remaining
            # dimensions are the filter_size
            receptive_field_size = np.prod(shape[2:])
            fan_in = shape[1] * receptive_field_size
            fan_out = shape[0] * receptive_field_size

        return (fan_in, fan_out)

99 100 101

class ConstantInitializer(Initializer):
    """Implements the constant initializer
102 103

    Args:
D
Double_V 已提交
104
        value (float32): constant value to initialize the variable 
105 106 107 108

    Examples:
        .. code-block:: python

109 110 111
            import paddle
            import paddle.fluid as fluid
            paddle.enable_static()
D
Double_V 已提交
112
            x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32")
113 114 115 116
            fc = fluid.layers.fc(
                input=x,
                size=10,
                param_attr=fluid.initializer.Constant(value=2.0))
117

118 119
    """

120
    def __init__(self, value=0.0, force_cpu=False):
121 122 123
        assert value is not None
        super(ConstantInitializer, self).__init__()
        self._value = value
124
        self._force_cpu = force_cpu
125

126 127
    def __call__(self, var, block=None):
        """Initialize the input tensor with constant.
128 129

        Args:
130 131 132
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
133 134

        Returns:
135
            The initialization op
136
        """
137 138
        block = self._check_block(block)

139 140
        assert isinstance(var, framework.Variable)
        assert isinstance(block, framework.Block)
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155

        # to be compatible of fp16 initializers
        if var.dtype == VarDesc.VarType.FP16:
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
                name=unique_name.generate(".".join(
                    ['constant_init', var.name, 'tmp'])),
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

156
        # fill constant should set the "str_value" to preserve precision
157
        op = block.append_op(
158
            type="fill_constant",
159
            outputs={"Out": out_var},
160 161
            attrs={
                "shape": var.shape,
162
                "dtype": int(out_dtype),
163
                "value": float(self._value),
164
                'str_value': str(float(self._value)),
165
                'force_cpu': self._force_cpu
M
minqiyang 已提交
166 167
            },
            stop_gradient=True)
168 169 170 171 172 173 174 175 176

        if var.dtype == VarDesc.VarType.FP16:
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
177
        if not framework.in_dygraph_mode():
178
            var.op = op
179 180 181 182
        return op


class UniformInitializer(Initializer):
183
    """Implements the random uniform distribution initializer
184 185 186 187 188

    Args:
        low (float): lower boundary of the uniform distribution
        high (float): upper boundary of the uniform distribution
        seed (int): random seed
189 190 191 192 193 194
        diag_num (int): the number of diagonal elements to initialize.
            If set to 0, diagonal initialization will be not performed.
        diag_step (int): Step size between two diagonal elements,
            which is generally the width of the square matrix.
        diag_val (float): the value of the diagonal element to be initialized,
            default 1.0. It takes effect only if the diag_num is greater than 0.
195 196 197 198

    Examples:
        .. code-block:: python

X
xiaoting 已提交
199
            import paddle.fluid as fluid
200
            x = fluid.data(name='x', shape=[None, 1], dtype='float32')
201
            fc = fluid.layers.fc(input=x, size=10,
202
    		param_attr=fluid.initializer.Uniform(low=-0.5, high=0.5))
203 204
    """

205 206 207 208 209 210 211
    def __init__(self,
                 low=-1.0,
                 high=1.0,
                 seed=0,
                 diag_num=0,
                 diag_step=0,
                 diag_val=1.0):
212 213
        assert low is not None
        assert high is not None
214
        assert high >= low
215
        assert seed is not None
216 217 218 219 220
        assert diag_num is not None
        assert diag_step is not None
        assert diag_val is not None
        if diag_num > 0 or diag_step > 0:
            assert (diag_num > 0 and diag_step > 0)
221 222 223 224
        super(UniformInitializer, self).__init__()
        self._low = low
        self._high = high
        self._seed = seed
225 226 227
        self._diag_num = diag_num
        self._diag_step = diag_step
        self._diag_val = diag_val
228

229 230
    def __call__(self, var, block=None):
        """Initialize the input tensor with Uniform distribution.
231 232

        Args:
233 234 235
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
236 237

        Returns:
238
            The initialization op
239
        """
240 241
        block = self._check_block(block)

242
        assert isinstance(block, framework.Block)
243 244
        check_variable_and_dtype(var, "Out",
                                 ["uint16", "float16", "float32", "float64"],
245 246
                                 "uniform_random")

D
dzhwinter 已提交
247 248
        if self._seed == 0:
            self._seed = block.program.random_seed
W
Wu Yi 已提交
249

X
polish  
Xin Pan 已提交
250
        # to be compatible of fp16 initializers
251
        if var.dtype == VarDesc.VarType.FP16:
W
Wu Yi 已提交
252 253
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
254 255
                name=unique_name.generate(".".join(
                    ['uniform_random', var.name, 'tmp'])),
W
Wu Yi 已提交
256 257 258 259 260 261 262 263
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

264
        op = block.append_op(
265
            type="uniform_random",
266
            inputs={},
W
Wu Yi 已提交
267
            outputs={"Out": out_var},
268 269
            attrs={
                "shape": var.shape,
W
Wu Yi 已提交
270
                "dtype": out_dtype,
271 272
                "min": self._low,
                "max": self._high,
273 274 275 276
                "seed": self._seed,
                "diag_num": self._diag_num,
                "diag_step": self._diag_step,
                "diag_val": self._diag_val
M
minqiyang 已提交
277 278
            },
            stop_gradient=True)
W
Wu Yi 已提交
279

280
        if var.dtype == VarDesc.VarType.FP16:
W
Wu Yi 已提交
281 282 283 284 285 286 287
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
288
        if not framework.in_dygraph_mode():
289
            var.op = op
290
        return op
291 292 293


class NormalInitializer(Initializer):
294 295 296 297 298 299 300 301 302 303
    """Implements the Random Normal(Gaussian) distribution initializer

    Args:
        loc (float): mean of the normal distribution
        scale (float): standard deviation of the normal distribution
        seed (int): random seed

    Examples:
        .. code-block:: python

X
xsrobin 已提交
304
            import paddle.fluid as fluid
305
            x = fluid.data(name="data", shape=[None, 32, 32], dtype="float32")
X
xsrobin 已提交
306 307
            fc = fluid.layers.fc(input=x, size=10,
                param_attr=fluid.initializer.Normal(loc=0.0, scale=2.0))
308

309 310 311 312 313 314 315 316 317 318 319
    """

    def __init__(self, loc=0.0, scale=1.0, seed=0):
        assert loc is not None
        assert scale is not None
        assert seed is not None
        super(NormalInitializer, self).__init__()
        self._mean = loc
        self._std_dev = scale
        self._seed = seed

320 321
    def __call__(self, var, block=None):
        """Initialize the input tensor with Normal distribution.
322 323

        Args:
324 325 326
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
327 328

        Returns:
329
            The initialization op
330
        """
331 332
        block = self._check_block(block)

333
        assert isinstance(block, framework.Block)
334

335 336
        check_variable_and_dtype(var, "Out",
                                 ["uint16", "float16", "float32", "float64"],
337
                                 "guassian_random")
338

D
dzhwinter 已提交
339 340
        if self._seed == 0:
            self._seed = block.program.random_seed
W
Wu Yi 已提交
341 342

        # to be compatible of fp16 initalizers
343
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
W
Wu Yi 已提交
344 345
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
346 347
                name=unique_name.generate(".".join(
                    ['gaussian_random', var.name, 'tmp'])),
W
Wu Yi 已提交
348 349 350 351 352 353 354 355
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

356
        op = block.append_op(
357
            type="gaussian_random",
W
Wu Yi 已提交
358
            outputs={"Out": out_var},
359 360
            attrs={
                "shape": var.shape,
W
Wu Yi 已提交
361
                "dtype": out_dtype,
362 363
                "mean": self._mean,
                "std": self._std_dev,
G
gongweibao 已提交
364 365
                "seed": self._seed,
                "use_mkldnn": False
M
minqiyang 已提交
366 367
            },
            stop_gradient=True)
W
Wu Yi 已提交
368

369
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
W
Wu Yi 已提交
370 371 372 373 374 375
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})
L
lujun 已提交
376
        if not framework.in_dygraph_mode():
377
            var.op = op
378
        return op
379 380


381 382 383 384 385 386 387 388 389 390 391
class TruncatedNormalInitializer(Initializer):
    """Implements the Random TruncatedNormal(Gaussian) distribution initializer

    Args:
        loc (float): mean of the normal distribution
        scale (float): standard deviation of the normal distribution
        seed (int): random seed

    Examples:
        .. code-block:: python

X
xiaoting 已提交
392
            import paddle.fluid as fluid
393
            x = fluid.data(name='x', shape=[None, 1], dtype='float32')
394 395 396 397 398 399 400 401
            fc = fluid.layers.fc(input=x, size=10,
                param_attr=fluid.initializer.TruncatedNormal(loc=0.0, scale=2.0))
    """

    def __init__(self, loc=0.0, scale=1.0, seed=0):
        assert loc is not None
        assert scale is not None
        assert seed is not None
W
whs 已提交
402
        super(TruncatedNormalInitializer, self).__init__()
403 404 405 406
        self._mean = loc
        self._std_dev = scale
        self._seed = seed

407 408
    def __call__(self, var, block=None):
        """Initialize the input tensor with TruncatedNormal distribution.
409 410

        Args:
411 412 413
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
414 415

        Returns:
416
            The initialization op
417
        """
418 419
        block = self._check_block(block)

420 421
        assert isinstance(var, framework.Variable)
        assert isinstance(block, framework.Block)
422

423 424
        if self._seed == 0:
            self._seed = block.program.random_seed
425 426

        # to be compatible of fp16 initalizers
427
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
428 429 430
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
                name=unique_name.generate(".".join(
431
                    ['truncated_gaussian_random', var.name, 'tmp'])),
432 433 434 435 436 437 438 439
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

440
        op = block.append_op(
441
            type="truncated_gaussian_random",
442
            outputs={"Out": out_var},
443 444
            attrs={
                "shape": var.shape,
445
                "dtype": out_dtype,
446 447 448
                "mean": self._mean,
                "std": self._std_dev,
                "seed": self._seed
M
minqiyang 已提交
449 450
            },
            stop_gradient=True)
451

452
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
453 454 455 456 457 458
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})
L
lujun 已提交
459
        if not framework.in_dygraph_mode():
460
            var.op = op
461 462 463
        return op


464
class XavierInitializer(Initializer):
465
    r"""
466
    This class implements the Xavier weight initializer from the paper
Q
qiaolongfei 已提交
467 468 469
    `Understanding the difficulty of training deep feedforward neural
    networks <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
    by Xavier Glorot and Yoshua Bengio.
470 471 472

    This initializer is designed to keep the scale of the gradients
    approximately same in all the layers. In case of Uniform distribution,
Q
qiaolongfei 已提交
473 474 475 476 477 478
    the range is [-x, x], where

    .. math::

        x = \sqrt{\\frac{6.0}{fan\_in + fan\_out}}

479
    In case of Normal distribution, the mean is 0 and the standard deviation
Q
qiaolongfei 已提交
480
    is
481

Q
qiaolongfei 已提交
482
    .. math::
483

Q
qiaolongfei 已提交
484
        \sqrt{\\frac{2.0}{fan\_in + fan\_out}}
485 486


Q
qiaolongfei 已提交
487
    Args:
X
xiaoting 已提交
488 489
        uniform (bool,default True): whether to use uniform ,if False use normal distribution
        fan_in (float,default None): fan_in for Xavier initialization. If None, it is
Q
qiaolongfei 已提交
490
                inferred from the variable.
X
xiaoting 已提交
491
        fan_out (float,default None): fan_out for Xavier initialization. If None, it is
Q
qiaolongfei 已提交
492 493 494 495 496 497 498 499 500
                 inferred from the variable.
        seed (int): random seed

    Note:
        It is recommended to set fan_in and fan_out to None for most cases.

    Examples:
        .. code-block:: python

X
xiaoting 已提交
501
            import paddle.fluid as fluid
X
xiaoting 已提交
502
            queries = fluid.data(name='x', shape=[None,1], dtype='float32')
Q
qiaolongfei 已提交
503 504 505 506 507 508 509
            fc = fluid.layers.fc(
                input=queries, size=10,
                param_attr=fluid.initializer.Xavier(uniform=False))

    """

    def __init__(self, uniform=True, fan_in=None, fan_out=None, seed=0):
510 511 512 513 514 515 516 517
        assert uniform is not None
        assert seed is not None
        super(XavierInitializer, self).__init__()
        self._uniform = uniform
        self._fan_in = fan_in
        self._fan_out = fan_out
        self._seed = seed

518 519
    def __call__(self, var, block=None):
        """Initialize the input tensor with Xavier initialization.
520 521

        Args:
522 523 524
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
525 526

        Returns:
527
            The initialization op
528
        """
529 530
        block = self._check_block(block)

531
        assert isinstance(block, framework.Block)
532 533
        check_variable_and_dtype(var, "Out",
                                 ["uint16", "float16", "float32", "float64"],
534 535
                                 "xavier_init")

536 537 538 539 540 541
        f_in, f_out = self._compute_fans(var)

        # If fan_in and fan_out are passed, use them
        fan_in = f_in if self._fan_in is None else self._fan_in
        fan_out = f_out if self._fan_out is None else self._fan_out

D
dzhwinter 已提交
542 543 544
        if self._seed == 0:
            self._seed = block.program.random_seed

545
        # to be compatible of fp16 initalizers
546 547
        if var.dtype == VarDesc.VarType.FP16 or (
                var.dtype == VarDesc.VarType.BF16 and not self._uniform):
548 549 550 551 552 553 554 555 556 557 558 559
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
                name=unique_name.generate(".".join(
                    ['xavier_init', var.name, 'tmp'])),
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

560 561
        if self._uniform:
            limit = np.sqrt(6.0 / float(fan_in + fan_out))
562
            op = block.append_op(
563
                type="uniform_random",
564
                inputs={},
565
                outputs={"Out": out_var},
566
                attrs={
567 568
                    "shape": out_var.shape,
                    "dtype": out_dtype,
569 570 571
                    "min": -limit,
                    "max": limit,
                    "seed": self._seed
M
minqiyang 已提交
572 573
                },
                stop_gradient=True)
574 575 576

        else:
            std = np.sqrt(2.0 / float(fan_in + fan_out))
577
            op = block.append_op(
578
                type="gaussian_random",
579
                outputs={"Out": out_var},
580
                attrs={
581 582
                    "shape": out_var.shape,
                    "dtype": out_dtype,
583 584 585
                    "mean": 0.0,
                    "std": std,
                    "seed": self._seed
M
minqiyang 已提交
586 587
                },
                stop_gradient=True)
588

589 590
        if var.dtype == VarDesc.VarType.FP16 or (
                var.dtype == VarDesc.VarType.BF16 and not self._uniform):
591 592 593 594 595 596 597
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
598
        if not framework.in_dygraph_mode():
599
            var.op = op
600
        return op
601 602 603


class MSRAInitializer(Initializer):
604
    r"""Implements the MSRA initializer a.k.a. Kaiming Initializer
605 606

    This class implements the weight initialization from the paper
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
    `Delving Deep into Rectifiers: Surpassing Human-Level Performance on
    ImageNet Classification <https://arxiv.org/abs/1502.01852>`_
    by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. This is a
    robust initialization method that particularly considers the rectifier
    nonlinearities. In case of Uniform distribution, the range is [-x, x], where

    .. math::

        x = \sqrt{\\frac{6.0}{fan\_in}}

    In case of Normal distribution, the mean is 0 and the standard deviation
    is

    .. math::

        \sqrt{\\frac{2.0}{fan\_in}}

    Args:
        uniform (bool): whether to use uniform or normal distribution
D
Double_V 已提交
626 627 628
        fan_in (float32|None): fan_in for MSRAInitializer. If None, it is\
        inferred from the variable. default is None.
        seed (int32): random seed
629 630 631 632 633 634

    Note:
        It is recommended to set fan_in to None for most cases.

    Examples:
        .. code-block:: python
X
xsrobin 已提交
635

636
            import paddle
X
xsrobin 已提交
637
            import paddle.fluid as fluid
638
            paddle.enable_static()
D
Double_V 已提交
639
            x = fluid.data(name="data", shape=[8, 32, 32], dtype="float32")
X
xsrobin 已提交
640 641
            fc = fluid.layers.fc(input=x, size=10,
                param_attr=fluid.initializer.MSRA(uniform=False))
642

643 644 645 646 647 648 649 650 651 652 653 654
    """

    def __init__(self, uniform=True, fan_in=None, seed=0):
        """Constructor for MSRAInitializer
        """
        assert uniform is not None
        assert seed is not None
        super(MSRAInitializer, self).__init__()
        self._uniform = uniform
        self._fan_in = fan_in
        self._seed = seed

655 656
    def __call__(self, var, block=None):
        """Initialize the input tensor with MSRA initialization.
657 658

        Args:
659 660 661
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
662 663

        Returns:
664
            The initialization op
665
        """
666 667
        block = self._check_block(block)

668 669 670 671 672 673 674
        assert isinstance(var, framework.Variable)
        assert isinstance(block, framework.Block)
        f_in, f_out = self._compute_fans(var)

        # If fan_in is passed, use it
        fan_in = f_in if self._fan_in is None else self._fan_in

D
dzhwinter 已提交
675 676 677
        if self._seed == 0:
            self._seed = block.program.random_seed

678
        # to be compatible of fp16 initalizers
679 680
        if var.dtype == VarDesc.VarType.FP16 or (
                var.dtype == VarDesc.VarType.BF16 and not self._uniform):
681 682 683 684 685 686 687 688 689 690 691 692
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
                name=unique_name.generate(".".join(
                    ['masra_init', var.name, 'tmp'])),
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

693 694
        if self._uniform:
            limit = np.sqrt(6.0 / float(fan_in))
695
            op = block.append_op(
696
                type="uniform_random",
697
                inputs={},
698
                outputs={"Out": out_var},
699
                attrs={
700 701
                    "shape": out_var.shape,
                    "dtype": int(out_dtype),
702 703 704
                    "min": -limit,
                    "max": limit,
                    "seed": self._seed
M
minqiyang 已提交
705 706
                },
                stop_gradient=True)
707 708 709

        else:
            std = np.sqrt(2.0 / float(fan_in))
710
            op = block.append_op(
711
                type="gaussian_random",
712
                outputs={"Out": out_var},
713
                attrs={
714 715
                    "shape": out_var.shape,
                    "dtype": int(out_dtype),
716 717 718
                    "mean": 0.0,
                    "std": std,
                    "seed": self._seed
M
minqiyang 已提交
719 720
                },
                stop_gradient=True)
721

722 723
        if var.dtype == VarDesc.VarType.FP16 or (
                var.dtype == VarDesc.VarType.BF16 and not self._uniform):
724 725 726 727 728 729 730
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
731
        if not framework.in_dygraph_mode():
732
            var.op = op
733
        return op
734 735


736
class BilinearInitializer(Initializer):
737
    """
738 739 740
    This initializer can be used in transposed convolution operator to
    act as upsampling. Users can upsample a feature map with shape of
    (B, C, H, W) by any integer factor. The usage is:
741 742 743 744 745

    Examples:

        .. code-block:: python

746
            import math
747 748 749 750 751

            import paddle
            import paddle.nn as nn
            from paddle.regularizer import L2Decay

X
xsrobin 已提交
752 753
            factor = 2
            C = 2
D
Double_V 已提交
754 755
            B = 8
            H = W = 32
756 757 758 759
            w_attr = paddle.ParamAttr(learning_rate=0.,
                                      regularizer=L2Decay(0.),
                                      initializer=nn.initializer.Bilinear())
            data = paddle.rand([B, 3, H, W], dtype='float32')
C
cnn 已提交
760
            conv_up = nn.Conv2DTranspose(3,
761 762 763 764 765 766 767 768 769 770 771
                                         out_channels=C,
                                         kernel_size=2 * factor - factor % 2,
                                         padding=int(
                                             math.ceil((factor - 1) / 2.)),
                                         stride=factor,
                                         weight_attr=w_attr,
                                         bias_attr=False)
            x = conv_up(data)

    Where, `out_channels=C` and `groups=C` means this is channel-wise transposed
    convolution. The filter shape will be (C, 1, K, K) where K is `kernel_size`,
772 773 774 775
    This initializer will set a (K, K) interpolation kernel for every channel
    of the filter identically. The resulting shape of the output feature map
    will be (B, C, factor * H, factor * W). Note that the learning rate and the
    weight decay are set to 0 in order to keep coefficient values of bilinear
776 777
    interpolation unchanged during training.

778 779 780 781 782 783 784
    """

    def __init__(self):
        """Constructor for BilinearInitializer.
        """
        super(BilinearInitializer, self).__init__()

785 786
    def __call__(self, var, block=None):
        """Initialize the input tensor with Bilinear initialization.
787 788

        Args:
789 790 791
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
792 793

        Returns:
794
            The initialization op
795
        """
796 797
        block = self._check_block(block)

798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
        if not isinstance(var, framework.Variable):
            raise ValueError("var must be framework.Variable.")

        if not isinstance(block, framework.Block):
            raise ValueError("block must be framework.Block.")

        shape = var.shape
        if len(shape) != 4:
            raise ValueError("the length of shape must be 4.")
        if shape[2] != shape[3]:
            raise ValueError("shape[2] must be equal to shape[3].")

        weight = np.zeros(np.prod(var.shape), dtype='float32')
        size = shape[3]
        # factor
        f = np.ceil(size / 2.)
        # center
        c = (2 * f - 1 - f % 2) / (2. * f)
        for i in range(np.prod(shape)):
            x = i % size
            y = (i / size) % size
            weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
        weight = np.reshape(weight, shape)

822
        # to be compatible of fp16 initalizers
823 824 825
        if var.dtype in [
                VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP64
        ]:
826 827 828 829 830 831 832 833 834 835 836 837 838
            out_dtype = VarDesc.VarType.FP32
            out_var = block.create_var(
                name=unique_name.generate(".".join(
                    ['bilinear_init', var.name, 'tmp'])),
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_dtype = var.dtype
            out_var = var

        if out_dtype == VarDesc.VarType.FP32:
839 840 841
            value_name = "fp32_values"
            values = [float(v) for v in weight.flat]
        else:
842 843
            raise TypeError("Unsupported dtype %s", var.dtype)

844 845 846 847
        if np.prod(shape) > 1024 * 1024:
            raise ValueError("The size of input is too big. ")
        op = block.append_op(
            type='assign_value',
848
            outputs={'Out': [out_var]},
849
            attrs={
850
                'dtype': out_dtype,
851 852 853
                'shape': list(shape),
                value_name: values
            })
854

855 856 857
        if var.dtype in [
                VarDesc.VarType.FP16, VarDesc.VarType.BF16, VarDesc.VarType.FP64
        ]:
858 859 860 861 862 863 864
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
865
        if not framework.in_dygraph_mode():
866
            var.op = op
867 868 869
        return op


870 871
class NumpyArrayInitializer(Initializer):
    """Init an parameter with an numpy array
872
    This op initialize the variable by numpy array.
873 874 875 876

    Args:
        value (numpy): numpy array to initialize the variable

877 878 879
    Returns:
        A Tensor variable initialized by numpy.

880 881 882
    Examples:
        .. code-block:: python

883
            import paddle.fluid as fluid
884 885
            import numpy
            x = fluid.data(name="x", shape=[2, 1], dtype='float32')
886 887 888 889 890 891 892 893 894 895
            fc = fluid.layers.fc(input=x, size=10,
                param_attr=fluid.initializer.NumpyArrayInitializer(numpy.array([1,2])))
    """

    def __init__(self, value):
        import numpy
        assert isinstance(value, numpy.ndarray)
        super(NumpyArrayInitializer, self).__init__()
        self._value = value

896 897
    def __call__(self, var, block=None):
        """Initialize the input tensor with Numpy array.
898 899

        Args:
900 901 902
            var(Tensor): Tensor that needs to be initialized.
            block(Block, optional): The block in which initialization ops
                   should be added. Used in static graph only, default None.
903 904

        Returns:
905
            The initialization op
906
        """
907 908
        block = self._check_block(block)

909 910
        assert isinstance(var, framework.Variable)
        assert isinstance(block, framework.Block)
911 912

        # to be compatible of fp16 initalizers
913
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928
            out_dtype = VarDesc.VarType.FP32
            np_value = self._value.astype("float32")
            out_var = block.create_var(
                name=unique_name.generate(".".join(
                    ['numpy_array_init', var.name, 'tmp'])),
                shape=var.shape,
                dtype=out_dtype,
                type=VarDesc.VarType.LOD_TENSOR,
                persistable=False)
        else:
            out_var = var
            out_dtype = var.dtype
            np_value = self._value

        if out_dtype == VarDesc.VarType.FP32:
929
            value_name = "fp32_values"
930 931
            values = [float(v) for v in np_value.flat]
        elif out_dtype == VarDesc.VarType.INT32:
932
            value_name = "int32_values"
933
            values = [int(v) for v in np_value.flat]
934 935
        else:
            raise ValueError("Unsupported dtype %s", self._value.dtype)
X
Xin Pan 已提交
936
        if self._value.size > 1024 * 1024 * 1024:
937 938
            raise ValueError("The size of input is too big. Please consider "
                             "saving it to file and 'load_op' to load it")
939
        op = block.append_op(
940
            type='assign_value',
941
            outputs={'Out': out_var},
942
            attrs={
943
                'dtype': out_dtype,
944
                'shape': list(self._value.shape),
945 946 947
                value_name: values
            },
            stop_gradient=True)
948

949
        if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]:
950 951 952 953 954 955 956
            block.append_op(
                type="cast",
                inputs={"X": out_var},
                outputs={"Out": var},
                attrs={"in_dtype": out_var.dtype,
                       "out_dtype": var.dtype})

L
lujun 已提交
957
        if not framework.in_dygraph_mode():
958
            var.op = op
959 960 961
        return op


962 963 964 965 966 967 968
def set_global_initializer(weight_init, bias_init=None):
    """
    This API is used to set up global model parameter initializer in framework.

    After this API is invoked, the global initializer will takes effect in subsequent code.

    The model parameters include ``weight`` and ``bias`` . In the framework, they correspond 
969
    to ``paddle.ParamAttr`` , which is inherited from ``paddle.Tensor`` , and is a persistable Variable.
970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988
    This API only takes effect for model parameters, not for variables created through apis such as 
    :ref:`api_fluid_layers_create_global_var` , :ref:`api_fluid_layers_create_tensor`.
    
    If the initializer is also set up by ``param_attr`` or ``bias_attr`` when creating a network layer,
    the global initializer setting here will not take effect because it has a lower priority.

    If you want to cancel the global initializer in framework, please set global initializer to ``None`` .

    Args:
        weight_init (Initializer): set the global initializer for ``weight`` of model parameters.
        bias_init (Initializer, optional): set the global initializer for ``bias`` of model parameters. 
            Default: None.

    Returns:
        None

    Examples:
        .. code-block:: python

989 990 991 992 993
            import paddle
            import paddle.nn as nn

            nn.initializer.set_global_initializer(nn.initializer.Uniform(), nn.initializer.Constant())
            x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
994 995 996

            # The weight of conv1 is initialized by Uniform
            # The bias of conv1 is initialized by Constant
997 998
            conv1 = nn.Conv2D(4, 6, (3, 3))
            y_var1 = conv1(x_var)
999 1000 1001 1002

            # If set param_attr/bias_attr too, global initializer will not take effect
            # The weight of conv2 is initialized by Xavier
            # The bias of conv2 is initialized by Normal
1003 1004 1005 1006
            conv2 = nn.Conv2D(4, 6, (3, 3), 
                weight_attr=nn.initializer.XavierUniform(),
                bias_attr=nn.initializer.Normal())
            y_var2 = conv2(x_var)
1007 1008

            # Cancel the global initializer in framework, it will takes effect in subsequent code
1009
            nn.initializer.set_global_initializer(None)
1010
    """
1011

1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
    check_type(weight_init, 'weight_init', (Initializer, type(None)),
               'set_global_initializer')
    global _global_weight_initializer_
    _global_weight_initializer_ = weight_init

    check_type(bias_init, 'bias_init', (Initializer, type(None)),
               'set_global_initializer')
    global _global_bias_initializer_
    _global_bias_initializer_ = bias_init


def _global_weight_initializer():
    """
    Return the global weight initializer, The user doesn't need to use it.
    """
    return _global_weight_initializer_


def _global_bias_initializer():
    """
    Return the global weight initializer, The user doesn't need to use it.
    """
    return _global_bias_initializer_


1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082
def calculate_gain(nonlinearity, param=None):
    """
    Get the recommended gain value of some nonlinearity function.

    Args:
        nonlinearity(str): nonlinearity function.
        param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to 'leaky_relu'. Default: None,  
        it will be calculated as 0.01 in the formula.

    Returns:
        The recommended gain value for nonlinearity function.

    Examples:
        .. code-block:: python

            import paddle
            gain = paddle.nn.initializer.calculate_gain('tanh') # 5.0 / 3
            gain = paddle.nn.initializer.calculate_gain('leaky_relu', param=1.0) # 1.0 = math.sqrt(2.0 / (1+param^2))

    """
    if param is None:
        param = 0.01
    else:
        assert isinstance(param, (bool, int, float))
        param = float(param)
    recommended_gain = {
        'sigmoid': 1,
        'linear': 1,
        'conv1d': 1,
        'conv2d': 1,
        'conv3d': 1,
        'conv_transpose1d': 1,
        'conv_transpose2d': 1,
        'conv_transpose3d': 1,
        'tanh': 5.0 / 3,
        'relu': math.sqrt(2.0),
        'leaky_relu': math.sqrt(2.0 / (1 + param**2)),
        'selu': 3.0 / 4
    }
    if nonlinearity in recommended_gain.keys():
        return recommended_gain[nonlinearity]
    else:
        raise ValueError("nonlinearity function {} is not suppported now.".
                         format(nonlinearity))


1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094
# We short the class name, since users will use the initializer with the package
# name. The sample code:
#
# import paddle.fluid as fluid
#
# hidden = fluid.layers.fc(...,
#                          param_attr=ParamAttr(fluid.initializer.Xavier()))
#
# It is no need to add an `Initializer` as the class suffix
Constant = ConstantInitializer
Uniform = UniformInitializer
Normal = NormalInitializer
1095
TruncatedNormal = TruncatedNormalInitializer
1096 1097
Xavier = XavierInitializer
MSRA = MSRAInitializer
1098
Bilinear = BilinearInitializer