From d56a0a1bea4be064e03648cfa587db4b01ce3d27 Mon Sep 17 00:00:00 2001 From: wangguanqun Date: Tue, 22 Feb 2022 16:23:02 +0800 Subject: [PATCH] fix bug in new the_one_ps (#39505) * fix benchmark and communicator config * fix bugs of the_one_ps * multi program and fix bug in optimizer * multi program in the_one_ps * public commcontext --- .../communicator/communicator_common.h | 13 +- paddle/fluid/pybind/fleet_py.cc | 6 +- .../fleet/meta_optimizers/ps_optimizer.py | 2 + .../distributed/passes/ps_trainer_pass.py | 4 +- python/paddle/distributed/ps/the_one_ps.py | 160 +++++--- .../ps/utils/ps_program_builder.py | 8 +- python/paddle/distributed/ps/utils/public.py | 380 ++++++++++++------ .../fleet/parameter_server/ir/public.py | 11 +- 8 files changed, 400 insertions(+), 184 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator_common.h b/paddle/fluid/distributed/ps/service/communicator/communicator_common.h index 66784c53c00..27b282a945d 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator_common.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator_common.h @@ -31,7 +31,8 @@ struct CommContext { const std::vector &origin_names, int id, bool merge_add_ = true, bool is_sparse_ = true, bool is_distributed_ = false, int table_id_ = -1, - bool is_tensor_table_ = false) + bool is_tensor_table_ = false, bool is_datanorm_table_ = false, + int64_t program_id_ = -1) : var_name(name), splited_varnames(names), epmap(emap), @@ -42,7 +43,9 @@ struct CommContext { is_sparse(is_sparse_), is_distributed(is_distributed_), table_id(table_id_), - is_tensor_table(is_tensor_table_) {} + program_id(program_id_), + is_tensor_table(is_tensor_table_), + is_datanorm_table(is_datanorm_table_) {} CommContext(const CommContext &ctx) { var_name = ctx.var_name; @@ -55,7 +58,9 @@ struct CommContext { origin_varnames = ctx.origin_varnames; is_distributed = ctx.is_distributed; table_id = ctx.table_id; + program_id = ctx.program_id; is_tensor_table = ctx.is_tensor_table; + is_datanorm_table = ctx.is_datanorm_table; } std::string print() const { @@ -78,7 +83,9 @@ struct CommContext { ss << " is_sparse: " << is_sparse; ss << " is_distributed: " << is_distributed << "\n"; ss << " table_id: " << table_id << "\n"; + ss << " program_id: " << program_id << "\n"; ss << " is_tensor_table: " << is_tensor_table << "\n"; + ss << " is_datanorm_table: " << is_datanorm_table << "\n"; return ss.str(); } @@ -93,7 +100,9 @@ struct CommContext { bool is_sparse; bool is_distributed; int table_id; + int64_t program_id; bool is_tensor_table; + bool is_datanorm_table; }; } // namespace distributed diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 73c8f362d14..3145a9cf765 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -103,11 +103,13 @@ void BindCommunicatorContext(py::module* m) { py::init&, const std::vector&, const std::vector&, const std::vector&, int, bool, bool, bool, int, - bool>()) + bool, bool, int64_t>()) .def("var_name", [](const CommContext& self) { return self.var_name; }) .def("trainer_id", [](const CommContext& self) { return self.trainer_id; }) .def("table_id", [](const CommContext& self) { return self.table_id; }) + .def("program_id", + [](const CommContext& self) { return self.program_id; }) .def("split_varnames", [](const CommContext& self) { return self.splited_varnames; }) .def("split_endpoints", @@ -122,6 +124,8 @@ void BindCommunicatorContext(py::module* m) { [](const CommContext& self) { return self.origin_varnames; }) .def("is_tensor_table", [](const CommContext& self) { return self.is_tensor_table; }) + .def("is_datanorm_table", + [](const CommContext& self) { return self.is_datanorm_table; }) .def("__str__", [](const CommContext& self) { return self.print(); }); } diff --git a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py index bc50bef0109..100a6882b1b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ps_optimizer.py @@ -46,7 +46,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): attrs['loss'] = loss attrs['min_block_size'] = 81920 attrs['origin_main_program'] = loss.block.program + attrs['origin_main_programs'] = [loss.block.program] attrs['origin_startup_program'] = startup_program + attrs['origin_startup_programs'] = [startup_program] attrs['cloned_main'] = attrs['origin_main_program'].clone() attrs['cloned_startup'] = attrs['origin_startup_program'].clone() diff --git a/python/paddle/distributed/passes/ps_trainer_pass.py b/python/paddle/distributed/passes/ps_trainer_pass.py index 3f39db69abd..284365ce066 100755 --- a/python/paddle/distributed/passes/ps_trainer_pass.py +++ b/python/paddle/distributed/passes/ps_trainer_pass.py @@ -560,9 +560,9 @@ class FakeInitOpsPass(PassBase): return True def _get_sparse_table_names(self, attrs): - dist_varnames = get_sparse_tablenames(attrs['origin_main_program'], + dist_varnames = get_sparse_tablenames(attrs['origin_main_programs'], True) - sparse_varnames = get_sparse_tablenames(attrs['origin_main_program'], + sparse_varnames = get_sparse_tablenames(attrs['origin_main_programs'], False) return list(set(dist_varnames + sparse_varnames)) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index f842ca791f1..14a68ad9167 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -24,8 +24,8 @@ from paddle.fluid.compiler import CompiledProgram from paddle.fluid.executor import Executor from paddle.fluid.parallel_executor import ParallelExecutor from paddle.fluid.framework import Variable, Parameter -from .runtime_base import RuntimeBase -from ..base.private_helper_function import wait_server_ready +from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase +from paddle.distributed.fleet.base.private_helper_function import wait_server_ready from paddle.fluid.communicator import Communicator, HeterClient from google.protobuf import text_format @@ -39,8 +39,17 @@ def conv_indent(indent): PSERVER_SAVE_SUFFIX = ".shard" -def parse_table_class(varname, o_main_program): - for op in o_main_program.global_block().ops: +def get_program_by_id(context, program_id): + programs = context["origin_main_programs"] + for i, program in enumerate(programs): + if id(program) == program_id: + return program, context["origin_startup_programs"][i] + return None, None + + +def parse_table_class(varname, program_id, context): + main_program, startup_program = get_program_by_id(context, program_id) + for op in main_program.global_block().ops: if not is_distributed_sparse_op(op) and not is_sparse_op(op): continue @@ -53,9 +62,10 @@ def parse_table_class(varname, o_main_program): return "MemorySparseTable" -def get_default_accessor_proto(accessor, varname, o_main_program): +def get_default_accessor_proto(accessor, varname, program_id, context): + main_program, startup_program = get_program_by_id(context, program_id) embedding_dim = 0 - for var in o_main_program.list_vars(): + for var in main_program.list_vars(): if var.name == varname: embedding_dim = var.shape[1] break @@ -123,9 +133,10 @@ def get_default_accessor_proto(accessor, varname, o_main_program): sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) -def check_embedding_dim(accessor, varname, o_main_program): +def check_embedding_dim(accessor, varname, program_id, context): + main_program, startup_program = get_program_by_id(context, program_id) embedding_dim = 0 - for var in o_main_program.list_vars(): + for var in main_program.list_vars(): if var.name == varname: embedding_dim = var.shape[1] break @@ -172,6 +183,8 @@ class CommonAccessor: self.dims = [] self.trainer_num = 0 self.sync = "false" + self.table_num = None + self.table_dim = None self.initializers = [] self.opt_input_map = {} self.opt_attr_map = {} @@ -192,6 +205,7 @@ class CommonAccessor: opt_input_map["sum"] = [("Param", None)] opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), ("LearningRate", 1)] + opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)] opt_attr_map = {} opt_attr_map["sgd"] = [] @@ -201,6 +215,7 @@ class CommonAccessor: ("epsilon", "f")] opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] + opt_attr_map["summary"] = [] opt_init_map = {} opt_init_map["gaussian_random"] = ["seed", "mean", "std"] @@ -212,8 +227,9 @@ class CommonAccessor: self.opt_input_map = opt_input_map self.opt_init_map = opt_init_map - def parse_entry(self, varname, o_main_program): - for op in o_main_program.global_block().ops: + def parse_entry(self, varname, program_id, context): + main_program, startup_program = get_program_by_id(context, program_id) + for op in main_program.global_block().ops: if not is_distributed_sparse_op(op) and not is_sparse_op(op): continue @@ -243,23 +259,36 @@ class CommonAccessor: attr_str = "" origin_var_name = value_name + print("get_initializer_attr param name:", value_name) for op in o_startup_program.global_block().ops: if op.type in self.opt_init_map.keys( ) and origin_var_name == op.output("Out")[0]: init_attr = [op.type] + print("get_initializer_attr op type:", op.type) for attr in self.opt_init_map[op.type]: + print("get_initializer_attr opt_init_map attr:", attr) init_attr.append(str(op.attr(attr))) + print("get_initializer_attr op attr:", str(op.attr(attr))) attr_str = l_in.join(init_attr) break return attr_str - def parse_by_optimizer(self, grad_name, is_sparse, total_dims, context, - adam_d2sum): - main_program = context['origin_main_program'] - startup_program = context['startup_main_program'] + def parse_by_optimizer(self, ctx, context): + grad_name = ctx.origin_varnames()[0] + is_sparse = ctx.is_sparse() + size = ctx.sections()[0] + single_dim = ctx.sections()[1] if ctx.is_sparse() else 1 + adam_d2sum = context["user_defined_strategy"].adam_d2sum + print("parse_by_optimizer table_id:{} is_datanorm:{}".format( + ctx.table_id(), ctx.is_datanorm_table())) + + main_program, startup_program = get_program_by_id(context, + ctx.program_id()) pserver_id = get_role_id(context['role_maker']) pserver_num = len(get_ps_endpoints(context['role_maker'])) optimizer_ops = get_optimize_ops(main_program) + print("the one ps optimizer_ops:", optimizer_ops) + print("the one ps parse_by_optimizer grad_name:", grad_name) oop = None for op in optimizer_ops: @@ -278,6 +307,8 @@ class CommonAccessor: initializers = [] self.trainer_num = get_trainers(context['role_maker']) + self.table_num = size + self.table_dim = single_dim if oop.type != 'adam' and adam_d2sum == True: print('optimization algorithm is not adam, set adam_d2sum False') @@ -291,7 +322,11 @@ class CommonAccessor: param_varnames = self.opt_input_map["naive_adagrad"] attr_varnames = self.opt_attr_map["naive_adagrad"] self.accessor_class = "sgd" - elif adam_d2sum: + elif ctx.is_datanorm_table(): + param_varnames = self.opt_input_map["summary"] + attr_varnames = self.opt_attr_map["summary"] + self.accessor_class = "summary" + elif adam_d2sum and not is_sparse: param_varnames = self.opt_input_map["adam_d2sum"] attr_varnames = self.opt_attr_map["adam_d2sum"] self.accessor_class = "adam_d2sum" @@ -306,10 +341,9 @@ class CommonAccessor: #for dims if shape is None: if is_sparse: - shape = total_dims + shape = single_dim else: - shape = self.get_shard(total_dims, pserver_num, - pserver_id) + shape = self.get_shard(size, pserver_num, pserver_id) dims.append(shape) #for initializers @@ -333,6 +367,27 @@ class CommonAccessor: else: initializer = "fill_constant&0" initializers.append(initializer) + elif self.accessor_class == "summary": + #for dims + if shape is None: + if is_sparse: + shape = single_dim + else: + shape = self.get_shard(size, pserver_num, pserver_id) + dims.append(shape) + + #for initializers + if formal_name == "Param": + param = main_program.global_block().vars[oop.input( + formal_name)[0]] + + initializer = self.get_initializer_attr(param.name, + startup_program) + elif formal_name == "SummaryDecayRate": + initializer = "fill_constant&0.99999" + else: + initializer = "fill_constant&0" + initializers.append(initializer) else: if formal_name == "G2Sum": dims.append(1) @@ -348,9 +403,9 @@ class CommonAccessor: if shape is None: if is_sparse: - shape = total_dims + shape = single_dim else: - shape = self.get_shard(total_dims, pserver_num, + shape = self.get_shard(size, pserver_num, pserver_id) dims.append(shape) @@ -379,6 +434,10 @@ class CommonAccessor: attrs += "entry: \"{}\" ".format(self.entry) attrs += "trainer_num: {} ".format(self.trainer_num) attrs += "sync: {} ".format(self.sync) + if self.table_num: + attrs += "table_num: {} ".format(self.table_num) + if self.table_dim: + attrs += "table_dim: {} ".format(self.table_dim) for param in self.params: attrs += "params: \"{}\" ".format(param) @@ -448,10 +507,7 @@ class Table: accessor_str = accessor_str.format( conv_indent(indent), self.accessor_proto, conv_indent(indent)) attrs += accessor_str + "\n" - return table_str.format( - conv_indent(indent), attrs, conv_indent(indent)) - - if self.accessor is not None: + elif self.accessor is not None: attrs += self.accessor.to_string(indent) attrs += "\n" @@ -607,7 +663,9 @@ class TheOnePSRuntime(RuntimeBase): def _set_basic_info(self, context): self.context = context self.role_maker = context["role_maker"] + self.origin_main_program = context["origin_main_program"] + self.origin_main_programs = context["origin_main_programs"] self.context[ 'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode @@ -615,10 +673,13 @@ class TheOnePSRuntime(RuntimeBase): self.context['trainer'] = TrainerRuntimeConfig(context[ 'valid_strategy']) self.context['ps_mode'] = self.context['trainer'].mode - self.context['use_ps_gpu'] = context['valid_strategy'].use_ps_gpu + self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[ + 'use_ps_gpu'] self.is_sync = True if self.context[ 'ps_mode'] == DistributedMode.SYNC else False self.context['grad_name_to_param_name'] = {} + self.context['tensor_table'] = {} + build_var_distributed(self.context) def _init_worker(self): worker = self._get_fleet_proto(is_server=False, is_sync=self.is_sync) @@ -689,6 +750,7 @@ class TheOnePSRuntime(RuntimeBase): sync_kwargs = sync_strategy_envs() kwargs.update(sync_kwargs) + print("communicator config:", trainer_config.get_communicator_flags()) self._communicator = Communicator( trainer_config.mode, kwargs, trainer_config.get_communicator_flags()) @@ -893,7 +955,7 @@ class TheOnePSRuntime(RuntimeBase): common.table_name = self.context['grad_name_to_param_name'][ ctx.origin_varnames()[0]] - if self.ps_mode == DistributedMode.GEO: + if self.context['ps_mode'] == DistributedMode.GEO: table.table_class = "SparseGeoTable" else: all_table_proto = self.context[ @@ -907,7 +969,8 @@ class TheOnePSRuntime(RuntimeBase): table.table_class = table_proto.table_class else: table.table_class = parse_table_class( - common.table_name, self.origin_main_program) + common.table_name, + ctx.program_id(), self.context) if table.table_class != 'MemorySparseTable': table.table_class = 'MemorySparseTable' warnings.warn( @@ -925,12 +988,12 @@ class TheOnePSRuntime(RuntimeBase): warnings.warn( "The accessor of sparse table is not set, use default value." ) - get_default_accessor_proto(table_proto.accessor, - common.table_name, - self.origin_main_program) + get_default_accessor_proto( + table_proto.accessor, common.table_name, + ctx.program_id(), self.context) check_embedding_dim(table_proto.accessor, common.table_name, - self.origin_main_program) + ctx.program_id(), self.context) table.accessor_proto = text_format.MessageToString( table_proto.accessor) else: @@ -940,15 +1003,11 @@ class TheOnePSRuntime(RuntimeBase): common.table_name = "MergedDense" adam_d2sum = self.context["user_defined_strategy"].adam_d2sum - common.parse_by_optimizer(ctx.origin_varnames()[0], - ctx.is_sparse(), - ctx.sections()[1] if ctx.is_sparse() - else ctx.sections()[0], self.context, - adam_d2sum) + common.parse_by_optimizer(ctx, self.context) if ctx.is_sparse(): common.parse_entry(common.table_name, - self.origin_main_program) + ctx.program_id(), self.context) if is_sync: common.sync = "true" @@ -1023,8 +1082,9 @@ class TheOnePSRuntime(RuntimeBase): self._server.init_server(proto_txt, string_hosts, role_id, trainers, self._server_sub_program) - dist_varnames = get_sparse_tablenames(self.origin_main_program, True) - sparse_varnames = get_sparse_tablenames(self.origin_main_program, False) + dist_varnames = get_sparse_tablenames(self.origin_main_programs, True) + sparse_varnames = get_sparse_tablenames(self.origin_main_programs, + False) distributed_varnames = dist_varnames + sparse_varnames @@ -1070,6 +1130,7 @@ class TheOnePSRuntime(RuntimeBase): if var.name in exclude_var_names: return False + from .utils.public import _get_varname_parts origin_varname, _, _ = _get_varname_parts(var.name) if origin_varname.endswith("@GRAD"): return False @@ -1085,16 +1146,24 @@ class TheOnePSRuntime(RuntimeBase): return is_valid + def _get_inference_model_path(self, dirname): + if dirname.startswith("afs:") or dirname.startswith("hdfs:"): + model_path = "./dnn_plugin" + else: + model_path = os.path.join(dirname, "dnn_plugin") + return model_path + def _save_sparse_params(self, executor, dirname, context, main_program, mode): - distributed_varnames = get_sparse_tablenames( - self.context['origin_main_program'], True) + distributed_varnames = get_sparse_tablenames(self.origin_main_programs, + True) values = [] + model_path = self._get_inference_model_path(dirname) for id, names in context.items(): if names[0] not in distributed_varnames: # only save sparse param to local try: - self._worker.recv_and_save_model(id, dirname) + self._worker.recv_and_save_model(id, model_path) except: pass # save sparse & distributed param on server @@ -1221,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase): infer_program._copy_dist_param_info_from(program) - if dirname.startswith("afs:") or dirname.startswith("hdfs:"): - model_path = "./dnn_plugin" - else: - model_path = os.path.join(dirname, "dnn_plugin") + model_path = self._get_inference_model_path(dirname) model_basename = "__model__" model_basename = os.path.join(model_path, model_basename) paddle.save(infer_program, model_basename) @@ -1266,7 +1332,7 @@ class TheOnePSRuntime(RuntimeBase): self._ps_inference_save_persistables(*args, **kwargs) def _load_sparse_params(self, dirname, context, main_program, mode): - distributed_varnames = get_sparse_tablenames(self.origin_main_program, + distributed_varnames = get_sparse_tablenames(self.origin_main_programs, True) values = [] for id, names in context.items(): diff --git a/python/paddle/distributed/ps/utils/ps_program_builder.py b/python/paddle/distributed/ps/utils/ps_program_builder.py index c6afd0cb03b..25e4dc28bdc 100755 --- a/python/paddle/distributed/ps/utils/ps_program_builder.py +++ b/python/paddle/distributed/ps/utils/ps_program_builder.py @@ -79,7 +79,7 @@ class GeoPsProgramBuilder(PsProgramBuilder): # 仅 CPU 模式 super(GeoPsProgramBuilder, self).__init__(pass_ctx) if self.ps_mode != DistributedMode.GEO: raise ValueError("ps mode: {} not matched {}", - format(ps_mode, "GeoPsProgramBuilder")) + format(self.ps_mode, "GeoPsProgramBuilder")) def _build_trainer_programs(self): append_send_ops_pass = new_pass("append_send_ops_pass", self.attrs) @@ -97,9 +97,9 @@ class CpuSyncPsProgramBuilder(PsProgramBuilder): def __init__(self, pass_ctx): logger.info("start building cpu-sync-ps program") super(CpuSyncPsProgramBuilder, self).__init__(pass_ctx) - if self.ps_mode != DistributedMode.SYNC: + if self.ps_mode != DistributedMode.SYNC and self.ps_mode != DistributedMode.ASYNC: raise ValueError("ps mode: {} not matched {}", - format(ps_mode, "CpuSyncPsProgramBuilder")) + format(self.ps_mode, "CpuSyncPsProgramBuilder")) def _build_trainer_programs(self): add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", @@ -178,7 +178,7 @@ class HeterAsyncPsProgramBuilder(PsProgramBuilder): if self.use_ps_gpu or self.ps_mode == DistributedMode.GEO or self.attrs[ 'is_heter_ps_mode'] == False: raise ValueError("ps mode: {} not matched {}", - format(ps_mode, "HeterAsyncPsProgramBuilder")) + format(self.ps_mode, "HeterAsyncPsProgramBuilder")) def _build_trainer_programs(self): add_lr_decay_table_pass = new_pass("add_lr_decay_table_pass", diff --git a/python/paddle/distributed/ps/utils/public.py b/python/paddle/distributed/ps/utils/public.py index 7743db1057d..ebec6900e38 100755 --- a/python/paddle/distributed/ps/utils/public.py +++ b/python/paddle/distributed/ps/utils/public.py @@ -54,6 +54,9 @@ SPARSE_GRAD_OP_TYPE_DICT = { } DEFAULT_DEVICE = 'cpu' +DATA_NORM_NAME = [".batch_size", ".batch_sum", ".batch_square_sum"] +DATA_NORM_GRAD_NAME = [x + "@GRAD" for x in DATA_NORM_NAME] + def logger_config(log_path, logging_name): logger = logging.getLogger(logging_name) @@ -84,6 +87,8 @@ class DistributedMode: class TrainerRuntimeConfig(object): def __init__(self, valid_strategy): self.mode = None + num_threads = os.getenv("CPU_NUM", "1") + send_queue_size = num_threads k_steps = valid_strategy.a_sync_configs["k_steps"] logger.info("ps mode in strategy: {}, {}".format( valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"])) @@ -95,14 +100,13 @@ class TrainerRuntimeConfig(object): if valid_strategy.a_sync and k_steps > 0: self.mode = DistributedMode.GEO - - num_threads = os.getenv("CPU_NUM", "1") + send_queue_size = k_steps self.runtime_configs = {} self.runtime_configs['communicator_max_merge_var_num'] = os.getenv( - "FLAGS_communicator_max_merge_var_num", num_threads) + "FLAGS_communicator_max_merge_var_num", send_queue_size) self.runtime_configs['communicator_send_queue_size'] = os.getenv( - "FLAGS_communicator_send_queue_size", num_threads) + "FLAGS_communicator_send_queue_size", send_queue_size) self.runtime_configs[ 'communicator_independent_recv_thread'] = os.getenv( "FLAGS_communicator_independent_recv_thread", "1") @@ -116,6 +120,55 @@ class TrainerRuntimeConfig(object): self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv( "FLAGS_communicator_is_sgd_optimizer", "1") + def get_communicator_flags(self): + need_keys = [] + num_threads = os.getenv("CPU_NUM", "1") + mode_str = "" + if self.mode is None or self.mode == DistributedMode.ASYNC: + need_keys = self.runtime_configs.keys() + mode_str = "async" + elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: + mode_str = "sync or half_async" + need_keys = [ + 'communicator_max_merge_var_num', + 'communicator_send_wait_times', 'communicator_thread_pool_size', + 'communicator_send_queue_size' + ] + elif self.mode == DistributedMode.GEO: + mode_str = "GEO" + need_keys = [ + 'communicator_thread_pool_size', 'communicator_send_wait_times', + 'communicator_max_merge_var_num', 'communicator_send_queue_size' + ] + else: + raise ValueError("Unsupported Mode") + + if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: + max_merge_var_num = self.runtime_configs[ + 'communicator_max_merge_var_num'] + send_queue_size = self.runtime_configs[ + 'communicator_send_queue_size'] + if max_merge_var_num != num_threads: + print('WARNING: In {} mode, communicator_max_merge_var_num ' + 'must be equal to CPU_NUM. But received, ' + 'communicator_max_merge_var_num = {}, CPU_NUM = ' + '{}. communicator_max_merge_var_num will be forced to {}.' + .format(mode_str, max_merge_var_num, num_threads, + num_threads)) + self.runtime_configs[ + 'communicator_max_merge_var_num'] = num_threads + if send_queue_size != num_threads: + print('WARNING: In {} mode, communicator_send_queue_size ' + 'must be equal to CPU_NUM. But received, ' + 'communicator_send_queue_size = {}, CPU_NUM = ' + '{}. communicator_send_queue_size will be forced to {}.' + .format(mode_str, send_queue_size, num_threads, + num_threads)) + self.runtime_configs[ + 'communicator_send_queue_size'] = num_threads + + return dict((key, str(self.runtime_configs[key])) for key in need_keys) + def get_lr_ops(program): lr_ops = [] @@ -176,6 +229,13 @@ def get_ps_endpoint(role_maker): return role_maker.get_pserver_endpoints()[get_role_id(role_maker)] +def get_ps_endpoints(role_maker): + try: + return role_maker._get_pserver_endpoints() + except Exception: + return role_maker.get_pserver_endpoints() + + def get_heter_worker_endpoint(role_maker): try: return role_maker._get_heter_worker_endpoint() @@ -224,26 +284,20 @@ def is_sparse_op(op): return False -def get_sparse_tablenames(program, is_distributed): +def get_sparse_tablenames(programs, is_distributed): tablenames = set() - if is_distributed: - for op in program.global_block().ops: - if is_distributed_sparse_op(op): - tablenames.add(get_sparse_tablename(op)) - else: - for op in program.global_block().ops: - if is_sparse_op(op): - tablenames.add(get_sparse_tablename(op)) + for program in programs: + if is_distributed: + for op in program.global_block().ops: + if is_distributed_sparse_op(op): + tablenames.add(get_sparse_tablename(op)) + else: + for op in program.global_block().ops: + if is_sparse_op(op): + tablenames.add(get_sparse_tablename(op)) return list(tablenames) -def get_ps_endpoints(role_maker): - try: - return role_maker._get_pserver_endpoints() - except Exception: - return role_maker.get_pserver_endpoints() - - def get_trainers(role_maker): try: return role_maker._worker_num() @@ -251,7 +305,7 @@ def get_trainers(role_maker): return role_maker.worker_num() -def get_dense_send_context(context, +def get_dense_send_context(program, send_ctx, idx, merged_dense_pairs, @@ -260,34 +314,72 @@ def get_dense_send_context(context, if len(merged_dense_pairs) < 1: return idx if not split_dense_table: + dense_pairs = [] + data_norm_pairs = [] + for merged in merged_dense_pairs: + is_data_norm = False + grad = merged[1] + varname = grad.merged_var.name + for name in DATA_NORM_GRAD_NAME: + if varname.endswith(name): + is_data_norm = True + if is_data_norm: + data_norm_pairs.append(merged) + else: + dense_pairs.append(merged) + + # simple dense table origin_varnames = [] var_numel = 0 - for merged in merged_dense_pairs: + for merged in dense_pairs: grad = merged[1] origin_varnames.append(grad.merged_var.name) - var = context['origin_main_program'].global_block().vars[ - grad.merged_var.name] + var = program.global_block().vars[grad.merged_var.name] var_numel += reduce(lambda x, y: x * y, var.shape) - grad_name = "Dense@Grad" - trainer_id = get_role_id(context['role_maker']) + grad_name = "Dense@GRAD_" + str(idx) aggregate = True + print("public get_dense_send_context dense_table:", grad_name, + var_numel, origin_varnames) dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], origin_varnames, trainer_id, - aggregate, False, False, idx, False) + aggregate, False, False, idx, False, False, + id(program)) send_ctx[grad_name] = dense_ctx idx += 1 + + if len(data_norm_pairs) <= 0: + return idx + + # data norm table + origin_varnames = [] + var_numel = 0 + for merged in data_norm_pairs: + grad = merged[1] + origin_varnames.append(grad.merged_var.name) + var = program.global_block().vars[grad.merged_var.name] + var_numel += reduce(lambda x, y: x * y, var.shape) + grad_name = "DataNorm@GRAD_" + str(idx) + aggregate = True + print("public get_dense_send_context data_norm table:", grad_name, + var_numel, origin_varnames) + data_norm_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], + [var_numel], origin_varnames, trainer_id, + aggregate, False, False, idx, False, True, + id(program)) + send_ctx[grad_name] = data_norm_ctx + idx += 1 else: for merged in merged_dense_pairs: grad = merged[1] origin_varname = grad.merged_var.name - var = context['origin_main_program'].global_block().vars[ - origin_varname] + var = program.global_block().vars[origin_varname] var_numel = reduce(lambda x, y: x * y, var.shape) grad_name = origin_varname aggregate = True dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [origin_varname], trainer_id, - aggregate, False, False, idx, False) + aggregate, False, False, idx, False, False, + id(program)) send_ctx[grad_name] = dense_ctx idx += 1 return idx @@ -299,25 +391,28 @@ def get_geo_trainer_send_context(context): format(ps_mode, "get_geo_trainer_send_context")) send_ctx = {} trainer_id = get_role_id(context['role_maker']) + origin_programs = context['origin_main_programs'] idx = 0 - distibuted_varnames = get_sparse_tablenames(context['origin_main_program'], - True) - for merged in context['merged_sparse_pairs']: - param, grad = merged - grad_name = grad.merged_var.name - param_name = param.merged_var.name - is_distributed = True if param_name in distibuted_varnames else False - - var = context['origin_main_program'].global_block().vars[ - grad.merged_var.name] - var_numel = reduce(lambda x, y: x * y, var.shape[1:]) - - sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], - [var_numel], [grad_name], trainer_id, True, - True, is_distributed, idx, False) - idx += 1 - send_ctx[sparse_ctx.var_name()] = sparse_ctx + distibuted_varnames = get_sparse_tablenames(origin_programs, True) + for i, program in enumerate(origin_programs): + merged_sparse_pairs = context['merged_sparse_pairs'][i] + for merged in merged_sparse_pairs: + param, grad = merged + grad_name = grad.merged_var.name + param_name = param.merged_var.name + is_distributed = True if param_name in distibuted_varnames else False + + var = program.global_block().vars[grad.merged_var.name] + var_numel = reduce(lambda x, y: x * y, var.shape[1:]) + + sparse_ctx = CommContext(grad_name, [grad_name], + ["127.0.0.1:6071"], [var_numel], + [grad_name], trainer_id, True, True, + is_distributed, idx, False, False, + id(program)) + idx += 1 + send_ctx[sparse_ctx.var_name()] = sparse_ctx if len(send_ctx) == 0: raise ValueError("GeoSGD require sparse parameters in your net.") @@ -336,7 +431,7 @@ def _step_ctx(idx, role_maker): sections = [1] * len(endpoints) names = [name] * len(endpoints) ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, - True, False, False, idx, True) + True, False, False, idx, True, False, -1) return name, ctx @@ -348,36 +443,45 @@ def get_the_one_send_context(context, ep_list = ["127.0.0.1:6071"] send_ctx = {} trainer_id = get_role_id(context['role_maker']) + origin_programs = context['origin_main_programs'] idx = 0 - idx += get_dense_send_context(context, send_ctx, idx, - context['merged_dense_pairs'], trainer_id, - split_dense_table) - distibuted_varnames = get_sparse_tablenames(context['origin_main_program'], - True) - for merged in context['merged_sparse_pairs']: - param, grad = merged - grad_name = grad.merged_var.name - param_name = param.merged_var.name - splited_varname = [] - - for i in range(len(ep_list)): - splited_varname.append("{}.block{}".format(param_name, i)) - - is_distributed = True if param_name in distibuted_varnames else False - - var = context['origin_main_program'].global_block().vars[ - grad.merged_var.name] - - shape = list(var.shape) - shape[0] = 0 if is_distributed else shape[0] - - sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, - [grad_name], trainer_id, True, True, - is_distributed, idx, False) + for i, program in enumerate(origin_programs): + merged_dense_pairs = context['merged_dense_pairs'][i] + idx += get_dense_send_context(program, send_ctx, idx, + merged_dense_pairs, trainer_id, + split_dense_table) + distibuted_varnames = get_sparse_tablenames(origin_programs, True) + print("public distibuted_varnames:", distibuted_varnames) + for i, program in enumerate(origin_programs): + merged_sparse_pairs = context['merged_sparse_pairs'][i] + for merged in merged_sparse_pairs: + param, grad = merged + grad_name = grad.merged_var.name + param_name = param.merged_var.name + splited_varname = [] + + for i in range(len(ep_list)): + splited_varname.append("{}.block{}".format(param_name, i)) + + is_distributed = True if param_name in distibuted_varnames else False + + var = program.global_block().vars[grad.merged_var.name] + + shape = list(var.shape) + shape[0] = 0 if is_distributed else shape[0] + + print("public get_the_one_send_context sparse:", grad_name, + splited_varname, shape) + if grad_name in send_ctx: + continue + sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, + [grad_name], trainer_id, True, True, + is_distributed, idx, False, False, + id(program)) - idx += 1 - send_ctx[sparse_ctx.var_name()] = sparse_ctx + idx += 1 + send_ctx[sparse_ctx.var_name()] = sparse_ctx if len(context['tensor_table']) > 0 and context['is_worker']: name, ctx = _step_ctx(idx, context['role_maker']) @@ -1073,7 +1177,7 @@ def get_the_one_recv_context(context, param_names = [] for grad_varname in origin_grad_varnames: - param_name = grad_name_to_param_name[grad_varname] + param_name = context["grad_name_to_param_name"][grad_varname] param_names.append(param_name) recv_id_maps[ctx.table_id()] = param_names else: @@ -1090,7 +1194,7 @@ def get_the_one_recv_context(context, param_names = [] for grad_varname in origin_grad_varnames: - param_name = grad_name_to_param_name[grad_varname] + param_name = context["grad_name_to_param_name"][grad_varname] param_names.append(param_name) recv_id_maps[ctx.table_id()] = param_names return recv_id_maps @@ -1141,58 +1245,88 @@ class MergedVariable: def build_var_distributed(context): - sparse_pairs, dense_pairs = get_param_grads(context['origin_main_program']) - origin_for_sparse = [] - origin_for_dense = [] - param_name_grad_name = {} + origin_programs = context['origin_main_programs'] + + param_name_to_grad_name = {} grad_name_to_param_name = {} - context["merged_variables_pairs"] = [] + context["origin_sparse_pairs"] = [] + context["origin_dense_pairs"] = [] context["merged_sparse_pairs"] = [] context['merged_dense_pairs'] = [] + context["merged_variables_pairs"] = [] context["merged_variable_map"] = {} - - for param, grad in sparse_pairs: - origin_for_sparse.append((param, grad)) - - for param, grad in dense_pairs: - origin_for_dense.append((param, grad)) - - for dense_pair in origin_for_dense: - param, grad = dense_pair - - m_param = MergedVariable(param, [param], [0]) - m_grad = MergedVariable(grad, [grad], [0]) - context["merged_variables_pairs"].append((m_param, m_grad)) - context["merged_dense_pairs"].append((m_param, m_grad)) - - for sparse_pair in origin_for_sparse: - param, grad = sparse_pair - - m_param = MergedVariable(param, [param], [0]) - m_grad = MergedVariable(grad, [grad], [0]) - context["merged_variables_pairs"].append((m_param, m_grad)) - context["merged_sparse_pairs"].append((m_param, m_grad)) - - for merged in context["merged_variables_pairs"]: - m_param, m_grad = merged - context["merged_variable_map"][ - m_param.merged_var.name] = m_param.merged_var - context["merged_variable_map"][ - m_grad.merged_var.name] = m_grad.merged_var - - param_merges = [] - param_merges.extend(origin_for_sparse) - param_merges.extend(origin_for_dense) - - for param, grad in param_merges: - param_name_grad_name[param.name] = grad.name - grad_name_to_param_name[grad.name] = param.name - - context["origin_sparse_pairs"] = origin_for_sparse - context["origin_dense_pairs"] = origin_for_dense - context["param_name_to_grad_name"] = param_name_grad_name + for origin_program in origin_programs: + sparse_pairs, dense_pairs = get_param_grads(origin_program) + print("public build_var_distributed sparse_pairs:", sparse_pairs) + print("public build_var_distributed dense_pairs:", dense_pairs) + origin_for_sparse = [] + origin_for_dense = [] + merged_sparse_pairs = [] + merged_dense_pairs = [] + merged_variables_pairs = [] + + for param, grad in sparse_pairs: + origin_for_sparse.append((param, grad)) + + for param, grad in dense_pairs: + origin_for_dense.append((param, grad)) + + for dense_pair in origin_for_dense: + param, grad = dense_pair + + m_param = MergedVariable(param, [param], [0]) + m_grad = MergedVariable(grad, [grad], [0]) + merged_variables_pairs.append((m_param, m_grad)) + merged_dense_pairs.append((m_param, m_grad)) + print("public build_var_distributed merged_dense_pairs:", + merged_dense_pairs) + + for sparse_pair in origin_for_sparse: + param, grad = sparse_pair + + m_param = MergedVariable(param, [param], [0]) + m_grad = MergedVariable(grad, [grad], [0]) + merged_variables_pairs.append((m_param, m_grad)) + merged_sparse_pairs.append((m_param, m_grad)) + print("public build_var_distributed merged_sparse_pairs:", + merged_sparse_pairs) + + for merged in merged_variables_pairs: + m_param, m_grad = merged + context["merged_variable_map"][ + m_param.merged_var.name] = m_param.merged_var + context["merged_variable_map"][ + m_grad.merged_var.name] = m_grad.merged_var + + param_merges = [] + param_merges.extend(origin_for_sparse) + param_merges.extend(origin_for_dense) + + for param, grad in param_merges: + param_name_to_grad_name[param.name] = grad.name + grad_name_to_param_name[grad.name] = param.name + + context["origin_sparse_pairs"].append(origin_for_sparse) + context["origin_dense_pairs"].append(origin_for_dense) + context["merged_sparse_pairs"].append(merged_sparse_pairs) + context['merged_dense_pairs'].append(merged_dense_pairs) + + context["param_name_to_grad_name"] = param_name_to_grad_name context["grad_name_to_param_name"] = grad_name_to_param_name + print("public build_var_distributed origin_sparse_pairs:", + context["origin_sparse_pairs"]) + print("public build_var_distributed origin_for_dense:", + context["origin_dense_pairs"]) + print("public build_var_distributed merged_sparse_pairs:", + context["merged_sparse_pairs"]) + print("public build_var_distributed merged_dense_pairs:", + context['merged_dense_pairs']) + print("public build_var_distributed param_name_to_grad_name:", + param_name_to_grad_name) + print("public build_var_distributed grad_name_to_param_name:", + grad_name_to_param_name) + def _is_opt_role_op(op): # NOTE : depend on oprole to find out whether this op is for diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index 4b8c7ccbb69..b6ec09bab72 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -577,7 +577,7 @@ class CompileTimeStrategy(object): sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [grad_name], trainer_id, True, True, - is_distributed, idx, False) + is_distributed, idx, False, False, -1) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx @@ -615,7 +615,8 @@ class CompileTimeStrategy(object): aggregate = True dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], origin_varnames, trainer_id, - aggregate, False, False, idx, False) + aggregate, False, False, idx, False, False, + -1) send_ctx[grad_name] = dense_ctx idx += 1 else: @@ -630,7 +631,7 @@ class CompileTimeStrategy(object): dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [origin_varname], trainer_id, aggregate, - False, False, idx, False) + False, False, idx, False, False, -1) send_ctx[grad_name] = dense_ctx idx += 1 return idx @@ -672,7 +673,7 @@ class CompileTimeStrategy(object): sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, [grad_name], trainer_id, True, True, - is_distributed, idx, False) + is_distributed, idx, False, False, -1) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx @@ -750,7 +751,7 @@ class CompileTimeStrategy(object): sections = [1] * len(endpoints) names = [name] * len(endpoints) ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, - True, False, False, idx, True) + True, False, False, idx, True, False, -1) return name, ctx def _create_vars_from_blocklist(self, block_list): -- GitLab