diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 05fed72ee6471ba42007b5a9f09f89148ac27a30..53d6ca86a008f798af2854a154cce8b7242d2f35 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -309,10 +309,10 @@ class DistributeTranspiler(object): def get_pserver_program(self, endpoint): """ Get parameter server side program. - + Args: endpoint (str): current parameter server endpoint. - + Returns: Program: the program for current parameter server to run. """ @@ -516,7 +516,7 @@ class DistributeTranspiler(object): endpoint (str): current pserver endpoint. pserver_program (Program): call get_pserver_program first and pass the result here. - + Returns: Program: parameter server side startup program. """ @@ -552,10 +552,10 @@ class DistributeTranspiler(object): op_on_pserver = True new_outputs[key] = pserver_vars[op.output(key)[0]] - # most startup program ops have no inputs - new_inputs = self._get_input_map_from_op(pserver_vars, op) - if op_on_pserver: + # most startup program ops have no inputs + new_inputs = self._get_input_map_from_op(pserver_vars, op) + if op.type in [ "gaussian_random", "fill_constant", "uniform_random" ]: