提交 00b4073c 编写于 作者: Z zhiqiu

support not passing block

上级 6b727e08
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from . import framework from . import framework
from . import core from . import core
from .framework import in_dygraph_mode from .framework import in_dygraph_mode, default_main_program
import numpy as np import numpy as np
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
...@@ -45,11 +45,21 @@ class Initializer(object): ...@@ -45,11 +45,21 @@ class Initializer(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, param, block): def __call__(self, param, block=None):
"""Add corresponding initialization operations to the network """Add corresponding initialization operations to the network
""" """
raise NotImplementedError() raise NotImplementedError()
def _check_block(self, block):
if block is None:
if in_dygraph_mode():
block = default_main_program().global_block()
else:
raise ValueError(
"The parameter 'block' is needed in static graph mode.")
return block
def _compute_fans(self, var): def _compute_fans(self, var):
"""Compute the fan_in and the fan_out for layers """Compute the fan_in and the fan_out for layers
...@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer): ...@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer):
self._value = value self._value = value
self._force_cpu = force_cpu self._force_cpu = force_cpu
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add constant initialization ops for a variable """Initialize the input tensor with constant.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -208,17 +220,19 @@ class UniformInitializer(Initializer): ...@@ -208,17 +220,19 @@ class UniformInitializer(Initializer):
self._diag_step = diag_step self._diag_step = diag_step
self._diag_val = diag_val self._diag_val = diag_val
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add uniform distribution initialization ops for a variable """Initialize the input tensor with Uniform distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"uniform_random") "uniform_random")
...@@ -297,17 +311,19 @@ class NormalInitializer(Initializer): ...@@ -297,17 +311,19 @@ class NormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add normal distribution initialization ops for a variable """Initialize the input tensor with Normal distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
...@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer): ...@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add truncated normal distribution initialization ops for a variable """Initialize the input tensor with TruncatedNormal distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
...@@ -490,17 +508,19 @@ class XavierInitializer(Initializer): ...@@ -490,17 +508,19 @@ class XavierInitializer(Initializer):
self._fan_out = fan_out self._fan_out = fan_out
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add xavier initialization ops for a variable """Initialize the input tensor with Xavier initialization.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"xavier_init") "xavier_init")
...@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer): ...@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer):
self._fan_in = fan_in self._fan_in = fan_in
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add MSRA initialization ops for a variable """Initialize the input tensor with MSRA initialization.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
...@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer): ...@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer):
""" """
super(BilinearInitializer, self).__init__() super(BilinearInitializer, self).__init__()
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add bilinear initialization ops for a variable """Initialize the input tensor with Bilinear initialization.
Args: Args:
var (Variable): Variable that needs to be initialized. var(Tensor): Tensor that needs to be initialized.
block (Block): The block in which initialization ops should block(Block, optional): The block in which initialization ops
be added. should be added. Used in static graph only, default None.
Returns: Returns:
Operator: the initialization op The initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
""" """
block = self._check_block(block)
if not isinstance(var, framework.Variable): if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.") raise ValueError("var must be framework.Variable.")
...@@ -855,7 +874,7 @@ class NumpyArrayInitializer(Initializer): ...@@ -855,7 +874,7 @@ class NumpyArrayInitializer(Initializer):
super(NumpyArrayInitializer, self).__init__() super(NumpyArrayInitializer, self).__init__()
self._value = value self._value = value
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add constant initialization ops for a variable """Add constant initialization ops for a variable
Args: Args:
...@@ -866,6 +885,9 @@ class NumpyArrayInitializer(Initializer): ...@@ -866,6 +885,9 @@ class NumpyArrayInitializer(Initializer):
Returns: Returns:
the initialization op the initialization op
""" """
if block is None:
block = default_main_program().global_block()
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册