提交 00b4073c 编写于 作者: Z zhiqiu

support not passing block

上级 6b727e08
......@@ -16,7 +16,7 @@ from __future__ import print_function
from . import framework
from . import core
from .framework import in_dygraph_mode
from .framework import in_dygraph_mode, default_main_program
import numpy as np
from .core import VarDesc
from . import unique_name
......@@ -45,11 +45,21 @@ class Initializer(object):
def __init__(self):
pass
def __call__(self, param, block):
def __call__(self, param, block=None):
"""Add corresponding initialization operations to the network
"""
raise NotImplementedError()
def _check_block(self, block):
if block is None:
if in_dygraph_mode():
block = default_main_program().global_block()
else:
raise ValueError(
"The parameter 'block' is needed in static graph mode.")
return block
def _compute_fans(self, var):
"""Compute the fan_in and the fan_out for layers
......@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer):
self._value = value
self._force_cpu = force_cpu
def __call__(self, var, block):
"""Add constant initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with constant.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
......@@ -208,17 +220,19 @@ class UniformInitializer(Initializer):
self._diag_step = diag_step
self._diag_val = diag_val
def __call__(self, var, block):
"""Add uniform distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Uniform distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"uniform_random")
......@@ -297,17 +311,19 @@ class NormalInitializer(Initializer):
self._std_dev = scale
self._seed = seed
def __call__(self, var, block):
"""Add normal distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Normal distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
......@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer):
self._std_dev = scale
self._seed = seed
def __call__(self, var, block):
"""Add truncated normal distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with TruncatedNormal distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
......@@ -490,17 +508,19 @@ class XavierInitializer(Initializer):
self._fan_out = fan_out
self._seed = seed
def __call__(self, var, block):
"""Add xavier initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Xavier initialization.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"xavier_init")
......@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer):
self._fan_in = fan_in
self._seed = seed
def __call__(self, var, block):
"""Add MSRA initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with MSRA initialization.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var)
......@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer):
"""
super(BilinearInitializer, self).__init__()
def __call__(self, var, block):
"""Add bilinear initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Bilinear initialization.
Args:
var (Variable): Variable that needs to be initialized.
block (Block): The block in which initialization ops should
be added.
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
Operator: the initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
The initialization op
"""
block = self._check_block(block)
if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.")
......@@ -855,7 +874,7 @@ class NumpyArrayInitializer(Initializer):
super(NumpyArrayInitializer, self).__init__()
self._value = value
def __call__(self, var, block):
def __call__(self, var, block=None):
"""Add constant initialization ops for a variable
Args:
......@@ -866,6 +885,9 @@ class NumpyArrayInitializer(Initializer):
Returns:
the initialization op
"""
if block is None:
block = default_main_program().global_block()
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册