diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 0e640bf280d396504deec1183821da3e8a156530..6f2768f59a131f6a2b5ee66fb535d645de75e5bb 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -62,6 +62,13 @@ def init_on_cpu(): _force_init_on_cpu_ = pre_state +def _is_inited_by(block, var, init_op_type): + for op in block.ops: + if var.name in op.output_arg_names and op.type == init_op_type: + return op + return None + + class Initializer(object): """Base class for variable initializers @@ -147,6 +154,9 @@ class ConstantInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op # Initialization Ops should be prepended and not appended op = block._prepend_op( type="fill_constant", @@ -199,6 +209,9 @@ class UniformInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed @@ -253,6 +266,9 @@ class NormalInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op # Initialization Ops should be prepended and not appended if self._seed == 0: self._seed = block.program.random_seed @@ -335,6 +351,10 @@ class XavierInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op + f_in, f_out = self._compute_fans(var) # If fan_in and fan_out are passed, use them @@ -434,6 +454,10 @@ class MSRAInitializer(Initializer): """ assert isinstance(var, framework.Variable) assert isinstance(block, framework.Block) + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op + f_in, f_out = self._compute_fans(var) # If fan_in is passed, use it @@ -532,6 +556,10 @@ class BilinearInitializer(Initializer): if not isinstance(block, framework.Block): raise ValueError("block must be framework.Block.") + init_op = _is_inited_by(block, var, 'uniform_random') + if init_op is not None: + return init_op + shape = var.shape if len(shape) != 4: raise ValueError("the length of shape must be 4.")