diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6635aedbceb3d08569a6361d04dc0daa66e6c925..c020ff45ad3f3a72bf8a88622df333c1765a3d21 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -55,7 +55,7 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,)) paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.InferenceTranspiler.__init__ @@ -328,7 +328,7 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) -paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program'], varargs=None, keywords=None, defaults=None) +paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True)) paddle.fluid.transpiler.InferenceTranspiler.__init__ diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index ff6e71bbfaf5aeb91150a8d4a48fa64be4421373..cd6cf558d5c4cd4a3fbced95e92c3f4dd94661af 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -530,7 +530,10 @@ class DistributeTranspiler(object): pserver_program._sync_with_cpp() return pserver_program - def get_startup_program(self, endpoint, pserver_program): + def get_startup_program(self, + endpoint, + pserver_program, + startup_program=None): """ Get startup program for current parameter server. Modify operator input variables if there are variables that @@ -540,12 +543,17 @@ class DistributeTranspiler(object): endpoint (str): current pserver endpoint. pserver_program (Program): call get_pserver_program first and pass the result here. + startup_program (Program): if pass None, will use + default_startup_program Returns: Program: parameter server side startup program. """ s_prog = Program() - orig_s_prog = default_startup_program() + if not startup_program: + orig_s_prog = default_startup_program() + else: + orig_s_prog = startup_program s_prog.random_seed = orig_s_prog.random_seed params = self.param_grad_ep_mapping[endpoint]["params"]