未验证 提交 bcaf88d2 编写于 作者: W wangguanqun 提交者: GitHub

Ps optimizer multi programs (#39883)

* fix benchmark and communicator config

* fix bugs of the_one_ps

* multi program and fix bug in optimizer

* multi program in the_one_ps

* public commcontext

* ps optimizer multi programs

* the one ps merge

* fix bug in test
上级 faece382
......@@ -29,7 +29,6 @@ from ..fluid.layers import utils
from ..fluid.dygraph import layers
from ..fluid.dygraph.parallel import prepare_context
import paddle
from .fleet import fleet
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle import _C_ops
......@@ -1422,6 +1421,7 @@ def split(x,
"graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
"ParallelColumnLinear instead.")
else:
from .fleet import fleet
assert fleet._role_maker, ("To use paddle.distributed.split, "
"you must call fleet.init() firstly.")
rank = fleet.worker_index()
......
......@@ -31,14 +31,19 @@ class ParameterServerOptimizer(MetaOptimizerBase):
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.pass_ctx = PassContext()
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
super(ParameterServerOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _set_origin_programs(self, losses):
self.origin_main_programs = []
for loss in losses:
self.origin_main_programs.append(loss.block.program)
def _init_ps_pass_context(self, loss, startup_program):
self.pass_ctx = PassContext()
attrs = {}
# trainer
attrs["env"] = get_dist_env()
......@@ -46,9 +51,9 @@ class ParameterServerOptimizer(MetaOptimizerBase):
attrs['loss'] = loss
attrs['min_block_size'] = 81920
attrs['origin_main_program'] = loss.block.program
attrs['origin_main_programs'] = [loss.block.program]
attrs['origin_startup_program'] = startup_program
attrs['origin_startup_programs'] = [startup_program]
attrs['origin_main_programs'] = self.origin_main_programs
attrs['cloned_main'] = attrs['origin_main_program'].clone()
attrs['cloned_startup'] = attrs['origin_startup_program'].clone()
......@@ -90,10 +95,11 @@ class ParameterServerOptimizer(MetaOptimizerBase):
return False
def _can_apply(self):
if self._attrs['role_maker']._is_collective or self._attrs[
'k_steps'] < 0:
if self.role_maker._is_collective:
return False
return True
k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
return True if k_steps >= 0 else False
def minimize_impl(self,
loss,
......@@ -104,12 +110,37 @@ class ParameterServerOptimizer(MetaOptimizerBase):
no_grad_set)
if startup_program == None:
startup_program = paddle.static.default_startup_program()
print("program after inner optimizer minimize:",
str(loss.block.program))
self._set_origin_programs([loss])
self._init_ps_pass_context(loss, startup_program)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx)
ps_builder._build_programs()
return None, None
def minimize_losses_impl(self,
losses,
startup_program=None,
parameter_list=None,
no_grad_set=None):
if parameter_list is None:
parameter_list = [None] * len(losses)
for idx, loss in enumerate(losses):
startup_prog = startup_program[idx]
parameters = parameter_list[idx]
self.inner_opt.minimize(loss, startup_prog, parameters, no_grad_set)
self._set_origin_programs(losses)
for idx, loss in enumerate(losses):
print("ps_optimizer idx loss:", idx, loss)
startup_prog = startup_program[idx]
self._init_ps_pass_context(loss, startup_prog)
ps_builder = PsProgramBuilderFactory()._create_ps_program_builder(
self.pass_ctx)
ps_builder._build_programs()
startup_program[idx] = self.pass_ctx._attrs['cloned_startup']
return None, None
def _can_apply_geo(self, program):
def get_sys_free_mem():
plat = platform.system()
......
......@@ -74,6 +74,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
attrs = pass_ctx._attrs
print("pass loss program id:", id(attrs['loss'].block.program))
print("pass main program id:", id(main_program))
ps_mode = attrs['ps_mode']
if ps_mode == DistributedMode.GEO:
send_ctx = get_geo_trainer_send_context(attrs) # geo 模式
......@@ -84,6 +86,8 @@ class AppendSendOpsPass(PassBase): # 该 pass 被多种模式复用
for merged_name, send in send_ctx.items():
if send.is_sparse() and ps_mode != DistributedMode.GEO:
continue
if send.program_id() != id(attrs['loss'].block.program):
continue
logger.info('merged_name, send: {}, {}'.format(merged_name, send))
is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse
......@@ -496,6 +500,7 @@ class DeleteOptimizesPass(PassBase):
persistable=True)
def _apply_single_impl(self, main_program, startup_program, pass_ctx):
print("delete_optimizer_pass")
attrs = pass_ctx._attrs
optimizer_ops = get_optimize_ops(main_program)
lr_ops = get_lr_ops(main_program)
......
......@@ -40,12 +40,12 @@ def get_program_by_id(context, program_id):
programs = context["origin_main_programs"]
for i, program in enumerate(programs):
if id(program) == program_id:
return program, context["origin_startup_programs"][i]
return None, None
return program, context["origin_startup_programs"][i], i
return None, None, None
def parse_table_class(varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
main_program, startup_program, idx = get_program_by_id(context, program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -60,7 +60,7 @@ def parse_table_class(varname, program_id, context):
def check_embedding_dim(accessor_proto, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
main_program, startup_program, idx = get_program_by_id(context, program_id)
embedding_dim = 0
for var in main_program.list_vars():
if var.name == varname:
......@@ -94,10 +94,9 @@ class Service:
class GpuService(Service):
def __init__(self):
super(GpuService).__init__(self)
super(GpuService, self).__init__()
def _set(self, service_proto):
super(GpuService)._set(service_proto)
service_proto.server_class = 'PsLocalServer'
service_proto.client_class = 'PsLocalClient'
......@@ -111,7 +110,8 @@ class Accessor:
# TableAccessorParameter accessor
def _set(self, accessor_proto, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
main_program, startup_program, idx = get_program_by_id(context,
program_id)
embedding_dim = 0
for var in main_program.list_vars():
if var.name == varname:
......@@ -236,7 +236,8 @@ class CommonAccessor(Accessor):
self.opt_init_map = opt_init_map
def parse_entry(self, varname, program_id, context):
main_program, startup_program = get_program_by_id(context, program_id)
main_program, startup_program, idx = get_program_by_id(context,
program_id)
for op in main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
......@@ -290,8 +291,8 @@ class CommonAccessor(Accessor):
print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
ctx.table_id(), ctx.is_datanorm_table()))
main_program, startup_program = get_program_by_id(context,
ctx.program_id())
main_program, startup_program, idx = get_program_by_id(context,
ctx.program_id())
pserver_id = get_role_id(context['role_maker'])
pserver_num = len(get_ps_endpoints(context['role_maker']))
optimizer_ops = get_optimize_ops(main_program)
......@@ -359,10 +360,11 @@ class CommonAccessor(Accessor):
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
#TODO: for dense learning_rate, can be different from sparse lr
if formal_name == "LearningRate" and param.name != "learning_rate_0":
if formal_name == "LearningRate" and param.name != "learning_rate_" + str(
idx):
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
"learning_rate_" + str(idx)]
initializer = self.get_initializer_attr(param.name,
startup_program)
......@@ -404,10 +406,11 @@ class CommonAccessor(Accessor):
else:
param = main_program.global_block().vars[oop.input(
formal_name)[0]]
if formal_name == "LearningRate" and param.name != "learning_rate_0":
if formal_name == "LearningRate" and param.name != "learning_rate_" + str(
idx):
warnings.warn("will support decay soon")
param = main_program.global_block().vars[
"learning_rate_0"]
"learning_rate_" + str(idx)]
if shape is None:
if is_sparse:
......@@ -707,6 +710,7 @@ class PsDescBuilder(object):
self.ps_mode = context['ps_mode']
self.is_heter_ps_mode = context['is_heter_ps_mode']
self.use_ps_gpu = context['use_ps_gpu']
self.barrier_table_id = None
self.send_ctx = get_the_one_send_context(
self.context,
use_origin_program=True,
......@@ -767,6 +771,8 @@ class PsDescBuilder(object):
table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add(
)
table._set(table_proto)
if type(table) == BarrierTable and self.barrier_table_id is None:
self.barrier_table_id = table.idx
self.service._set(
self.ps_desc.server_param.downpour_server_param.service_param)
return text_format.MessageToString(self.ps_desc)
......@@ -820,9 +826,9 @@ class TheOnePSRuntime(RuntimeBase):
self.context['tensor_table'] = {}
build_var_distributed(self.context)
endpoints = get_ps_endpoints(self.role_maker)
self.endpoints = get_ps_endpoints(self.role_maker)
self.string_hosts = []
for idx, ep in enumerate(endpoints):
for idx, ep in enumerate(self.endpoints):
host, port = ep.split(":")
pshost = fluid.core.PSHost(host, int(port), idx)
self.string_hosts.append(pshost.serialize_to_string())
......@@ -848,7 +854,7 @@ class TheOnePSRuntime(RuntimeBase):
kwargs["trainer_id"] = self.role_maker._worker_index()
return kwargs
proto_txt = worker_desc + "\n" + server_desc
proto_txt = worker_desc
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
if debug:
print("worker: \n{}".format(proto_txt))
......@@ -859,7 +865,7 @@ class TheOnePSRuntime(RuntimeBase):
self.context,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode,
ep_list=endpoints)
ep_list=self.endpoints)
trainer_config = self.context['trainer']
debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
......@@ -876,10 +882,7 @@ class TheOnePSRuntime(RuntimeBase):
kwargs["trainer_id"] = self.role_maker._role_id()
kwargs["trainers"] = self.role_maker._worker_num()
for table in server.servers[0].tables: #TODO
if table.table_class == "BarrierTable":
kwargs["barrier_table_id"] = table.id
break
kwargs["barrier_table_id"] = self.ps_desc_builder.barrier_table_id
if self.context['ps_mode'] == DistributedMode.SYNC:
sync_kwargs = sync_strategy_envs()
......@@ -1009,7 +1012,7 @@ class TheOnePSRuntime(RuntimeBase):
if origin_varname.endswith("@GRAD"):
return False
if origin_varname == "learning_rate_0":
if origin_varname.startswith("learning_rate_"):
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
......@@ -1113,7 +1116,7 @@ class TheOnePSRuntime(RuntimeBase):
"in fleet.save() function, executor must be as Executor type")
if main_program is None:
main_program = self.context['origin_ps_main_program']
main_program = self.context['origin_main_program']
if isinstance(main_program, CompiledProgram):
raise TypeError(
......
......@@ -88,7 +88,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式
self.attrs['origin_main_program'] = self.cloned_main
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
wait_server_ready(self.server_endpoints)
return
......@@ -103,10 +103,13 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
format(self.ps_mode, "PsProgramBuilder"))
def _build_trainer_programs(self):
print("build trainer program entry")
print("before ps program builder program:", self.cloned_main)
add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass",
self.attrs)
add_lr_decay_table_pass.apply([], [], self.pass_ctx)
print("before distributed op pass")
distributed_ops_pass = new_pass("distributed_ops_pass", self.attrs)
distributed_ops_pass.apply([self.cloned_main], [None], self.pass_ctx)
......@@ -126,9 +129,10 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder):
self.attrs['origin_main_program'] = self.cloned_main
self.attrs['origin_startup_program'] = self.cloned_startup
print("after ps program builder program:", self.cloned_main)
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
wait_server_ready(self.server_endpoints)
return
......@@ -167,7 +171,7 @@ class GpuPsProgramBuilder(PsProgramBuilder):
self.attrs['origin_startup_program'] = self.cloned_startup
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
wait_server_ready(self.server_endpoints)
return
......@@ -220,7 +224,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder):
[self.cloned_startup], self.pass_ctx)
if self.launch_barrier and self.launch_barrier_flag:
wait_server_ready(server_endpoints)
wait_server_ready(self.server_endpoints)
return
......
......@@ -450,9 +450,8 @@ def get_the_one_send_context(context,
idx = 0
for i, program in enumerate(origin_programs):
merged_dense_pairs = context['merged_dense_pairs'][i]
idx += get_dense_send_context(program, send_ctx, idx,
merged_dense_pairs, trainer_id,
split_dense_table)
idx = get_dense_send_context(program, send_ctx, idx, merged_dense_pairs,
trainer_id, split_dense_table)
distibuted_varnames = get_sparse_tablenames(origin_programs, True)
print("public distibuted_varnames:", distibuted_varnames)
for i, program in enumerate(origin_programs):
......
......@@ -146,9 +146,13 @@ class TestPsTrainerPass(PsPassTestBase):
self.config['ps_mode_config'] = "../ps/gpu_ps_config.yaml"
self.config['debug_new_minimize'] = '0'
self.config['log_dir'] = ps_log_root_dir + "gpubox_log_old_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch("gpu-ps")
self.config['debug_new_minimize'] = '1'
self.config['log_dir'] = ps_log_root_dir + "gpubox_log_new_minimize"
remove_path_if_exists(self.config['log_dir'])
self.ps_launch("gpu-ps")
file1 = '/ps_log/gpubox_run_minimize_debug:_0_worker_main.prototxt'
......
......@@ -382,6 +382,7 @@ class DnnTrainer(object):
ps_optimizer = ParameterServerOptimizer(inner_optimizer)
ps_optimizer._set_basic_info(loss, self.role_maker, inner_optimizer,
user_defined_strategy)
ps_optimizer._set_origin_programs([loss])
ps_optimizer._init_ps_pass_context(loss, startup_program)
_main = ps_optimizer.pass_ctx._attrs['cloned_main']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册