未验证 提交 26b61691 编写于 作者: L Leo Chen 提交者: GitHub

Refine initializer to support not passing block in dygraph mode (#27612)

* support not passing block

* fix NumpyArrayInitializer

* add unit test
上级 0873644c
......@@ -16,7 +16,7 @@ from __future__ import print_function
from . import framework
from . import core
from .framework import in_dygraph_mode
from .framework import in_dygraph_mode, default_main_program
import numpy as np
from .core import VarDesc
from . import unique_name
......@@ -45,11 +45,21 @@ class Initializer(object):
def __init__(self):
pass
def __call__(self, param, block):
def __call__(self, param, block=None):
"""Add corresponding initialization operations to the network
"""
raise NotImplementedError()
def _check_block(self, block):
if block is None:
if in_dygraph_mode():
block = default_main_program().global_block()
else:
raise ValueError(
"The parameter 'block' is needed in static graph mode.")
return block
def _compute_fans(self, var):
"""Compute the fan_in and the fan_out for layers
......@@ -108,17 +118,19 @@ class ConstantInitializer(Initializer):
self._value = value
self._force_cpu = force_cpu
def __call__(self, var, block):
"""Add constant initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with constant.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
......@@ -208,17 +220,19 @@ class UniformInitializer(Initializer):
self._diag_step = diag_step
self._diag_val = diag_val
def __call__(self, var, block):
"""Add uniform distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Uniform distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"uniform_random")
......@@ -297,17 +311,19 @@ class NormalInitializer(Initializer):
self._std_dev = scale
self._seed = seed
def __call__(self, var, block):
"""Add normal distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Normal distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
......@@ -381,17 +397,19 @@ class TruncatedNormalInitializer(Initializer):
self._std_dev = scale
self._seed = seed
def __call__(self, var, block):
"""Add truncated normal distribution initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with TruncatedNormal distribution.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
......@@ -490,17 +508,19 @@ class XavierInitializer(Initializer):
self._fan_out = fan_out
self._seed = seed
def __call__(self, var, block):
"""Add xavier initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Xavier initialization.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(block, framework.Block)
check_variable_and_dtype(var, "Out", ["float16", "float32", "float64"],
"xavier_init")
......@@ -620,17 +640,19 @@ class MSRAInitializer(Initializer):
self._fan_in = fan_in
self._seed = seed
def __call__(self, var, block):
"""Add MSRA initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with MSRA initialization.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var)
......@@ -745,22 +767,19 @@ class BilinearInitializer(Initializer):
"""
super(BilinearInitializer, self).__init__()
def __call__(self, var, block):
"""Add bilinear initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Bilinear initialization.
Args:
var (Variable): Variable that needs to be initialized.
block (Block): The block in which initialization ops should
be added.
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
Operator: the initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
The initialization op
"""
block = self._check_block(block)
if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.")
......@@ -855,17 +874,19 @@ class NumpyArrayInitializer(Initializer):
super(NumpyArrayInitializer, self).__init__()
self._value = value
def __call__(self, var, block):
"""Add constant initialization ops for a variable
def __call__(self, var, block=None):
"""Initialize the input tensor with Numpy array.
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
var(Tensor): Tensor that needs to be initialized.
block(Block, optional): The block in which initialization ops
should be added. Used in static graph only, default None.
Returns:
the initialization op
The initialization op
"""
block = self._check_block(block)
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.initializer as initializer
......@@ -31,6 +32,14 @@ def check_cast_op(op):
op.attr('out_dtype') == VarDesc.VarType.FP16
def output_hist(out):
hist, _ = np.histogram(out, range=(-1, 1))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_default_value(self, dtype="float32"):
"""Test the constant initializer with default value
......@@ -583,5 +592,31 @@ class TestSetGlobalInitializer(unittest.TestCase):
fluid.set_global_initializer(None)
class TestUniformInitializerDygraph(unittest.TestCase):
def test_uniform_initializer(self, dtype="float32"):
"""
In dygraph mode, we can use initializer directly to initialize a tensor.
"""
paddle.disable_static()
tensor = paddle.zeros([1024, 1024])
tensor.stop_gradient = False
self.assertTrue(np.allclose(np.zeros((1024, 1024)), tensor.numpy()))
uniform_ = paddle.nn.initializer.Uniform()
uniform_(tensor)
self.assertEqual(tensor.stop_gradient,
False) # stop_gradient is not changed
hist, prob = output_hist(tensor.numpy())
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=1e-3), "hist: " + str(hist))
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册