未验证 提交 9ae55dd7 编写于 作者: W Wu Yi 提交者: GitHub

fix dist transpile with memopt (#12974)

* fix dist transpile with memopt

* update api.spec

* polish dist transpile api

* update apispec

* update apispec
上级 902f19b4
......@@ -55,9 +55,10 @@ paddle.fluid.Inferencer.__init__ ArgSpec(args=['self', 'infer_func', 'param_path
paddle.fluid.Inferencer.infer ArgSpec(args=['self', 'inputs', 'return_numpy'], varargs=None, keywords=None, defaults=(True,))
paddle.fluid.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
paddle.fluid.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.InferenceTranspiler.__init__
paddle.fluid.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
......@@ -335,9 +336,10 @@ paddle.fluid.contrib.BeamSearchDecoder.update_array ArgSpec(args=['self', 'array
paddle.fluid.contrib.memory_usage ArgSpec(args=['program', 'batch_size'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_startup_program ArgSpec(args=['self', 'endpoint', 'pserver_program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.transpiler.DistributeTranspiler.get_trainer_program ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True))
paddle.fluid.transpiler.DistributeTranspiler.transpile ArgSpec(args=['self', 'trainer_id', 'program', 'pservers', 'trainers', 'sync_mode', 'startup_program'], varargs=None, keywords=None, defaults=(None, '127.0.0.1:6174', 1, True, None))
paddle.fluid.transpiler.InferenceTranspiler.__init__
paddle.fluid.transpiler.InferenceTranspiler.transpile ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.memory_optimize ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level'], varargs=None, keywords=None, defaults=(None, False, 0))
......
......@@ -31,6 +31,7 @@ Steps to transpile pserver:
"""
import math
import sys
import numpy as np
import collections
import six
......@@ -181,7 +182,8 @@ class DistributeTranspiler(object):
program=None,
pservers="127.0.0.1:6174",
trainers=1,
sync_mode=True):
sync_mode=True,
startup_program=None):
"""
Run the transpiler.
......@@ -194,13 +196,17 @@ class DistributeTranspiler(object):
list.
trainers (int): number of trainers in the distributed job.
sync_mode (bool): Do sync training or not, default is True.
startup_program (Program|None): startup_program to transpile,
default is fluid.default_main_program().
"""
if program is None:
program = default_main_program()
if startup_program is None:
startup_program = default_startup_program()
self.origin_program = program
self.origin_startup_program = default_startup_program().clone()
self.startup_program = startup_program
self.origin_startup_program = self.startup_program.clone()
self.startup_program = default_startup_program()
self.trainer_num = trainers
self.sync_mode = sync_mode
self.trainer_id = trainer_id
......@@ -376,20 +382,17 @@ class DistributeTranspiler(object):
return self.origin_program
def _get_trainer_startup_program(self,
recv_vars,
eplist,
startup_program=None):
def _get_trainer_startup_program(self, recv_vars, eplist):
"""
Get transpiled trainer side startup program.
Args:
startup_program(Program): Startup program.
recv_vars (list): Variable list to recv for current trainer_id
eplist (list): A list of strings indicating
Returns:
Program: trainer side startup program.
"""
if startup_program is None:
startup_program = self.startup_program
# FIXME(gongwb): delete not need ops.
......@@ -438,7 +441,18 @@ class DistributeTranspiler(object):
#add concat ops to merge splited parameters received from parameter servers.
if len(splited_var) <= 1:
continue
# NOTE: if enable memory optimization, origin vars maybe removed.
if startup_program.global_block().vars.has_key(varname):
orig_param = startup_program.global_block().vars[varname]
else:
origin_param_var = self.origin_program.global_block().vars[
varname]
orig_param = startup_program.global_block().create_var(
name=varname,
persistable=origin_param_var.persistable,
type=origin_param_var.type,
dtype=origin_param_var.dtype,
shape=origin_param_var.shape)
startup_program.global_block().append_op(
type="concat",
inputs={"X": splited_var},
......@@ -461,7 +475,9 @@ class DistributeTranspiler(object):
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
sys.stderr.write("get_pserver_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
......@@ -651,32 +667,58 @@ class DistributeTranspiler(object):
endpoint)
pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively.
self.pserver_program = pserver_program
return pserver_program
def get_pserver_programs(self, endpoint):
"""
Get pserver side main program and startup program for distributed training.
Args:
endpoint (str): current pserver endpoint.
Returns:
tuple: (main_program, startup_program), of type "Program"
"""
pserver_prog = self.get_pserver_program(endpoint)
pserver_startup = self.get_startup_program(endpoint)
return pserver_prog, pserver_startup
def get_startup_program(self,
endpoint,
pserver_program,
pserver_program=None,
startup_program=None):
"""
**Deprecated**
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
Args:
endpoint (str): current pserver endpoint.
pserver_program (Program): call get_pserver_program first and
pass the result here.
startup_program (Program): if pass None, will use
default_startup_program
pserver_program (Program): deprecated, call get_pserver_program first.
startup_program (Program): deprecated, should pass startup_program
when initalizing
Returns:
Program: parameter server side startup program.
"""
sys.stderr.write("get_startup_program() is deprecated, call\
get_pserver_programs() to get pserver main and startup\
in a single call.")
if pserver_program != None:
sys.stderr.write("passing pserver_program to get_startup_program()\
is deprecated, you can use new API get_pserver_programs() to\
get both pserver main program and startup program.")
if startup_program != None:
sys.stderr.write("passing startup_program to get_startup_program()\
is deprecated, use fluid.program_guard() or pass this argument\
to transpile() call.")
s_prog = Program()
if not startup_program:
orig_s_prog = default_startup_program()
else:
orig_s_prog = startup_program
orig_s_prog = self.startup_program
s_prog.random_seed = orig_s_prog.random_seed
params = self.param_grad_ep_mapping[endpoint]["params"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册