diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 343901cda3f505c3b3d2ed0c30cf7fea71c8b6b1..f086600702eeafc7948e168d77dfbd1d1c4b901c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -309,10 +309,10 @@ class DistributeTranspiler(object): def get_pserver_program(self, endpoint): """ Get parameter server side program. - + Args: endpoint (str): current parameter server endpoint. - + Returns: Program: the program for current parameter server to run. """ @@ -514,7 +514,7 @@ class DistributeTranspiler(object): endpoint (str): current pserver endpoint. pserver_program (Program): call get_pserver_program first and pass the result here. - + Returns: Program: parameter server side startup program. """ @@ -550,10 +550,10 @@ class DistributeTranspiler(object): op_on_pserver = True new_outputs[key] = pserver_vars[op.output(key)[0]] - # most startup program ops have no inputs - new_inputs = self._get_input_map_from_op(pserver_vars, op) - if op_on_pserver: + # most startup program ops have no inputs + new_inputs = self._get_input_map_from_op(pserver_vars, op) + if op.type in [ "gaussian_random", "fill_constant", "uniform_random" ]: