提交 8e9f338c 编写于 作者: Q qiaolongfei

add _is_duplicated_init_op

上级 37e069eb
...@@ -62,13 +62,23 @@ def init_on_cpu(): ...@@ -62,13 +62,23 @@ def init_on_cpu():
_force_init_on_cpu_ = pre_state _force_init_on_cpu_ = pre_state
def _is_inited_by(block, var, init_op_type): def _is_inited_by(block, var, init_op_types):
for op in block.ops: for op in block.ops:
if var.name in op.output_arg_names and op.type == init_op_type: if var.name in op.output_arg_names and op.type in init_op_types:
return op return op
return None return None
def _is_duplicated_init_op(op1, op2):
if op1.block == op2.block and \
op1.type == op2.type and \
op1.input_arg_names == op2.output_arg_names and \
op1.idx != op2.idx and \
op1.all_attrs == op2.all_attrs:
return True
return False
class Initializer(object): class Initializer(object):
"""Base class for variable initializers """Base class for variable initializers
...@@ -154,9 +164,7 @@ class ConstantInitializer(Initializer): ...@@ -154,9 +164,7 @@ class ConstantInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'fill_constant') init_op = _is_inited_by(block, var, ['fill_constant'])
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
op = block._prepend_op( op = block._prepend_op(
type="fill_constant", type="fill_constant",
...@@ -167,6 +175,9 @@ class ConstantInitializer(Initializer): ...@@ -167,6 +175,9 @@ class ConstantInitializer(Initializer):
"value": float(self._value), "value": float(self._value),
'force_cpu': self._force_cpu or force_init_on_cpu() 'force_cpu': self._force_cpu or force_init_on_cpu()
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
...@@ -209,9 +220,7 @@ class UniformInitializer(Initializer): ...@@ -209,9 +220,7 @@ class UniformInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random') init_op = _is_inited_by(block, var, ['uniform_random'])
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -225,6 +234,9 @@ class UniformInitializer(Initializer): ...@@ -225,6 +234,9 @@ class UniformInitializer(Initializer):
"max": self._high, "max": self._high,
"seed": self._seed "seed": self._seed
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
...@@ -266,9 +278,7 @@ class NormalInitializer(Initializer): ...@@ -266,9 +278,7 @@ class NormalInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'gaussian_random') init_op = _is_inited_by(block, var, ['gaussian_random'])
if init_op is not None:
return init_op
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
...@@ -282,6 +292,9 @@ class NormalInitializer(Initializer): ...@@ -282,6 +292,9 @@ class NormalInitializer(Initializer):
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed "seed": self._seed
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
...@@ -351,9 +364,8 @@ class XavierInitializer(Initializer): ...@@ -351,9 +364,8 @@ class XavierInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random') init_op = _is_inited_by(block, var,
if init_op is not None: ['uniform_random', 'gaussian_random'])
return init_op
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
...@@ -389,6 +401,9 @@ class XavierInitializer(Initializer): ...@@ -389,6 +401,9 @@ class XavierInitializer(Initializer):
"std": std, "std": std,
"seed": self._seed "seed": self._seed
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
...@@ -454,13 +469,8 @@ class MSRAInitializer(Initializer): ...@@ -454,13 +469,8 @@ class MSRAInitializer(Initializer):
""" """
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
init_op = _is_inited_by(block, var, 'uniform_random') init_op = _is_inited_by(block, var,
if init_op is not None: ['uniform_random', 'gaussian_random'])
return init_op
init_op = _is_inited_by(block, var, 'gaussian_random')
if init_op is not None:
return init_op
f_in, f_out = self._compute_fans(var) f_in, f_out = self._compute_fans(var)
...@@ -495,6 +505,9 @@ class MSRAInitializer(Initializer): ...@@ -495,6 +505,9 @@ class MSRAInitializer(Initializer):
"std": std, "std": std,
"seed": self._seed "seed": self._seed
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
...@@ -561,8 +574,6 @@ class BilinearInitializer(Initializer): ...@@ -561,8 +574,6 @@ class BilinearInitializer(Initializer):
raise ValueError("block must be framework.Block.") raise ValueError("block must be framework.Block.")
init_op = _is_inited_by(block, var, 'assign_value') init_op = _is_inited_by(block, var, 'assign_value')
if init_op is not None:
return init_op
shape = var.shape shape = var.shape
if len(shape) != 4: if len(shape) != 4:
...@@ -597,6 +608,9 @@ class BilinearInitializer(Initializer): ...@@ -597,6 +608,9 @@ class BilinearInitializer(Initializer):
'shape': list(shape), 'shape': list(shape),
value_name: values value_name: values
}) })
if init_op is not None and _is_duplicated_init_op(init_op, op):
block._remove_op(0)
return init_op
var.op = op var.op = op
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册