未验证 提交 4eac85c6 编写于 作者: D dzhwinter 提交者: GitHub

"add init seed" (#6221)

* "add init seed"

* "fix compile error"

* "add program level seed setting"

* "fixed based on comments"
上级 a0c1190f
...@@ -512,6 +512,7 @@ class Program(object): ...@@ -512,6 +512,7 @@ class Program(object):
self.desc = core.ProgramDesc() self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
self._seed = 0
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
...@@ -564,6 +565,16 @@ class Program(object): ...@@ -564,6 +565,16 @@ class Program(object):
p.sync_with_cpp() p.sync_with_cpp()
return p return p
@property
def random_seed(self):
return self._seed
@random_seed.setter
def random_seed(self, seed):
if not isinstance(seed, int):
raise ValueError("Seed must be a integer.")
self._seed = seed
def __repr__(self): def __repr__(self):
return str(self) return str(self)
......
...@@ -132,6 +132,8 @@ class UniformInitializer(Initializer): ...@@ -132,6 +132,8 @@ class UniformInitializer(Initializer):
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op( op = block.prepend_op(
type="uniform_random", type="uniform_random",
outputs={"Out": var}, outputs={"Out": var},
...@@ -180,6 +182,8 @@ class NormalInitializer(Initializer): ...@@ -180,6 +182,8 @@ class NormalInitializer(Initializer):
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
op = block.prepend_op( op = block.prepend_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": var}, outputs={"Out": var},
...@@ -255,6 +259,9 @@ class XavierInitializer(Initializer): ...@@ -255,6 +259,9 @@ class XavierInitializer(Initializer):
fan_in = f_in if self._fan_in is None else self._fan_in fan_in = f_in if self._fan_in is None else self._fan_in
fan_out = f_out if self._fan_out is None else self._fan_out fan_out = f_out if self._fan_out is None else self._fan_out
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out)) limit = np.sqrt(6.0 / float(fan_in + fan_out))
op = block.prepend_op( op = block.prepend_op(
...@@ -338,6 +345,9 @@ class MSRAInitializer(Initializer): ...@@ -338,6 +345,9 @@ class MSRAInitializer(Initializer):
# If fan_in is passed, use it # If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in fan_in = f_in if self._fan_in is None else self._fan_in
if self._seed == 0:
self._seed = block.program.random_seed
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in)) limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op( op = block.prepend_op(
......
...@@ -60,6 +60,29 @@ class TestUniformInitializer(unittest.TestCase): ...@@ -60,6 +60,29 @@ class TestUniformInitializer(unittest.TestCase):
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA) self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0) self.assertEqual(init_op.attr('seed'), 0)
def test_uniform_initializer_random_seed(self):
"""Test the uniform initializer with manually setting seed
"""
program = framework.Program()
program.random_seed = 123
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer())
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(seed=456))
init_op = block.ops[1]
self.assertEqual(init_op.attr("seed"), 123)
init_op1 = block.ops[0]
self.assertEqual(init_op1.attr("seed"), 456)
def test_uniform_initializer(self): def test_uniform_initializer(self):
"""Test uniform initializer with supplied attributes """Test uniform initializer with supplied attributes
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册