提交 fa94261b 编写于 作者: Q qiaolongfei

avoid duplicated init op for one parameter

上级 56b50ee4
...@@ -62,6 +62,13 @@ def init_on_cpu(): ...@@ -62,6 +62,13 @@ def init_on_cpu():
_force_init_on_cpu_ = pre_state _force_init_on_cpu_ = pre_state
def _is_inited_by(block, var, init_op_type):
for op in block.ops:
if var.name in op.output_arg_names and op.type == init_op_type:
return op
return None
class Initializer(object): class Initializer(object):
"""Base class for variable initializers """Base class for variable initializers
...@@ -147,6 +154,9 @@ class ConstantInitializer(Initializer): ...@@ -147,6 +154,9 @@ class ConstantInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
op = block._prepend_op( op = block._prepend_op(
type="fill_constant", type="fill_constant",
...@@ -199,6 +209,9 @@ class UniformInitializer(Initializer): ...@@ -199,6 +209,9 @@ class UniformInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -253,6 +266,9 @@ class NormalInitializer(Initializer): ...@@ -253,6 +266,9 @@ class NormalInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -335,6 +351,10 @@ class XavierInitializer(Initializer): ...@@ -335,6 +351,10 @@ class XavierInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
# If fan_in and fan_out are passed, use them # If fan_in and fan_out are passed, use them
...@@ -434,6 +454,10 @@ class MSRAInitializer(Initializer): ...@@ -434,6 +454,10 @@ class MSRAInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
# If fan_in is passed, use it # If fan_in is passed, use it
...@@ -532,6 +556,10 @@ class BilinearInitializer(Initializer): ...@@ -532,6 +556,10 @@ class BilinearInitializer(Initializer):
if not isinstance(block, framework.Block): if not isinstance(block, framework.Block):
raise ValueError("block must be framework.Block.") raise ValueError("block must be framework.Block.")
init_op = _is_inited_by(block, var, 'uniform_random')
if init_op is not None:
return init_op
shape = var.shape shape = var.shape
if len(shape) != 4: if len(shape) != 4:
raise ValueError("the length of shape must be 4.") raise ValueError("the length of shape must be 4.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册