From 00b4073c5c6ca9aaf76080fb157e5678a24d0465 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Sun, 27 Sep 2020 16:10:51 +0800 Subject: [PATCH] support not passing block --- python/paddle/fluid/initializer.py | 122 +++++++++++++++++------------ 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index ce975ea8423..d98506fc94f 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -16,7 +16,7 @@ from __future__ import print_function from . import framework from . import core -from .framework import in_dygraph_mode +from .framework import in_dygraph_mode, default_main_program import numpy as np from .core import VarDesc from . import unique_name @@ -45,11 +45,21 @@ class Initializer(object): def __init__(self): pass - def __call__(self, param, block): + def __call__(self, param, block=None): """Add corresponding initialization operations to the network """ raise NotImplementedError() + def _check_block(self, block): + if block is None: + if in_dygraph_mode(): + block = default_main_program().global_block() + else: + raise ValueError( + "The parameter 'block' is needed in static graph mode.") + + return block + def _compute_fans(self, var): """Compute the fan_in and the fan_out for layers @@ -108,17 +118,19 @@ class ConstantInitializer(Initializer): self._value = value self._force_cpu = force_cpu - def __call__(self, var, block): - """Add constant initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with constant. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) @@ -208,17 +220,19 @@ class UniformInitializer(Initializer): self._diag_step = diag_step self._diag_val = diag_val - def __call__(self, var, block): - """Add uniform distribution initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with Uniform distribution. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(block, framework.Block) check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], "uniform_random") @@ -297,17 +311,19 @@ class NormalInitializer(Initializer): self._std_dev = scale self._seed = seed - def __call__(self, var, block): - """Add normal distribution initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with Normal distribution. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(block, framework.Block) check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], @@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer): self._std_dev = scale self._seed = seed - def __call__(self, var, block): - """Add truncated normal distribution initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with TruncatedNormal distribution. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) # Initialization Ops should be prepended and not appended @@ -490,17 +508,19 @@ class XavierInitializer(Initializer): self._fan_out = fan_out self._seed = seed - def __call__(self, var, block): - """Add xavier initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with Xavier initialization. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(block, framework.Block) check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], "xavier_init") @@ -620,17 +640,19 @@ class MSRAInitializer(Initializer): self._fan_in = fan_in self._seed = seed - def __call__(self, var, block): - """Add MSRA initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with MSRA initialization. Args: - var: Variable that needs to be initialized - block: The block in which initialization ops - should be added + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - the initialization op + The initialization op """ + block = self._check_block(block) + assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) f_in, f_out = self._compute_fans(var) @@ -745,22 +767,19 @@ class BilinearInitializer(Initializer): """ super(BilinearInitializer, self).__init__() - def __call__(self, var, block): - """Add bilinear initialization ops for a variable + def __call__(self, var, block=None): + """Initialize the input tensor with Bilinear initialization. Args: - var (Variable): Variable that needs to be initialized. - block (Block): The block in which initialization ops should - be added. + var(Tensor): Tensor that needs to be initialized. + block(Block, optional): The block in which initialization ops + should be added. Used in static graph only, default None. Returns: - Operator: the initialization op - - Raises: - ValueError: If type of `var` and `block` is not right. - If the shape of `var` size is not 4 and - var.shape[2] != var.shape[3]. + The initialization op """ + block = self._check_block(block) + if not isinstance(var, framework.Variable): raise ValueError("var must be framework.Variable.") @@ -855,7 +874,7 @@ class NumpyArrayInitializer(Initializer): super(NumpyArrayInitializer, self).__init__() self._value = value - def __call__(self, var, block): + def __call__(self, var, block=None): """Add constant initialization ops for a variable Args: @@ -866,6 +885,9 @@ class NumpyArrayInitializer(Initializer): Returns: the initialization op """ + if block is None: + block = default_main_program().global_block() + assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) -- GitLab