diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 3911680bb6ac318ae45662665b5831c839e92ea5..8f891b19bb80aef69c224756993ed740cb953e47 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -62,13 +62,23 @@ def init_on_cpu(): _force_init_on_cpu_ = pre_state -def _is_inited_by(block, var, init_op_type): +def _is_inited_by(block, var, init_op_types): for op in block.ops: - if var.name in op.output_arg_names and op.type == init_op_type: + if var.name in op.output_arg_names and op.type in init_op_types: return op return None +def _is_duplicated_init_op(op1, op2): + if op1.block == op2.block and \ + op1.type == op2.type and \ + op1.input_arg_names == op2.output_arg_names and \ + op1.idx != op2.idx and \ + op1.all_attrs == op2.all_attrs: + return True + return False + + class Initializer(object): """Base class for variable initializers @@ -154,9 +164,7 @@ class ConstantInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) - init_op = _is_inited_by(block, var, 'fill_constant') - if init_op is not None: - return init_op + init_op = _is_inited_by(block, var, ['fill_constant']) # Initialization Ops should be prepended and not appended op = block._prepend_op( type="fill_constant", @@ -167,6 +175,9 @@ class ConstantInitializer(Initializer): "value": float(self._value), 'force_cpu': self._force_cpu or force_init_on_cpu() }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op @@ -209,9 +220,7 @@ class UniformInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) - init_op = _is_inited_by(block, var, 'uniform_random') - if init_op is not None: - return init_op + init_op = _is_inited_by(block, var, ['uniform_random']) # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed @@ -225,6 +234,9 @@ class UniformInitializer(Initializer): "max": self._high, "seed": self._seed }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op @@ -266,9 +278,7 @@ class NormalInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) - init_op = _is_inited_by(block, var, 'gaussian_random') - if init_op is not None: - return init_op + init_op = _is_inited_by(block, var, ['gaussian_random']) # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed @@ -282,6 +292,9 @@ class NormalInitializer(Initializer): "std": self._std_dev, "seed": self._seed }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op @@ -351,9 +364,8 @@ class XavierInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) - init_op = _is_inited_by(block, var, 'uniform_random') - if init_op is not None: - return init_op + init_op = _is_inited_by(block, var, + ['uniform_random', 'gaussian_random']) f_in, f_out = self._compute_fans(var) @@ -389,6 +401,9 @@ class XavierInitializer(Initializer): "std": std, "seed": self._seed }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op @@ -454,13 +469,8 @@ class MSRAInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) - init_op = _is_inited_by(block, var, 'uniform_random') - if init_op is not None: - return init_op - - init_op = _is_inited_by(block, var, 'gaussian_random') - if init_op is not None: - return init_op + init_op = _is_inited_by(block, var, + ['uniform_random', 'gaussian_random']) f_in, f_out = self._compute_fans(var) @@ -495,6 +505,9 @@ class MSRAInitializer(Initializer): "std": std, "seed": self._seed }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op @@ -561,8 +574,6 @@ class BilinearInitializer(Initializer): raise ValueError("block must be framework.Block.") init_op = _is_inited_by(block, var, 'assign_value') - if init_op is not None: - return init_op shape = var.shape if len(shape) != 4: @@ -597,6 +608,9 @@ class BilinearInitializer(Initializer): 'shape': list(shape), value_name: values }) + if init_op is not None and _is_duplicated_init_op(init_op, op): + block._remove_op(0) + return init_op var.op = op return op