未验证 提交 4eac85c6 编写于 作者: D dzhwinter 提交者: GitHub

"add init seed" (#6221)

* "add init seed"

* "fix compile error"

* "add program level seed setting"

* "fixed based on comments"
上级 a0c1190f
......@@ -512,6 +512,7 @@ class Program(object):
self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
self._seed = 0
def __str__(self):
return self.to_string(True)
......@@ -564,6 +565,16 @@ class Program(object):
p.sync_with_cpp()
return p
@property
def random_seed(self):
return self._seed
@random_seed.setter
def random_seed(self, seed):
if not isinstance(seed, int):
raise ValueError("Seed must be a integer.")
self._seed = seed
def __repr__(self):
return str(self)
......
......@@ -132,6 +132,8 @@ class UniformInitializer(Initializer):
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op(
type="uniform_random",
outputs={"Out": var},
......@@ -180,6 +182,8 @@ class NormalInitializer(Initializer):
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op(
type="gaussian_random",
outputs={"Out": var},
......@@ -255,6 +259,9 @@ class XavierInitializer(Initializer):
fan_in = f_in if self._fan_in is None else self._fan_in
fan_out = f_out if self._fan_out is None else self._fan_out
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out))
op = block.prepend_op(
......@@ -338,6 +345,9 @@ class MSRAInitializer(Initializer):
# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform:
limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op(
......
......@@ -60,6 +60,29 @@ class TestUniformInitializer(unittest.TestCase):
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_uniform_initializer_random_seed(self):
"""Test the uniform initializer with manually setting seed
"""
program = framework.Program()
program.random_seed = 123
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer())
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(seed=456))
init_op = block.ops[1]
self.assertEqual(init_op.attr("seed"), 123)
init_op1 = block.ops[0]
self.assertEqual(init_op1.attr("seed"), 456)
def test_uniform_initializer(self):
"""Test uniform initializer with supplied attributes
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册