未验证 提交 26b61691 编写于 作者: L Leo Chen 提交者: GitHub

Refine initializer to support not passing block in dygraph mode (#27612)

* support not passing block

* fix NumpyArrayInitializer

* add unit test
上级 0873644c
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
from . import framework from . import framework
from . import core from . import core
from .framework import in_dygraph_mode from .framework import in_dygraph_mode, default_main_program
import numpy as np import numpy as np
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
...@@ -45,11 +45,21 @@ class Initializer(object): ...@@ -45,11 +45,21 @@ class Initializer(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, param, block): def __call__(self, param, block=None):
"""Add corresponding initialization operations to the network """Add corresponding initialization operations to the network
""" """
raise NotImplementedError() raise NotImplementedError()
def _check_block(self, block):
if block is None:
if in_dygraph_mode():
block = default_main_program().global_block()
else:
raise ValueError(
"The parameter 'block' is needed in static graph mode.")
return block
def _compute_fans(self, var): def _compute_fans(self, var):
"""Compute the fan_in and the fan_out for layers """Compute the fan_in and the fan_out for layers
...@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer): ...@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer):
self._value = value self._value = value
self._force_cpu = force_cpu self._force_cpu = force_cpu
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add constant initialization ops for a variable """Initialize the input tensor with constant.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -208,17 +220,19 @@ class UniformInitializer(Initializer): ...@@ -208,17 +220,19 @@ class UniformInitializer(Initializer):
self._diag_step = diag_step self._diag_step = diag_step
self._diag_val = diag_val self._diag_val = diag_val
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add uniform distribution initialization ops for a variable """Initialize the input tensor with Uniform distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"uniform_random") "uniform_random")
...@@ -297,17 +311,19 @@ class NormalInitializer(Initializer): ...@@ -297,17 +311,19 @@ class NormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add normal distribution initialization ops for a variable """Initialize the input tensor with Normal distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
...@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer): ...@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer):
self._std_dev = scale self._std_dev = scale
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add truncated normal distribution initialization ops for a variable """Initialize the input tensor with TruncatedNormal distribution.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
...@@ -490,17 +508,19 @@ class XavierInitializer(Initializer): ...@@ -490,17 +508,19 @@ class XavierInitializer(Initializer):
self._fan_out = fan_out self._fan_out = fan_out
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add xavier initialization ops for a variable """Initialize the input tensor with Xavier initialization.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"], check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"xavier_init") "xavier_init")
...@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer): ...@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer):
self._fan_in = fan_in self._fan_in = fan_in
self._seed = seed self._seed = seed
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add MSRA initialization ops for a variable """Initialize the input tensor with MSRA initialization.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
...@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer): ...@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer):
""" """
super(BilinearInitializer, self).__init__() super(BilinearInitializer, self).__init__()
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add bilinear initialization ops for a variable """Initialize the input tensor with Bilinear initialization.
Args: Args:
var (Variable): Variable that needs to be initialized. var(Tensor): Tensor that needs to be initialized.
block (Block): The block in which initialization ops should block(Block, optional): The block in which initialization ops
be added. should be added. Used in static graph only, default None.
Returns: Returns:
Operator: the initialization op The initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
""" """
block = self._check_block(block)
if not isinstance(var, framework.Variable): if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.") raise ValueError("var must be framework.Variable.")
...@@ -855,17 +874,19 @@ class NumpyArrayInitializer(Initializer): ...@@ -855,17 +874,19 @@ class NumpyArrayInitializer(Initializer):
super(NumpyArrayInitializer, self).__init__() super(NumpyArrayInitializer, self).__init__()
self._value = value self._value = value
def __call__(self, var, block): def __call__(self, var, block=None):
"""Add constant initialization ops for a variable """Initialize the input tensor with Numpy array.
Args: Args:
var: Variable that needs to be initialized var(Tensor): Tensor that needs to be initialized.
block: The block in which initialization ops block(Block, optional): The block in which initialization ops
should be added should be added. Used in static graph only, default None.
Returns: Returns:
the initialization op The initialization op
""" """
block = self._check_block(block)
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
import paddle.fluid.initializer as initializer import paddle.fluid.initializer as initializer
...@@ -31,6 +32,14 @@ def check_cast_op(op): ...@@ -31,6 +32,14 @@ def check_cast_op(op):
op.attr('out_dtype') == VarDesc.VarType.FP16 op.attr('out_dtype') == VarDesc.VarType.FP16
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestConstantInitializer(unittest.TestCase): class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_default_value(self, dtype="float32"): def test_constant_initializer_default_value(self, dtype="float32"):
"""Test the constant initializer with default value """Test the constant initializer with default value
...@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase): ...@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase):
fluid.set_global_initializer(None) fluid.set_global_initializer(None)
class TestUniformInitializerDygraph(unittest.TestCase):
def test_uniform_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()
tensor = paddle.zeros([1024, 1024])
tensor.stop_gradient = False
self.assertTrue(np.allclose(np.zeros((1024, 1024)), tensor.numpy()))
uniform_ = paddle.nn.initializer.Uniform()
uniform_(tensor)
self.assertEqual(tensor.stop_gradient,
False) # stop_gradient is not changed
hist, prob = output_hist(tensor.numpy())
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=1e-3), "hist: " + str(hist))
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册