# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import os import paddle.fluid as fluid from paddle.distributed import fleet from paddle.fluid import core from paddle.distributed.ps.utils.public import * # noqa: F403 from paddle.fluid.framework import Program from paddle.fluid.compiler import CompiledProgram from paddle.fluid.executor import Executor from paddle.fluid.parallel_executor import ParallelExecutor from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase from paddle.distributed.fleet.base.private_helper_function import wait_server_ready from paddle.distributed.fleet.proto import the_one_ps_pb2 from paddle.fluid.communicator import Communicator, HeterClient from google.protobuf import text_format from paddle.distributed.ps.coordinator import Coordinator __all__ = [ 'Table', 'SparseTable', 'GeoSparseTable', 'BarrierTable', 'TensorTable', 'DenseTable' ] def get_program_by_id(context, program_id): programs = context["origin_main_programs"] for i, program in enumerate(programs): if id(program) == program_id: return program, context["origin_startup_programs"][i], i return None, None, None def parse_table_class(varname, program_id, context): main_program, startup_program, idx = get_program_by_id(context, program_id) for op in main_program.global_block().ops: if not is_distributed_sparse_op(op) and not is_sparse_op(op): continue param_name = op.input("W")[0] if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2": if op.has_attr('table_class') and op.attr("table_class") != "none": return op.attr('table_class') else: return "MemorySparseTable" def check_embedding_dim(accessor_proto, varname, program_id, context): main_program, startup_program, idx = get_program_by_id(context, program_id) embedding_dim = 0 for var in main_program.list_vars(): if var.name == varname: embedding_dim = var.shape[1] print('new var: {}, {}, {}'.format(var, embedding_dim, accessor_proto.fea_dim)) break fea_dim = accessor_proto.fea_dim if accessor_proto.accessor_class == "SparseAccessor": if fea_dim != embedding_dim + 2: raise ValueError( "The fea_dim is wrong, it will be sparse_embedding_dim + 2: {}, but got {}" .format(embedding_dim + 2, fea_dim)) else: if fea_dim != embedding_dim: raise ValueError( "The fea_dim is wrong, it will be sparse_embedding_dim: {}, but got {}" .format(embedding_dim, fea_dim)) embedx_dim = accessor_proto.embedx_dim if accessor_proto.accessor_class == "SparseAccessor": if embedx_dim != embedding_dim - 1: raise ValueError( "The embedx_dim is wrong, it will be sparse_embedding_dim - 1: {}, but got {}" .format(embedding_dim - 1, embedx_dim)) else: if embedx_dim != embedding_dim - 3: raise ValueError( "The embedx_dim is wrong, it will be sparse_embedding_dim - 3: {}, but got {}" .format(embedding_dim - 3, embedx_dim)) class Service: def __init__(self): pass def _set(self, service_proto): service_proto.server_class = "BrpcPsServer" service_proto.client_class = "BrpcPsClient" service_proto.service_class = "BrpcPsService" service_proto.start_server_port = 0 service_proto.server_thread_num = 12 class GpuService(Service): def __init__(self): super(GpuService, self).__init__() def _set(self, service_proto): service_proto.server_class = 'PsLocalServer' service_proto.client_class = 'PsLocalClient' class Accessor: def __init__(self): self.accessor_class = "" self.optimizer = None self.feature_dim = 0 self.embedding_dim = 0 # TableAccessorParameter accessor def _set(self, accessor_proto, varname, program_id, context, common_accessor): main_program, startup_program, idx = get_program_by_id( context, program_id) embedding_dim = 0 for var in main_program.list_vars(): if var.name == varname: embedding_dim = var.shape[1] break if not accessor_proto.HasField("accessor_class"): # DownpourSparseValueAccessor if context['use_ps_gpu']: accessor_proto.accessor_class = "CtrDymfAccessor" else: accessor_proto.accessor_class = "SparseAccessor" if not accessor_proto.HasField("fea_dim"): if accessor_proto.accessor_class == "SparseAccessor": accessor_proto.fea_dim = embedding_dim + 2 else: accessor_proto.fea_dim = embedding_dim if not accessor_proto.HasField("embedx_dim"): if accessor_proto.accessor_class == "SparseAccessor": accessor_proto.embedx_dim = embedding_dim - 1 else: accessor_proto.embedx_dim = embedding_dim - 3 if not accessor_proto.HasField("embedx_threshold"): accessor_proto.embedx_threshold = 0 graph_sgd_param = accessor_proto.graph_sgd_param if not graph_sgd_param.HasField("nodeid_slot"): graph_sgd_param.nodeid_slot = 9008 if not graph_sgd_param.HasField("feature_learning_rate"): graph_sgd_param.feature_learning_rate = 0.05 ctr_accessor_param = accessor_proto.ctr_accessor_param if accessor_proto.embedx_dim == 0: ctr_accessor_param.zero_init = False if not ctr_accessor_param.HasField("nonclk_coeff"): ctr_accessor_param.nonclk_coeff = 0.1 if not ctr_accessor_param.HasField("click_coeff"): ctr_accessor_param.click_coeff = 1.0 if not ctr_accessor_param.HasField("base_threshold"): ctr_accessor_param.base_threshold = 0 if not ctr_accessor_param.HasField("delta_threshold"): ctr_accessor_param.delta_threshold = 0 if not ctr_accessor_param.HasField("delta_keep_days"): ctr_accessor_param.delta_keep_days = 16 if not ctr_accessor_param.HasField("show_click_decay_rate"): ctr_accessor_param.show_click_decay_rate = 1 if not ctr_accessor_param.HasField("delete_threshold"): ctr_accessor_param.delete_threshold = 0 if not ctr_accessor_param.HasField("delete_after_unseen_days"): ctr_accessor_param.delete_after_unseen_days = 30 if not ctr_accessor_param.HasField("ssd_unseenday_threshold"): ctr_accessor_param.ssd_unseenday_threshold = 1 for sgd_param in [ accessor_proto.embed_sgd_param, accessor_proto.embedx_sgd_param ]: if not sgd_param.HasField("name"): if common_accessor.accessor_class == "sgd": sgd_param.name = "SparseNaiveSGDRule" if common_accessor.accessor_class == "adam": sgd_param.name = "SparseAdamSGDRule" else: # for fl-ps, because geo accessor is 'sum' sgd_param.name = "SparseAdamSGDRule" if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": if not sgd_param.adagrad.HasField("learning_rate"): sgd_param.adagrad.learning_rate = 0.05 if not sgd_param.adagrad.HasField("initial_g2sum"): sgd_param.adagrad.initial_g2sum = 3.0 if not sgd_param.adagrad.HasField("initial_range"): sgd_param.adagrad.initial_range = 0.0001 if len(sgd_param.adagrad.weight_bounds) == 0: sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0]) if sgd_param.name == "SparseNaiveSGDRule": if not sgd_param.naive.HasField("learning_rate"): learning_rate = common_accessor.initializers[-1].split( "&")[1] sgd_param.naive.learning_rate = float(learning_rate) if not sgd_param.naive.HasField("initial_range"): initial_range = common_accessor.initializers[0].split( "&")[-1] sgd_param.naive.initial_range = float(initial_range) if len(sgd_param.naive.weight_bounds) == 0: sgd_param.naive.weight_bounds.extend([-10.0, 10.0]) if sgd_param.name == "SparseAdamSGDRule" or sgd_param.name == "SparseSharedAdamSGDRule": if not sgd_param.adam.HasField("learning_rate"): learning_rate = common_accessor.initializers[-1].split( "&")[1] sgd_param.adam.learning_rate = float(learning_rate) if not sgd_param.adam.HasField("initial_range"): initial_range = common_accessor.initializers[0].split( "&")[-1] sgd_param.adam.initial_range = float(initial_range) attr_list = [x.split("&") for x in common_accessor.attrs] if not sgd_param.adam.HasField( "beta1_decay_rate" ) and common_accessor.accessor_class == "adam": sgd_param.adam.beta1_decay_rate = float(attr_list[0][1]) else: sgd_param.adam.beta1_decay_rate = 0.9 if not sgd_param.adam.HasField( "beta2_decay_rate" ) and common_accessor.accessor_class == "adam": sgd_param.adam.beta2_decay_rate = float(attr_list[1][1]) else: sgd_param.adam.beta2_decay_rate = 0.999 if not sgd_param.adam.HasField( "ada_epsilon" ) and common_accessor.accessor_class == "adam": sgd_param.adam.ada_epsilon = float(attr_list[2][1]) else: sgd_param.adam.ada_epsilon = 1e-08 if len(sgd_param.adam.weight_bounds) == 0: sgd_param.adam.weight_bounds.extend([-10.0, 10.0]) class CommonAccessor(Accessor): def __init__(self): super(CommonAccessor, self).__init__() self.table_name = '' self.entry = 'none' self.attrs = [] self.params = [] self.dims = [] self.trainer_num = 0 self.sync = False self.initializers = [] self.opt_input_map = {} self.opt_attr_map = {} self.opt_init_map = {} self.define_optimize_map() def define_optimize_map(self): opt_input_map = {} opt_input_map["sgd"] = [("Param", None), ("LearningRate", 1)] opt_input_map["adam"] = [("Param", None), ("Moment1", None), ("Moment2", None), ("Beta1Pow", 1), ("Beta2Pow", 1), ("LearningRate", 1)] opt_input_map["adam_d2sum"] = [("Param", None), ("D2Sum", None), ("G2Sum", None), ("Moment", None), ("MomentDecayRate", 1), ("AdaDecayRate", 1), ("AdaEpsilon", 1), ("LearningRate", 1)] opt_input_map["sum"] = [("Param", None)] opt_input_map["naive_adagrad"] = [("Param", None), ("G2Sum", 1), ("LearningRate", 1)] opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)] opt_attr_map = {} opt_attr_map["sgd"] = [] opt_attr_map["sum"] = [] opt_attr_map["naive_adagrad"] = [] opt_attr_map["adam"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] opt_attr_map["adam_d2sum"] = [("beta1", "f"), ("beta2", "f"), ("epsilon", "f")] opt_attr_map["summary"] = [("summary_decay_rate", "f")] opt_init_map = {} opt_init_map["gaussian_random"] = ["seed", "mean", "std"] opt_init_map["fill_constant"] = ["value"] opt_init_map["uniform_random"] = ["seed", "min", "max"] opt_init_map["truncated_gaussian_random"] = ["seed", "mean", "std"] self.opt_attr_map = opt_attr_map self.opt_input_map = opt_input_map self.opt_init_map = opt_init_map def parse_entry(self, varname, program_id, context): main_program, startup_program, idx = get_program_by_id( context, program_id) for op in main_program.global_block().ops: if not is_distributed_sparse_op(op) and not is_sparse_op(op): continue param_name = op.input("W")[0] if param_name == varname and op.type == "lookup_table": self.entry = op.attr('entry') break if param_name == varname and op.type == "lookup_table_v2": self.entry = "none" break def get_shard(self, total_dim, shard_num, pserver_id): blocksize = int(total_dim / shard_num + 1) if blocksize * (pserver_id + 1) <= total_dim: return blocksize else: if blocksize * pserver_id < total_dim: return total_dim - blocksize * pserver_id else: return 0 def get_initializer_attr(self, value_name, o_startup_program): l_in = "&" attr_str = "" origin_var_name = value_name # print("get_initializer_attr param name:", value_name) for op in o_startup_program.global_block().ops: if op.type in self.opt_init_map.keys( ) and origin_var_name == op.output("Out")[0]: init_attr = [op.type] # print("get_initializer_attr op type:", op.type) for attr in self.opt_init_map[op.type]: # print("get_initializer_attr opt_init_map attr:", attr) init_attr.append(str(op.attr(attr))) # print("get_initializer_attr op attr:", str(op.attr(attr))) attr_str = l_in.join(init_attr) break return attr_str def parse_by_optimizer(self, ctx, context): grad_name = ctx.origin_varnames()[0] is_sparse = ctx.is_sparse() size = ctx.sections()[0] single_dim = ctx.sections()[1] if ctx.is_sparse() else 1 adam_d2sum = context["user_defined_strategy"].adam_d2sum # print("parse_by_optimizer table_id:{} is_datanorm:{}".format( # ctx.table_id(), ctx.is_datanorm_table())) main_program, startup_program, idx = get_program_by_id( context, ctx.program_id()) pserver_id = get_role_id(context['role_maker']) pserver_num = len(get_ps_endpoints(context['role_maker'])) optimizer_ops = get_optimize_ops(main_program) # print("the one ps optimizer_ops:", optimizer_ops) # print("the one ps parse_by_optimizer grad_name:", grad_name) oop = None for op in optimizer_ops: if ("Param" in op.input_names) and ( op.input("Param")[0] == context['grad_name_to_param_name'][grad_name]): oop = op break if oop is None: raise ValueError("can not find optimizer for {}".format(grad_name)) params = [] dims = [] attrs = [] initializers = [] self.trainer_num = get_trainers(context['role_maker']) self.table_num = size self.table_dim = single_dim if oop.type != 'adam' and adam_d2sum == True: print('optimization algorithm is not adam, set adam_d2sum False') adam_d2sum = False print("adam_d2sum:", adam_d2sum) if context['ps_mode'] == DistributedMode.GEO: param_varnames = self.opt_input_map["sum"] attr_varnames = self.opt_attr_map["sum"] self.accessor_class = "sum" elif context['use_ps_gpu'] and is_sparse: param_varnames = self.opt_input_map["naive_adagrad"] attr_varnames = self.opt_attr_map["naive_adagrad"] self.accessor_class = "sgd" elif ctx.is_datanorm_table(): param_varnames = self.opt_input_map["summary"] attr_varnames = self.opt_attr_map["summary"] self.accessor_class = "summary" elif adam_d2sum and not is_sparse: param_varnames = self.opt_input_map["adam_d2sum"] attr_varnames = self.opt_attr_map["adam_d2sum"] self.accessor_class = "adam_d2sum" else: if oop.type != 'sgd' and oop.type != 'adam': raise ValueError( "The dense optimizer in PS is only supported SGD or Adam!") param_varnames = self.opt_input_map[oop.type] attr_varnames = self.opt_attr_map[oop.type] self.accessor_class = oop.type for (formal_name, shape) in param_varnames: params.append(formal_name) if self.accessor_class == "adam_d2sum": #for dims if shape is None: if is_sparse: shape = single_dim else: shape = self.get_shard(size, pserver_num, pserver_id) dims.append(shape) #for initializers if formal_name == "Param" or formal_name == "LearningRate": param = main_program.global_block().vars[oop.input( formal_name)[0]] #TODO: for dense learning_rate, can be different from sparse lr if formal_name == "LearningRate" and param.name != "learning_rate_" + str( idx): warnings.warn("will support decay soon") param = main_program.global_block().vars[ "learning_rate_" + str(idx)] initializer = self.get_initializer_attr( param.name, startup_program) elif formal_name == "MomentDecayRate": initializer = "fill_constant&0.99" elif formal_name == "AdaDecayRate": initializer = "fill_constant&0.9999" elif formal_name == "AdaEpsilon": initializer = "fill_constant&1.0e-8" else: initializer = "fill_constant&0" initializers.append(initializer) elif self.accessor_class == "summary": #for dims if shape is None: if is_sparse: shape = single_dim else: shape = self.get_shard(size, pserver_num, pserver_id) dims.append(shape) #for initializers if formal_name == "Param": param = main_program.global_block().vars[oop.input( formal_name)[0]] initializer = self.get_initializer_attr( param.name, startup_program) elif formal_name == "SummaryDecayRate": initializer = "fill_constant&0.999999" else: initializer = "fill_constant&0" initializers.append(initializer) else: if formal_name == "G2Sum": dims.append(1) initializer = "fill_constant&0" initializers.append(initializer) else: param = main_program.global_block().vars[oop.input( formal_name)[0]] if formal_name == "LearningRate" and param.name != "learning_rate_" + str( idx): warnings.warn("will support decay soon") param = main_program.global_block().vars[ "learning_rate_" + str(idx)] if shape is None: if is_sparse: shape = single_dim else: shape = self.get_shard(size, pserver_num, pserver_id) dims.append(shape) initializer = self.get_initializer_attr( param.name, startup_program) initializers.append(initializer) if self.accessor_class == 'summary': datanorm_ops = get_datanorm_ops(main_program) for op in datanorm_ops: if ("BatchSize" in op.input_names) and ( op.input("BatchSize")[0] == context['grad_name_to_param_name'][grad_name]): oop = op break for (attr_varname, type_) in attr_varnames: value = oop.attr(attr_varname) attrs.append("&".join([attr_varname, str(value)])) self.params = params self.dims = dims self.initializers = initializers self.attrs = attrs # CommonAccessorParameter common def _set(self, proto): proto.name = self.accessor_class proto.table_name = self.table_name proto.params.extend(self.params) proto.dims.extend(self.dims) proto.initializers.extend(self.initializers) proto.entry = self.entry proto.trainer_num = self.trainer_num proto.sync = self.sync proto.table_num = self.table_num proto.table_dim = self.table_dim proto.attr = "#".join(self.attrs) class Tensor: def __init__(self, tesnor_dcit): self.tensor_dict = tesnor_dcit def _set(self, tensor_proto): tensor_proto.main_program_id = self.tensor_dict.get( "main_program_id", 0) tensor_proto.startup_program_id = self.tensor_dict.get( "startup_program_id", 0) tensor_proto.feed_var_name = self.tensor_dict.get("feed_var_name", '') tensor_proto.fetch_var_name = self.tensor_dict.get("fetch_var_name", '') tensor_proto.tensor_table_class = self.tensor_dict.get( "tensor_table_class", '') class Table: def __init__(self): self.table_class = None self.shard_num = -1 self.type = None self.accessor = Accessor() self.shard_num = 256 self.common = CommonAccessor() self.tensor = None def _set(self, table_proto): pass class BarrierTable(Table): def __init__(self, context, idx): super(BarrierTable, self).__init__() self.type = None self.shard_num = 256 self.accessor.accessor_class = 'CommMergeAccessor' self.common.attrs = "" self.common.dims = [] self.common.params = [] self.is_heter_ps_mode = context['is_heter_ps_mode'] self.role_maker = context['role_maker'] self.idx = idx self.is_sync = context['is_sync'] def _set(self, table_proto): table_proto.table_id = self.idx table_proto.table_class = 'BarrierTable' table_proto.shard_num = 256 table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE table_proto.accessor.accessor_class = "CommMergeAccessor" table_proto.accessor.fea_dim = 0 table_proto.accessor.embedx_dim = 0 table_proto.common.name = "" table_proto.common.table_name = "barrier_table" table_proto.common.sync = self.is_sync table_proto.common.entry = 'none' trainer_num = get_trainers(self.role_maker) if self.is_heter_ps_mode: trainer_num += len(self.role_maker._get_heter_worker_endpoints()) table_proto.common.trainer_num = trainer_num class TensorTable(Table): def __init__(self, idx, tensor_dict, role_maker): super(TensorTable, self).__init__() self.idx = idx self.tensor_dict = tensor_dict self.role_maker = role_maker def _set(self, table_proto): table_proto.table_id = self.idx table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE table_proto.table_class = self.tensor_dict.get("tensor_table_class", '') table_proto.accessor.accessor_class = "CommMergeAccessor" table_proto.common.table_name = self.tensor_dict.get( "feed_var_name", '') table_proto.common.trainer_num = get_trainers(self.role_maker) tensor = Tensor(self.tensor_dict) tensor._set(table_proto.tensor) class SparseTable(Table): def __init__(self, context, send_ctx): super(SparseTable, self).__init__() self.context = context self.ctx = send_ctx self.type = None self.table_class = 'MemorySparseTable' self.accessor = Accessor() def _set(self, table_proto): ctx = self.ctx if ctx.is_tensor_table() or len( ctx.origin_varnames()) < 1 or (ctx.is_sparse() == False): return table_proto.table_id = ctx.table_id() table_proto.table_class = self.table_class table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.shard_num = self.shard_num if table_proto.sparse_table_cache_file_num > len( get_ps_endpoints(self.context['role_maker'])): table_proto.sparse_table_cache_file_num = len( get_ps_endpoints(self.context['role_maker'])) self.common.table_name = self.context['grad_name_to_param_name'][ ctx.origin_varnames()[0]] self.common.parse_by_optimizer(ctx, self.context) self.common.parse_entry(self.common.table_name, ctx.program_id(), self.context) self.common.sync = True if self.context['is_sync'] else False self.common._set(table_proto.common) print('new table_name: {}'.format(self.common.table_name)) all_table_proto = self.context[ "user_defined_strategy"].sparse_table_configs usr_table_proto = all_table_proto.add() for proto in all_table_proto: if proto.table_name == self.common.table_name: usr_table_proto = proto break if usr_table_proto.HasField("table_class"): table_proto.table_class = usr_table_proto.table_class else: table_proto.table_class = 'MemorySparseTable' warnings.warn("The PS mode must use MemorySparseTable.") if usr_table_proto.HasField("shard_num"): table_proto.shard_num = usr_table_proto.shard_num else: if self.context['use_ps_gpu']: table_proto.shard_num = 37 warnings.warn( "The shard_num of sparse table is not set, use default value 37 in gpups." ) else: table_proto.shard_num = 1000 warnings.warn( "The shard_num of sparse table is not set, use default value 1000 in cpups." ) if usr_table_proto.HasField("enable_sparse_table_cache"): table_proto.enable_sparse_table_cache = usr_table_proto.enable_sparse_table_cache if usr_table_proto.HasField("sparse_table_cache_rate"): table_proto.sparse_table_cache_rate = usr_table_proto.sparse_table_cache_rate if usr_table_proto.HasField("sparse_table_cache_file_num"): table_proto.sparse_table_cache_file_num = usr_table_proto.sparse_table_cache_file_num if usr_table_proto.HasField("enable_revert"): table_proto.enable_revert = usr_table_proto.enable_revert if usr_table_proto.HasField("shard_merge_rate"): table_proto.shard_merge_rate = usr_table_proto.shard_merge_rate if usr_table_proto.accessor.ByteSize() == 0: warnings.warn( "The accessor of sparse table is not set, use default value.") table_proto.accessor.ParseFromString( usr_table_proto.accessor.SerializeToString()) self.accessor._set(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context, self.common) check_embedding_dim(table_proto.accessor, self.common.table_name, ctx.program_id(), self.context) class GeoSparseTable(SparseTable): def __init__(self, context, send_ctx): super(GeoSparseTable, self).__init__(context, send_ctx) self.table_class = "MemorySparseGeoTable" if self.context['ps_mode'] != DistributedMode.GEO: raise ValueError("not geo sparse table!") def _set(self, table_proto): ctx = self.ctx if ctx.is_tensor_table() or len( ctx.origin_varnames()) < 1 or (ctx.is_sparse() == False): return table_proto.table_id = ctx.table_id() table_proto.table_class = self.table_class table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE table_proto.shard_num = self.shard_num table_proto.accessor.accessor_class = 'CommMergeAccessor' table_proto.accessor.fea_dim = ctx.sections()[0] table_proto.accessor.embedx_dim = ctx.sections()[1] self.common.table_name = self.context['grad_name_to_param_name'][ ctx.origin_varnames()[0]] self.common.parse_by_optimizer(ctx, self.context) self.common.parse_entry(self.common.table_name, ctx.program_id(), self.context) self.common.sync = False self.common._set(table_proto.common) class DenseTable(Table): def __init__(self, context, send_ctx): super(DenseTable, self).__init__() self.context = context self.ctx = send_ctx self.accessor = Accessor() def _set(self, table_proto): ctx = self.ctx if ctx.is_tensor_table() or len( ctx.origin_varnames()) < 1 or (ctx.is_sparse() == True): return table_proto.table_id = ctx.table_id() table_proto.type = the_one_ps_pb2.PS_DENSE_TABLE table_proto.table_class = "MemoryDenseTable" table_proto.shard_num = 256 table_proto.accessor.accessor_class = 'CommMergeAccessor' table_proto.accessor.fea_dim = ctx.sections()[0] table_proto.accessor.embedx_dim = 1 self.common.table_name = "MergedDense" self.common.parse_by_optimizer(ctx, self.context) self.common.parse_entry(self.common.table_name, ctx.program_id(), self.context) self.common.sync = True if self.context['is_sync'] else False self.common._set(table_proto.common) class Server: def __init__(self): pass def _set(self): pass class DownpourServer(Server): def __init__(self): super(DownpourServer, self).__init__() def _set(self): pass class Worker: def __init__(self): pass def _set(self): pass class DownpourWorker(Worker): def __init__(self): super(DownpourWorker, self).__init__() def _set(self): pass class fsClient: def __init__(self, fs_client_param): self.fs_client_param = fs_client_param def _set(self, proto): if not text_format.MessageToString(self.fs_client_param): return proto.uri = self.fs_client_param.uri proto.user = self.fs_client_param.user proto.passwd = self.fs_client_param.passwd proto.hadoop_bin = self.fs_client_param.hadoop_bin class PsDescBuilder(object): def __init__(self, context): self.context = context self.is_sync = context['is_sync'] self.ps_mode = context['ps_mode'] self.is_heter_ps_mode = context['is_heter_ps_mode'] self.use_ps_gpu = context['use_ps_gpu'] self.barrier_table_id = None self.send_ctx = get_the_one_send_context( self.context, split_dense_table=self.is_heter_ps_mode) self.tensor_table_dict = {} # TODO self._server_sub_program = [] self.tables = self._get_tables() self.service = self._get_service() self.fs_client = self._get_fs_client() self.ps_desc = the_one_ps_pb2.PSParameter() self.fl_desc = the_one_ps_pb2.FLParameter() def _get_tensor_tables(self): program_idx = 0 if not self.tensor_table_dict: self._server_sub_program.append(Program().desc) tables = [] for table_name in self.tensor_table_dict: tables.append(globals()['TensorTable'](len(tables), tensor_dict, self.context['role_maker'])) program_idx += 1 return tables def _get_tables(self): tables = [] for idx, (name, ctx) in enumerate(self.send_ctx.items()): print("idx, name, ctx:", idx, name, ctx) if ctx.is_sparse(): if self.ps_mode == DistributedMode.GEO: if (self.context['local_sparse'] and name[:-5] in self.context['local_sparse']) or ( not self.context['local_sparse']): tables.append(globals()['GeoSparseTable'](self.context, ctx)) else: tables.append(globals()['SparseTable'](self.context, ctx)) else: tables.append(globals()['SparseTable'](self.context, ctx)) else: tables.append(globals()['DenseTable'](self.context, ctx)) self.tensor_tables = self._get_tensor_tables() tables.extend(self.tensor_tables) tables.append(globals()['BarrierTable'](self.context, len(tables))) return tables def _get_service(self): if self.use_ps_gpu: return GpuService() else: return Service() def _get_fs_client(self): return fsClient(self.context["user_defined_strategy"].fs_client_param) def build_fl_client_desc(self, client_info): pass def build_worker_desc(self): for table in self.tables: table_proto = self.ps_desc.worker_param.downpour_worker_param.downpour_table_param.add( ) table._set(table_proto) table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add( ) table._set(table_proto) if type(table) == BarrierTable and self.barrier_table_id is None: self.barrier_table_id = table.idx self.service._set( self.ps_desc.server_param.downpour_server_param.service_param) self.fs_client._set(self.ps_desc.fs_client_param) return text_format.MessageToString(self.ps_desc) def build_server_desc(self): self.sparse_table_maps = {} for table in self.tables: table_proto = self.ps_desc.server_param.downpour_server_param.downpour_table_param.add( ) table._set(table_proto) if table_proto.type == the_one_ps_pb2.PS_SPARSE_TABLE and table_proto.common is not None: self.sparse_table_maps[ table_proto.common.table_name] = table_proto.table_id self.service._set( self.ps_desc.server_param.downpour_server_param.service_param) self.fs_client._set(self.ps_desc.fs_client_param) return text_format.MessageToString(self.ps_desc) class TheOnePSRuntime(RuntimeBase): def __init__(self): super(TheOnePSRuntime, self).__init__() self._communicator = None self._server = None self._worker = fluid.core.DistFleetWrapper() self._coordinator = None self._server_sub_program = [] self._heter_client = None self._send_ctx = None def _set_basic_info(self, context): self.context = context self.role_maker = context["role_maker"] self.role_id = get_role_id(self.role_maker) self.debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) self.origin_main_program = context["origin_main_program"] self.origin_main_programs = context.get("origin_main_programs", [self.origin_main_program]) self.context["origin_main_programs"] = self.origin_main_programs self.context["origin_startup_programs"] = context.get( 'origin_startup_programs', [context['origin_startup_program']]) self.context[ 'is_heter_ps_mode'] = self.role_maker._is_heter_parameter_server_mode self.is_heter_ps_mode = self.context['is_heter_ps_mode'] self.context['trainer'] = TrainerRuntimeConfig( context['valid_strategy']) self.context['ps_mode'] = self.context['trainer'].mode self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[ 'use_ps_gpu'] self.context['is_sync'] = True if self.context[ 'ps_mode'] == DistributedMode.SYNC else False self.context['grad_name_to_param_name'] = {} self.context['tensor_table'] = {} # FL self.context['local_sparse'] = context[ "user_defined_strategy"].trainer_desc_configs["local_sparse"] self.context['remote_sparse'] = context[ "user_defined_strategy"].trainer_desc_configs["remote_sparse"] print("fl-ps > local_sparse: {}, remote_sparse: {}".format( self.context['local_sparse'], self.context['remote_sparse'])) build_var_distributed(self.context) self.trainer_endpoints = get_trainer_endpoints(self.role_maker) self.endpoints = get_ps_endpoints(self.role_maker) self.string_hosts = [] for idx, ep in enumerate(self.endpoints): host, port = ep.split(":") pshost = fluid.core.PSHost(host, int(port), idx) self.string_hosts.append(pshost.serialize_to_string()) self.with_coordinator = self.role_maker._with_coordinator self.coordinator_hosts = [] if self.with_coordinator: print("fl-ps > all ps addrs: {}".format(self.string_hosts)) coordinator_endpoints = self.role_maker._get_coordinator_endpoints() for idx, ep in enumerate(coordinator_endpoints): ip, port = ep.split(":") pshost = fluid.core.PSHost(ip, int(port), idx) self.coordinator_hosts.append(pshost.serialize_to_string()) self.ps_desc_builder = PsDescBuilder(self.context) def _init_all_params(self, scopes, send_ctx, recv_map): all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue _, _, idx = get_program_by_id(self.context, ctx.program_id()) scope = scopes[idx] table_id = ctx.table_id() var_names = recv_map[table_id] #print("init params:", idx, table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names) all_var_names.extend(var_names) return all_var_names def _pull_all_dense(self, scopes, send_ctx, recv_map): all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue _, _, idx = get_program_by_id(self.context, ctx.program_id()) scope = scopes[idx] table_id = ctx.table_id() var_names = recv_map[table_id] #print("pull all dense:", idx, table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names) all_var_names.extend(var_names) return all_var_names def _init_params(self, program, scope, send_ctx, recv_map): all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue if ctx.program_id() != id(program): continue table_id = ctx.table_id() var_names = recv_map[table_id] # print("init params:", table_id, var_names) self._worker.push_dense_params(scope, table_id, var_names) all_var_names.extend(var_names) return all_var_names def _pull_dense(self, program, scope, send_ctx, recv_map): all_var_names = [] for name, ctx in send_ctx.items(): if ctx.is_sparse(): continue if ctx.program_id() != id(program): continue table_id = ctx.table_id() var_names = recv_map[table_id] # print("pull dense:", table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names) all_var_names.extend(var_names) return all_var_names def _init_worker(self, scopes=None): worker_desc = self.ps_desc_builder.build_worker_desc() if self.context['use_ps_gpu']: main_program = self.context['loss'].block.program if not main_program._fleet_opt: main_program._fleet_opt = {} main_program._fleet_opt["use_ps_gpu"] = True gpus_env = os.getenv("FLAGS_selected_gpus") gpus_env = [int(s) for s in gpus_env.split(",")] main_program._fleet_opt["worker_places"] = gpus_env PSGPU = fluid.core.PSGPU() PSGPU.init_gpu_ps(gpus_env) def sync_strategy_envs(): kwargs = {} kwargs[ "pserver_endpoints"] = self.role_maker._get_pserver_endpoints() kwargs["trainer_id"] = self.role_maker._worker_index() return kwargs dense_map = get_the_one_recv_context( self.context, split_dense_table=self.is_heter_ps_mode) send_ctx = get_the_one_send_context( self.context, split_dense_table=self.is_heter_ps_mode, ep_list=self.endpoints) self._send_ctx = send_ctx trainer_config = self.context['trainer'] if self.debug: print("worker_desc: \n{}".format(worker_desc)) print("communicator send_ctx:") for key in send_ctx: print("{}: {}".format(key, send_ctx[key])) for key in dense_map: print("{}: {}".format(key, dense_map[key])) kwargs = {} kwargs['need_global_step'] = "0" kwargs["trainer_id"] = self.role_maker._role_id() kwargs["trainers"] = self.role_maker._worker_num() kwargs["barrier_table_id"] = self.ps_desc_builder.barrier_table_id if self.context['ps_mode'] == DistributedMode.SYNC: sync_kwargs = sync_strategy_envs() kwargs.update(sync_kwargs) print("communicator config:", trainer_config.get_communicator_flags()) self._worker.init_worker(worker_desc, self.string_hosts, self.role_id) if not self.is_heter_ps_mode: self.trainer_endpoint = get_trainer_endpoint(self.role_maker) print("fl-ps > trainer_endpoint: {}".format(self.trainer_endpoint)) print("fl-ps > with_coordinator? {}".format(self.with_coordinator)) print("fl-ps > coordinator addr: {}".format(self.coordinator_hosts)) if self.with_coordinator: self._worker.init_fl_worker(self.coordinator_hosts, self.role_id, self.trainer_endpoint) if self.context[ 'ps_mode'] == DistributedMode.GEO or self.is_heter_ps_mode: self._communicator = Communicator( trainer_config.mode, kwargs, trainer_config.get_communicator_flags()) self._communicator.init_with_ctx(send_ctx, dense_map, worker_desc, self.string_hosts, fluid.global_scope()) fleet.util.barrier() # info = self._communicator.get_client_info() info = self._worker.get_client_info() if isinstance(info, list) and len(info) > 0: all_info = self.role_maker._all_gather( info[0]) # 收集其他 client 的 service 地址 # for unittest if not isinstance(all_info, list): warnings.warn("gloo may not initialize correctly") all_info = [all_info] # self._communicator.set_clients(all_info) # self._communicator.create_client_to_client_connection() self._worker.set_clients(all_info) self._worker.create_client2client_connection() print('create c2c connection done') else: print('cannot create c2c connection') dist_strategy = self.context["valid_strategy"] is_test = bool(int(os.getenv("TEST_MODE", "0"))) if scopes is None: if len(self.origin_main_programs) > 1: raise ValueError( "You must set the scope list when you have Multiple programs" ) scopes = [fluid.global_scope()] if len(self.origin_main_programs) != len(scopes): raise VauleError("len(programs) != len(scopes)") self.scopes = scopes if not is_test: if self.context[ 'ps_mode'] == DistributedMode.GEO or self.is_heter_ps_mode == True: self._communicator.init_params(dense_map) else: if not self.context['use_ps_gpu']: if self.role_id == 0: print("entering self._init_all_params()") self._init_all_params(scopes, send_ctx, dense_map) fleet.util.barrier() # 保证 0 号 worker 参数 push_dense_param over if not self.context['use_ps_gpu']: self._pull_all_dense(scopes, send_ctx, dense_map) fleet.util.barrier() if self.context[ 'ps_mode'] == DistributedMode.GEO or self.is_heter_ps_mode == True: if not self._communicator.is_running(): self._communicator.start() else: warnings.warn("communicator has been initialized, skip") launch_barrier = dist_strategy.a_sync_configs["launch_barrier"] launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1")) if launch_barrier and launch_barrier_flag: wait_server_ready(self.role_maker._get_pserver_endpoints()) if self.is_heter_ps_mode and self.role_maker._get_next_trainers( ) != []: wait_server_ready(self.role_maker._get_next_trainers()) if self.is_heter_ps_mode: previous_trainers = [] if self.role_maker._get_previous_trainers() != []: previous_trainers = self.role_maker._get_previous_trainers() next_trainers = [] if self.role_maker._get_next_trainers() != []: next_trainers = self.role_maker._get_next_trainers() self._heter_client = HeterClient( next_trainers, previous_trainers, self.role_maker._role_id()) # --> HeterClient::GetInstance def _init_coordinator(self, scopes=None): if self._coordinator == None: self._coordinator = Coordinator(self.string_hosts) print(">>> curr node ip: {}".format(self.coordinator_hosts[0])) print(">>> all trainer endpoints: {}".format(self.trainer_endpoints)) self._coordinator.start_coordinator(self.coordinator_hosts[0], self.trainer_endpoints) def _make_fl_strategy(self): if self._coordinator == None: assert ("Coordinator py object is null!") else: self._coordinator.make_fl_strategy() def _init_server(self, dirname=None, var_names=None, **kwargs): server_desc = self.ps_desc_builder.build_server_desc() trainers = get_trainers(self.role_maker) if self.is_heter_ps_mode: trainers += len(self.role_maker._get_heter_worker_endpoints()) if self.debug: print("server_desc: \n{}".format(server_desc)) self._server = fluid.core.DistFleetWrapper() self._server.init_server(server_desc, self.string_hosts, self.role_id, trainers, self._server_sub_program) dist_varnames = get_sparse_tablenames(self.origin_main_programs, True) sparse_varnames = get_sparse_tablenames(self.origin_main_programs, False) distributed_varnames = dist_varnames + sparse_varnames if var_names is None: load_varnames = distributed_varnames else: for var_name in var_names: if var_name not in distributed_varnames: raise ValueError( "fleet.init server can only load sparse variables in {}" .format(distributed_varnames)) load_varnames = var_names if dirname is None or not load_varnames: return sparse_table_maps = self.ps_desc_builder.sparse_table_maps dirname = os.path.normpath(dirname) pserver_id = self.role_maker._role_id() for var_name in load_varnames: table_id = sparse_table_maps[var_name] self._server.load_sparse(dirname, "0", table_id) def _run_server(self): ep = get_ps_endpoint(self.role_maker) host, port = ep.split(":") self._server.run_server(host, int(port)) def _stop_worker(self): if self.context['ps_mode'] == DistributedMode.GEO: self._communicator.stop() self._worker.stop_worker() if self.is_heter_ps_mode: assert self._heter_client != None, "heter client should not be None in heterps mode" self._heter_client.stop() @staticmethod def __exclude_vars(exclude_var_names=[]): def is_valid(var): if var.name in exclude_var_names: return False from .utils.public import _get_varname_parts origin_varname, _, _ = _get_varname_parts(var.name) if origin_varname.endswith("@GRAD"): return False if origin_varname.startswith("learning_rate_"): return False if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.READER: return False return var.persistable return is_valid def _get_inference_model_path(self, dirname): if dirname.startswith("afs:") or dirname.startswith("hdfs:"): model_path = "./dnn_plugin" else: model_path = os.path.join(dirname, "dnn_plugin") return model_path def _ps_save_dense_params(self, executor, dirname, scope, program, var_names=None): dense_map = get_the_one_recv_context( self.context, split_dense_table=self.is_heter_ps_mode) send_ctx = get_the_one_send_context( self.context, split_dense_table=self.is_heter_ps_mode, ep_list=self.endpoints) if program is None or len(self.origin_main_programs) == 1: program = self.origin_main_programs[0] dense_var_names = self._pull_dense(program, scope, send_ctx, dense_map) save_var_names = dense_var_names if var_names is None else var_names vars = [program.global_block().var(i) for i in save_var_names] import paddle with paddle.static.scope_guard(scope): paddle.static.save_vars(executor, "./", program, vars=vars, filename=dirname) def _save_sparse_params(self, executor, dirname, context, main_program, mode): distributed_varnames = get_sparse_tablenames(self.origin_main_programs, True) values = [] model_path = self._get_inference_model_path(dirname) for id, names in context.items(): if names[0] not in distributed_varnames: # only save sparse param to local try: self._worker.recv_and_save_model(id, model_path) except: pass # save sparse & distributed param on server self._worker.save_one_model(id, dirname, mode) values.extend(names) # self._worker.save_all_model(dirname, mode) return values def _save_distributed_persistables(self, executor, dirname, main_program=None, mode=0, **kwargs): """ This function filters out all variables with `persistable==True` from the give `main_program` and then saves these variables to the folder `dirname` or file `filename`. The `dirname` is used to specify the folder where persistable variables are going to be saved. If you would like to save variables in separate files, set `filename` None; if you would like to save all variables in a single file, use `filename` to specify the file name. """ if isinstance(executor, ParallelExecutor): raise TypeError( "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed" ) if not isinstance(executor, Executor): raise TypeError( "in fleet.save() function, executor must be as Executor type") if main_program is None: main_program = self.context['origin_main_program'] if isinstance(main_program, CompiledProgram): raise TypeError( "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" ) self._worker.save_all_model(dirname, mode) def _ps_inference_save_inference_model(self, executor, dirname, feeded_var_names, target_vars, main_program=None, export_for_deployment=True, mode=0): """ Prune the given `main_program` to build a new program especially for inference, and then save it and all related parameters to given `dirname` by the `executor`. """ if isinstance(executor, ParallelExecutor): raise TypeError( "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed" ) if not isinstance(executor, Executor): raise TypeError( "in fleet.save() function, executor must be as Executor type") import paddle program = self.origin_main_programs[ 0] if main_program is None else main_program _, _, idx = get_program_by_id(self.context, id(program)) scope = self.scopes[idx] print("save inference model scope idx:", idx) if isinstance(program, CompiledProgram): raise TypeError( "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" ) feed_vars = [ program.global_block().var(name) for name in feeded_var_names ] infer_program = paddle.static.normalize_program(program, feed_vars, target_vars) infer_program._copy_dist_param_info_from(program) model_path = self._get_inference_model_path(dirname) model_basename = "__model__" model_basename = os.path.join(model_path, model_basename) paddle.save(infer_program, model_basename) sparses = get_the_one_recv_context( self.context, is_dense=False, split_dense_table=self.is_heter_ps_mode) sparse_names = self._save_sparse_params(executor, dirname, sparses, main_program, mode) dense_map = get_the_one_recv_context( self.context, split_dense_table=self.is_heter_ps_mode) send_ctx = get_the_one_send_context( self.context, split_dense_table=self.is_heter_ps_mode, ep_list=self.endpoints) self._pull_dense(program, scope, send_ctx, dense_map) generate_vars = self.context[ "user_defined_strategy"].trainer_desc_configs["stat_var_names"] generate_vars = [var for var in generate_vars] remaining_vars = list( filter(TheOnePSRuntime.__exclude_vars(sparse_names), infer_program.list_vars())) for var in remaining_vars: tensor = var.get_value(scope) paddle.save(tensor, os.path.join(model_path, var.name), use_binary_format=True) def _save_cache_model(self, dirname, **kwargs): mode = kwargs.get("mode", 1) table_id = kwargs.get("table_id", 0) self._worker.client_flush() fleet.util.barrier() cache_threshold = 0.0 if self.role_maker._is_first_worker(): cache_threshold = self._worker.get_cache_threshold(table_id) #check cache threshold right or not fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.cache_shuffle(table_id, dirname, mode, cache_threshold) fleet.util.barrier() feasign_num = -1 if self.role_maker._is_first_worker(): feasign_num = self._worker.save_cache(table_id, dirname, mode) fleet.util.barrier() return feasign_num def _check_save_pre_patch_done(self): fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.check_save_pre_patch_done() fleet.util.barrier() def _load_sparse_params(self, dirname, context, main_program, mode): distributed_varnames = get_sparse_tablenames(self.origin_main_programs, True) values = [] for id, names in context.items(): if names[0] not in distributed_varnames: # TODO: only load sparse param from local warnings.warn("varname is not in distributed_varnames, pass") # load sparse & distributed param on server self._worker.load_one_table(id, dirname, mode) values.extend(names) return values def _ps_inference_load_inference_model(self, dirname, mode=0, main_program=None): main_program = self.origin_main_programs[ 0] if main_program is None else main_program _, _, idx = get_program_by_id(self.context, id(main_program)) scope = self.scopes[idx] print("load inference model scope idx:", idx) if isinstance(main_program, CompiledProgram): raise TypeError( "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" ) sparses = get_the_one_recv_context( self.context, is_dense=False, split_dense_table=self.is_heter_ps_mode) sparse_varnames = self._load_sparse_params(dirname, sparses, main_program, mode) dense_map = get_the_one_recv_context( self.context, split_dense_table=self.is_heter_ps_mode) send_ctx = get_the_one_send_context( self.context, split_dense_table=self.is_heter_ps_mode, ep_list=self.endpoints) recv_dense_varnames = [] for _, names in dense_map.items(): recv_dense_varnames.extend(names) loaded_varnames = sparse_varnames remaining_vars = list( filter(TheOnePSRuntime.__exclude_vars(loaded_varnames), main_program.list_vars())) model_path = self._get_inference_model_path(dirname) import paddle for var in remaining_vars: if var.name not in recv_dense_varnames: continue tensor = paddle.load(os.path.join(model_path, var.name)) var.set_value(tensor, scope) self._init_params(main_program, scope, send_ctx, dense_map) def _save_one_table(self, table_id, path, mode): fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.save_one_model(table_id, path, mode) fleet.util.barrier() def _save_dense_params(self, *args, **kwargs): fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_save_dense_params(*args, **kwargs) fleet.util.barrier() def _save_persistables(self, *args, **kwargs): fleet.util.barrier() if self.role_maker._is_first_worker(): self._save_distributed_persistables(*args, **kwargs) fleet.util.barrier() def _save_inference_model(self, *args, **kwargs): fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_inference_save_inference_model(*args, **kwargs) fleet.util.barrier() def _load_one_table(self, table_id, path, mode): fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.load_one_table(table_id, path, mode) fleet.util.barrier() def _load_persistables(self, path, mode): fleet.util.barrier() if self.role_maker._is_first_worker(): self._worker.load_model(path, mode) fleet.util.barrier() def _load_inference_model(self, path, mode): fleet.util.barrier() if self.role_maker._is_first_worker(): self._ps_inference_load_inference_model(path, mode) fleet.util.barrier() def _shrink(self, threshold=None): if threshold is not None: warnings.warn( "The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor" ) else: threshold = 0 fleet.util.barrier() if self.role_maker._is_first_worker(): sparses = get_the_one_recv_context( self.context, is_dense=False, split_dense_table=self.role_maker. _is_heter_parameter_server_mode) for id, names in sparses.items(): self._worker.shrink_sparse_table(id, threshold) fleet.util.barrier()