# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function """ Steps to transpile trainer: 1. split variable to multiple blocks, aligned by product(dim[1:]) (width). 2. rename splited grad variables to add trainer_id suffix ".trainer_%d". 3. modify trainer program add split_op to each grad variable. 4. append send_op to send splited variables to server and 5. add recv_op to fetch params(splited blocks or origin param) from server. 6. append concat_op to merge splited blocks to update local weights. Steps to transpile pserver: 1. create new program for parameter server. 2. create params and grad variables that assigned to current server instance. 3. create a sub-block in the server side program 4. append ops that should run on current server instance. 5. add listen_and_serv op """ import math import numpy as np import collections import logging from .ps_dispatcher import RoundRobin, PSDispatcher from .. import core, framework, unique_name from ..framework import Program, default_main_program, \ default_startup_program, Block, \ Parameter, Variable, grad_var_name from .details import * from ..distribute_lookup_table import find_distributed_lookup_table from functools import reduce LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( ) OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC DIST_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Dist LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched PRINT_LOG = False def log(*args): if PRINT_LOG: print(args) class VarStruct(object): """ record part properties of a Variable in python. """ def __init__(self, name, shape, dtype, type, lod_level, persistable): self.name = name self.shape = shape self.dtype = dtype self.type = type self.lod_level = lod_level self.persistable = persistable class VarDistributed(object): """ a class to record the var distributed on parameter servers. the class will record the relationship between origin var and slice var. the slice var's properties, such as type/shape/offset/endpoint. """ def __init__(self, origin_var, slice_var, is_slice=None, block_id=None, offset=None, vtype=None, endpoint=None): """ Args: origin_var(Variable|VarStruct): origin var properties slice_var(Variable|VarStruct): slice var properties is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. block_id(int|None): the number about the slice var. offset(int|None): if the slice var is sliced, offset is the numel before the var. vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" """ if isinstance(origin_var, Variable): self.origin = self.__create_var_struct(origin_var) else: self.origin = origin_var if isinstance(slice_var, Variable): self.slice = self.__create_var_struct(slice_var) else: self.slice = slice_var if self.equal(self.origin, self.slice): self.is_slice = False self.block_id = 0 self.offset = 0 else: self.is_slice = True self.block_id = 0 self.offset = 0 if is_slice is not None: self.is_slice = is_slice if block_id is not None: self.block_id = block_id if offset is not None: self.offset = offset self.vtype = vtype self.endpoint = endpoint @staticmethod def __create_var_struct(var): return VarStruct(var.name, var.shape, var.dtype, var.type, var.lod_level, var.persistable) @staticmethod def equal(var1, var2): """ the two var is equal or not. Returns: bool: equal will return True else False """ assert isinstance(var1, VarStruct) and isinstance(var2, VarStruct) return var1.name == var2.name and \ var1.type == var2.type and \ var1.shape == var2.shape and \ var1.dtype == var2.dtype and \ var1.lod_level == var2.lod_level and \ var1.persistable == var2.persistable def __str__(self): origin_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})". \ format(i="{", e="}", name=self.origin.name, type=self.origin.type, shape=self.origin.shape, dtype=self.origin.dtype) slice_var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})" \ ".slice({is_slice}).block({block_id}).offset({offset})". \ format(i="{", e="}", name=self.slice.name, type=self.slice.type, shape=self.slice.shape, dtype=self.slice.dtype, is_slice=self.is_slice, block_id=self.block_id, offset=self.offset) return "var owned: {}, origin var: ( {} ), slice var: ( {} ), endpoint: {} ".format( self.vtype, origin_var_str, slice_var_str, self.endpoint) class VarsDistributed(object): """ a gather about VarDistributed with many methods to find distributed vars. through the class, we can get overview about the distributed parameters on parameter servers. this class may centralized and convenient for developer to manage and get variable's distribute. other module can also use this to find variables such io.py. """ def __init__(self): self.distributed_vars = [] def add_distributed_var(self, origin_var, slice_var, is_slice=None, block_id=None, offset=None, vtype=None, endpoint=None): """ add distributed var in this. Args: origin_var(Variable|VarStruct): origin var properties slice_var(Variable|VarStruct): slice var properties is_slice(bool|None): slice or not, slice_var=True/False and its block size > 8192 are the judgement standard. block_id(int|None): the number about the slice var. offset(int|None): if the slice var is sliced, offset is the numel before the var. vtype(str|None): a tag, such as Optimizer/Param/RemoteProfetch. endpoint(str|None): which parameter the slice var on, such as "127.0.0.1:1001" Returns: None """ self.distributed_vars.append( VarDistributed(origin_var, slice_var, is_slice, block_id, offset, vtype, endpoint)) def get_distributed_var_by_slice(self, var_name): """ get distributed var by conditions. Args: var_name(str): slice var name, such as "w.traier0.block1" Returns: VarDistributed: distributed var. """ for dist_var in self.distributed_vars: if dist_var.slice.name == var_name: return dist_var return None @staticmethod def equal(var1, var2): """ the two var is equal or not. Returns: bool: equal will return True else False """ return var1.name == var2.name and \ var1.type == var2.type and \ var1.shape == var2.shape and \ var1.dtype == var2.dtype and \ var1.lod_level == var2.lod_level and \ var1.persistable == var2.persistable def get_distributed_var_by_origin_and_ep(self, origin_var_name, endpoint): """ get distributed var by conditions. Args: origin_var_name(str): endpoint(str): the parameter endpoint, such as "127.0.0.1:1001" Returns: VarDistributed: distributed var. """ for dist_var in self.distributed_vars: if dist_var.origin.name == origin_var_name and dist_var.endpoint == endpoint: return dist_var return None def get_distributed_vars_by_vtypes(self, vtypes, groupby=False): """ get distributed vars by conditions. Args: vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" groupby(bool|False): group by origin var or not. Returns: list: distributed var list. dict: distributed var map when groupby=True """ vtype_vars = [] for var in self.distributed_vars: if var.vtype in vtypes: vtype_vars.append(var) if not groupby: return vtype_vars params_map = {} for var in vtype_vars: origin_var_name = var.origin.name if origin_var_name in params_map.keys(): optimizers = params_map.get(origin_var_name) else: optimizers = [] optimizers.append(var) params_map[origin_var_name] = optimizers return params_map def get_distributed_vars_by_ep(self, endpoint, vtype=None): """ get distributed vars by conditions. Args: endpoint(str): the parameter server endpoint, such as "127.0.0.1:2001" vtype(str|None): distributed var's vtype, such as "Optimizer", "RemotePrefetch" Returns: list: distributed var list. """ endpoint_vars = [] for var in self.distributed_vars: if var.endpoint == endpoint: endpoint_vars.append(var) if not vtype: return endpoint_vars vtype_vars = [] for var in endpoint_vars: if var.vtype == vtype: vtype_vars.append(var) return vtype_vars def overview(self): """ get the overview string about all params on all parameter servers. Returns: Str: overview string. """ vars_str = [] for var in self.distributed_vars: vars_str.append(str(var)) return "\n".join(vars_str) class VarBlock: def __init__(self, varname, offset, size): self.varname = varname # NOTE: real offset is offset * size self.offset = offset self.size = size def __str__(self): return "%s:%d:%d" % (self.varname, self.offset, self.size) def same_or_split_var(p_name, var_name): return p_name == var_name or p_name.startswith(var_name + ".block") def slice_variable(var_list, slice_count, min_block_size): """ We may need to split dense tensor to one or more blocks and put them equally onto parameter server. One block is a sub-tensor aligned by dim[0] of the tensor. We need to have a minimal block size so that the calculations in the parameter server side can gain better performance. By default minimum block size 8K elements (maybe 16bit or 32bit or 64bit). Args: var_list (list): List of variables. slice_count (int): Numel of count that variables will be sliced, which could be the pserver services' count. min_block_size (int): Minimum splitted block size. Returns: blocks (list[(varname, block_id, current_block_size)]): A list of VarBlocks. Each VarBlock specifies a shard of the var. """ blocks = [] for var in var_list: split_count = slice_count var_numel = reduce(lambda x, y: x * y, var.shape) max_pserver_count = int(math.floor(var_numel / float(min_block_size))) if max_pserver_count == 0: max_pserver_count = 1 if max_pserver_count < slice_count: split_count = max_pserver_count block_size = int(math.ceil(var_numel / float(split_count))) if len(var.shape) >= 2: # align by dim1(width) dim1 = reduce(lambda x, y: x * y, var.shape[1:]) remains = block_size % dim1 if remains != 0: block_size += dim1 - remains # update split_count after aligning split_count = int(math.ceil(var_numel / float(block_size))) for block_id in range(split_count): curr_block_size = min(block_size, var_numel - ( (block_id) * block_size)) block = VarBlock(var.name, block_id, curr_block_size) blocks.append(str(block)) return blocks class DistributeTranspilerConfig(object): """ .. py:attribute:: slice_var_up (bool) Do Tensor slice for pservers, default is True. .. py:attribute:: split_method (PSDispatcher) RoundRobin or HashName can be used. Try to choose the best method to balance loads for pservers. .. py:attribute:: min_block_size (int) Minimum number of splitted elements in block. According to : https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156 We can use bandwidth effiently when data size is larger than 2MB.If you want to change it, please be sure you have read the slice_variable function. """ slice_var_up = True split_method = None min_block_size = 8192 enable_dc_asgd = False # supported modes: pserver, nccl2 mode = "pserver" print_log = False wait_port = True class DistributeTranspiler(object): """ **DistributeTranspiler** Convert the fluid program to distributed data-parallelism programs. Supports two modes: pserver mode and nccl2 mode. In pserver mode, the main_program will be transformed to use a remote parameter server to do parameter optimization. And the optimization graph will be put into a parameter server program. In nccl2 mode, the transpiler will append a NCCL_ID broadcasting op in startup_program to share the NCCL_ID across the job nodes. After transpile_nccl2 called, you ***must*** pass trainer_id and num_trainers argument to ParallelExecutor to enable NCCL2 distributed mode. Examples: .. code-block:: python # for pserver mode pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" current_endpoint = "192.168.0.1:6174" trainer_id = 0 trainers = 4 role = os.getenv("PADDLE_TRAINING_ROLE") t = fluid.DistributeTranspiler() t.transpile( trainer_id, pservers=pserver_endpoints, trainers=trainers) if role == "PSERVER": pserver_program = t.get_pserver_program(current_endpoint) pserver_startup_program = t.get_startup_program(current_endpoint, pserver_program) elif role == "TRAINER": trainer_program = t.get_trainer_program() # for nccl2 mode config = fluid.DistributeTranspilerConfig() config.mode = "nccl2" t = fluid.DistributeTranspiler(config=config) t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep) exe = fluid.ParallelExecutor( use_cuda, loss_name=loss_var.name, num_trainers=len(trainers.split(",)), trainer_id=trainer_id ) """ def __init__(self, config=None): if config is not None: self.config = config else: self.config = DistributeTranspilerConfig() if self.config.split_method is None: self.config.split_method = RoundRobin global PRINT_LOG if self.config.print_log: PRINT_LOG = True assert (self.config.min_block_size >= 8192) assert (self.config.split_method.__bases__[0] == PSDispatcher) def _transpile_nccl2(self, trainer_id, trainers, current_endpoint, startup_program=None, wait_port=True): if not startup_program: startup_program = default_startup_program() if trainer_id >= 0: worker_endpoints = trainers.split(",") # send NCCL_ID to others or recv from trainer 0 worker_endpoints.remove(current_endpoint) if trainer_id == 0 and wait_port: wait_server_ready(worker_endpoints) nccl_id_var = startup_program.global_block().create_var( name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) startup_program.global_block().append_op( type="gen_nccl_id", inputs={}, outputs={"NCCLID": nccl_id_var}, attrs={ "endpoint": current_endpoint, "endpoint_list": worker_endpoints, "trainer_id": trainer_id }) return nccl_id_var else: raise ValueError("must set trainer_id > 0") def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] for op in main_program.global_block().ops: if op.type in sparse_update_op_types and op.attr( 'remote_prefetch') is True: sparse_update_ops.append(op) return sparse_update_ops def _update_remote_sparse_update_op(self, param_varname, height_sections, endpint_map, table_names): for op in self.sparse_update_ops: if param_varname in op.input_arg_names: op._set_attr('epmap', endpint_map) op._set_attr('table_names', table_names) op._set_attr('height_sections', height_sections) op._set_attr('trainer_id', self.trainer_id) def _is_input_of_remote_sparse_update_op(self, param_name): for op in self.sparse_update_ops: if param_name in op.input_arg_names: return True return False def transpile(self, trainer_id, program=None, pservers="127.0.0.1:6174", trainers=1, sync_mode=True, startup_program=None, current_endpoint="127.0.0.1:6174"): """ Run the transpiler. Args: trainer_id (int): id for current trainer worker, if you have n workers, the id may range from 0 ~ n-1 program (Program|None): program to transpile, default is fluid.default_main_program(). startup_program (Program|None): startup_program to transpile, default is fluid.default_startup_program(). pservers (str): comma separated ip:port string for the pserver list. trainers (int|str): in pserver mode this is the number of trainers, in nccl2 mode this is a string of trainer endpoints. sync_mode (bool): Do sync training or not, default is True. startup_program (Program|None): startup_program to transpile, default is fluid.default_main_program(). current_endpoint (str): need pass current endpoint when transpile as nccl2 distributed mode. In pserver mode this argument is not used. """ if program is None: program = default_main_program() if startup_program is None: startup_program = default_startup_program() self.origin_program = program self.startup_program = startup_program self.origin_startup_program = self.startup_program.clone() if self.config.mode == "nccl2": assert (isinstance(trainers, str)) self.origin_program._trainers_endpoints = trainers.split(",") self._transpile_nccl2( trainer_id, trainers, current_endpoint, startup_program=startup_program, wait_port=self.config.wait_port) return self.trainer_num = trainers self.sync_mode = sync_mode self.trainer_id = trainer_id pserver_endpoints = pservers.split(",") self.pserver_endpoints = pserver_endpoints self.vars_overview = VarsDistributed() self.optimize_ops, self.params_grads = self._get_optimize_pass() ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.table_name = find_distributed_lookup_table(self.origin_program) self.has_distributed_lookup_table = self.table_name != None self.param_name_to_grad_name = dict() self.grad_name_to_param_name = dict() for param_var, grad_var in self.params_grads: self.param_name_to_grad_name[param_var.name] = grad_var.name self.grad_name_to_param_name[grad_var.name] = param_var.name # get all sparse update ops self.sparse_update_ops = self._get_all_remote_sparse_update_op( self.origin_program) # use_sparse_update_param_name -> split_height_section self.sparse_param_to_height_sections = dict() # add distributed attrs to program self.origin_program._is_distributed = True self.origin_program._endpoints = self.pserver_endpoints self.origin_program._ps_endpoint = current_endpoint self.origin_program._is_chief = self.trainer_id == 0 self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None # split and create vars, then put splited vars in dicts for later use. # step 1: split and create vars, then put splited vars in dicts for later use. self._init_splited_vars() # step 2: insert send op to send gradient vars to parameter servers ps_dispatcher.reset() send_vars = [] # in general cases, the number of pservers is times of 2, and this # will lead to uneven distribution among weights and bias: # fc_w@GRAD_trainer_0, fc_w@GRAD_trainer_1 --> pserver1 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # shuffle the map will avoid the uneven distribution above grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping)) if not self.config.slice_var_up: np.random.seed(self.origin_program.random_seed) np.random.shuffle(grad_var_mapping_items) self.grad_name_to_send_dummy_out = dict() for grad_varname, splited_vars in grad_var_mapping_items: eplist = ps_dispatcher.dispatch(splited_vars) if not self.config.slice_var_up: assert (len(splited_vars) == 1) splited_grad_varname = grad_varname if len(splited_vars) == 1: splited_grad_varname = splited_vars[0].name index = find_op_by_output_arg( program.global_block(), splited_grad_varname, reverse=True) if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS: sparse_param_name = self.grad_name_to_param_name[ grad_varname] if self._is_input_of_remote_sparse_update_op( sparse_param_name): self.sparse_param_to_height_sections[ sparse_param_name] = [splited_vars[0].shape[0]] elif len(splited_vars) > 1: orig_var = program.global_block().vars[splited_grad_varname] index = find_op_by_output_arg( program.global_block(), splited_grad_varname, reverse=True) self._insert_split_op(program, orig_var, index, splited_vars) index += 1 else: AssertionError("Can not insert the send op by original " "variable name :", splited_grad_varname) dummy_output = program.global_block().create_var( name=framework.generate_control_dev_var_name()) self.grad_name_to_send_dummy_out[grad_varname] = dummy_output # get send op_role_var, if not splited, the grad should have .trainer suffix # if splited, grad should be the original grad var name (split_by_ref and send # will be on the same place). ParallelExecutor # will use op_role_var to get expected device place to run this op. program.global_block()._insert_op( index=index + 1, type="send", inputs={"X": splited_vars}, outputs={"Out": dummy_output}, attrs={ "epmap": eplist, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [ self.grad_name_to_param_name[grad_varname], splited_grad_varname ], "sync_mode": not self.sync_mode, }) for _, var in enumerate(splited_vars): send_vars.append(var) if self.sync_mode: send_barrier_out = program.global_block().create_var( name=framework.generate_control_dev_var_name()) if self.has_distributed_lookup_table: self.grad_name_to_send_dummy_out[ self.table_name] = program.global_block().create_var( name=framework.generate_control_dev_var_name()) input_deps = list(self.grad_name_to_send_dummy_out.values()) program.global_block().append_op( type="send_barrier", inputs={"X": list(input_deps)}, outputs={"Out": send_barrier_out}, attrs={ "endpoints": pserver_endpoints, "sync_mode": self.sync_mode, "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) # step 3: insert recv op to receive parameters from parameter server recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) ps_dispatcher.reset() eplist = ps_dispatcher.dispatch(recv_vars) for i, ep in enumerate(eplist): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) distributed_var = self.vars_overview.get_distributed_var_by_slice( recv_vars[i].name) distributed_var.endpoint = ep # step4: Concat the parameters splits together after recv. all_recv_outputs = [] for param_varname, splited_var in six.iteritems(self.param_var_mapping): eps = [] table_names = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) eps.append(eplist[index]) table_names.append(var.name) if self.sync_mode: recv_dep_in = send_barrier_out else: # connect deps to send op in async mode recv_dep_in = self.grad_name_to_send_dummy_out[ self.param_name_to_grad_name[param_varname]] # get recv op_role_var, if not splited, the grad should have .trainer suffix # if splited, grad should be the original grad var name. ParallelExecutor # will use op_role_var to get expected device place to run this op. orig_grad_name = self.param_name_to_grad_name[param_varname] recv_op_role_var_name = orig_grad_name splited_trainer_grad = self.grad_var_mapping[orig_grad_name] if len(splited_trainer_grad) == 1: recv_op_role_var_name = splited_trainer_grad[0].name if param_varname in self.sparse_param_to_height_sections: for table_name in table_names: distributed_var = self.vars_overview.get_distributed_var_by_slice( table_name) distributed_var.vtype = "RemotePrefetch" height_sections = self.sparse_param_to_height_sections[ param_varname] self._update_remote_sparse_update_op( param_varname, height_sections, eps, table_names) else: all_recv_outputs.extend(splited_var) program.global_block().append_op( type="recv", inputs={"X": [recv_dep_in]}, outputs={"Out": splited_var}, attrs={ "epmap": eps, "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [param_varname, recv_op_role_var_name], "sync_mode": not self.sync_mode }) if self.sync_mode: # form a WAW dependency program.global_block().append_op( type="fetch_barrier", inputs={}, outputs={"Out": all_recv_outputs}, attrs={ "endpoints": pserver_endpoints, "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) for param_varname, splited_var in six.iteritems(self.param_var_mapping): if len(splited_var) <= 1: continue orig_param = program.global_block().vars[param_varname] if param_varname not in self.sparse_param_to_height_sections: program.global_block().append_op( type="concat", inputs={"X": splited_var}, outputs={"Out": [orig_param]}, attrs={ "axis": 0, RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) if self.has_distributed_lookup_table: self._replace_lookup_table_op_with_prefetch(program, pserver_endpoints) self._split_table_grad_and_add_send_vars(program, pserver_endpoints) self._get_distributed_optimizer_vars() self.origin_program._parameters_on_pservers = self.vars_overview def get_trainer_program(self, wait_port=True): """ Get transpiled trainer side program. Returns: Program: trainer side program. """ # remove optimize ops and add a send op to main_program # FIXME(typhoonzero): Also ops like clip_gradient, lrn_decay? lr_ops = self._get_lr_ops() delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), lr_ops) # delete table init op if self.has_distributed_lookup_table: table_var = self.startup_program.global_block().vars[ self.table_name] table_param_init_op = [] for op in self.startup_program.global_block().ops: if self.table_name in op.output_arg_names: table_param_init_op.append(op) init_op_num = len(table_param_init_op) if init_op_num != 1: raise ValueError("table init op num should be 1, now is " + str( init_op_num)) table_init_op = table_param_init_op[0] self.startup_program.global_block().append_op( type="fake_init", inputs={}, outputs={"Out": table_var}, attrs={"shape": table_init_op.attr('shape')}) delete_ops(self.startup_program.global_block(), table_param_init_op) self.origin_program.__str__() if wait_port: wait_server_ready(self.pserver_endpoints) return self.origin_program def _get_trainer_startup_program(self, recv_vars, eplist): """ Get transpiled trainer side startup program. Args: recv_vars (list): Variable list to recv for current trainer_id eplist (list): A list of strings indicating Returns: Program: trainer side startup program. """ startup_program = self.startup_program # FIXME(gongwb): delete not need ops. # note that: some parameter is not trainable and those ops can't be deleted. for varname, splited_var in six.iteritems(self.param_var_mapping): # Get the eplist of recv vars eps = [] for var in splited_var: index = [v.name for v in recv_vars].index(var.name) eps.append(eplist[index]) for var in splited_var: if startup_program.global_block().has_var(var.name): continue startup_program.global_block().create_var( name=var.name, persistable=False, type=var.type, dtype=var.dtype, shape=var.shape, lod_level=var.lod_level) op = startup_program.global_block().append_op( type="recv", inputs={"X": []}, outputs={"Out": splited_var}, attrs={ "epmap": eps, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) fetch_barrier_out = startup_program.global_block().create_var( name=framework.generate_control_dev_var_name()) startup_program.global_block().append_op( type="fetch_barrier", inputs={}, outputs={"Out": fetch_barrier_out}, attrs={ "endpoints": self.pserver_endpoints, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) for varname, splited_var in six.iteritems(self.param_var_mapping): # add concat ops to merge splited parameters received from parameter servers. if len(splited_var) <= 1: continue # NOTE: if enable memory optimization, origin vars maybe removed. if varname in startup_program.global_block().vars: orig_param = startup_program.global_block().vars[varname] else: origin_param_var = self.origin_program.global_block().vars[ varname] orig_param = startup_program.global_block().create_var( name=varname, persistable=origin_param_var.persistable, type=origin_param_var.type, dtype=origin_param_var.dtype, shape=origin_param_var.shape) startup_program.global_block().append_op( type="concat", inputs={"X": splited_var}, outputs={"Out": [orig_param]}, attrs={"axis": 0}) return startup_program def get_pserver_program(self, endpoint): """ Get parameter server side program. Args: endpoint (str): current parameter server endpoint. Returns: Program: the program for current parameter server to run. """ # TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers. # NOTE: assume blocks of the same variable is not distributed # on the same pserver, only change param/grad varnames for # trainers to fetch. sys.stderr.write( "get_pserver_program() is deprecated, call get_pserver_programs() to get pserver main and startup in a single call.\n" ) # step1 pserver_program = Program() pserver_program.random_seed = self.origin_program.random_seed pserver_program._copy_dist_param_info_from(self.origin_program) # step2: Create vars to receive vars at parameter servers. recv_inputs = [] for v in self.param_grad_ep_mapping[endpoint]["params"]: self._clone_var(pserver_program.global_block(), v) for v in self.param_grad_ep_mapping[endpoint]["grads"]: # create vars for each trainer in global scope, so # we don't need to create them when grad arrives. # change client side var name to origin name by # removing ".trainer_%d" suffix suff_idx = v.name.find(".trainer_") if suff_idx >= 0: orig_var_name = v.name[:suff_idx] else: orig_var_name = v.name # NOTE: single_trainer_var must be created for multi-trainer # case to merge grads from multiple trainers single_trainer_var = \ pserver_program.global_block().create_var( name=orig_var_name, persistable=True, type=v.type, dtype=v.dtype, shape=v.shape) if self.sync_mode and self.trainer_num > 1: for trainer_id in range(self.trainer_num): var = pserver_program.global_block().create_var( name="%s.trainer_%d" % (orig_var_name, trainer_id), persistable=False, type=v.type, dtype=v.dtype, shape=v.shape) recv_inputs.append(var) else: recv_inputs.append(single_trainer_var) # step 3 # Create a union-find data structure from optimize ops, # If two ops are connected, we could add these two ops # into one set. ufind = self._create_ufind(self.optimize_ops) # step 3.2 # Iterate through the ops and append optimize op which # located on current pserver opt_op_on_pserver = [] for _, op in enumerate(self.optimize_ops): if self._is_optimizer_op(op) and self._is_opt_op_on_pserver( endpoint, op): opt_op_on_pserver.append(op) # step 3.3 # prepare if dc asgd is enabled if self.config.enable_dc_asgd == True: assert (self.sync_mode == False) self.param_bak_list = [] # add param_bak for each trainer for p in self.param_grad_ep_mapping[endpoint]["params"]: # each parameter should have w_bak for each trainer id for i in range(self.trainer_num): param_bak_name = "%s.trainer_%d_bak" % (p.name, i) tmpvar = pserver_program.global_block().create_var( # NOTE: this var name format is used in `request_get_handler` name=param_bak_name, type=p.type, shape=p.shape, dtype=p.dtype) self.param_bak_list.append((p, tmpvar)) # step 3.4 # Iterate through the ops, and if an op and the optimize ops # which located on current pserver are in one set, then # append it into the sub program. global_ops = [] def __append_optimize_op__(op, block, grad_to_block_id, merged_var, lr_ops): if self._is_optimizer_op(op): self._append_pserver_ops(block, op, endpoint, grad_to_block_id, self.origin_program, merged_var) elif op not in lr_ops: self._append_pserver_non_opt_ops(block, op) def __clone_lr_op_sub_block__(op, program, lr_block): if not op.has_attr('sub_block'): return origin_block_desc = op.attr('sub_block') origin_block = self.origin_program.block(origin_block_desc.id) assert isinstance(origin_block, Block) # we put the new sub block to new block to follow the block # hierarchy of the original blocks new_sub_block = program._create_block(lr_block.idx) # clone vars for var in origin_block.vars: new_sub_block._clone_variable(var) # clone ops for origin_op in origin_block.ops: cloned_op = self._clone_lr_op(program, new_sub_block, origin_op) # clone sub_block of op __clone_lr_op_sub_block__(cloned_op, program, new_sub_block) # reset the block of op op._set_attr('sub_block', new_sub_block) # append lr decay ops to the child block if exists lr_ops = self._get_lr_ops() # record optimize blocks and we can run them on pserver parallel optimize_blocks = [] if len(lr_ops) > 0: lr_decay_block = pserver_program._create_block( pserver_program.num_blocks - 1) optimize_blocks.append(lr_decay_block) for _, op in enumerate(lr_ops): cloned_op = self._append_pserver_non_opt_ops(lr_decay_block, op) # append sub blocks to pserver_program in lr_decay_op __clone_lr_op_sub_block__(cloned_op, pserver_program, lr_decay_block) # append op to the current block grad_to_block_id = [] pre_block_idx = pserver_program.num_blocks - 1 for idx, opt_op in enumerate(opt_op_on_pserver): per_opt_block = pserver_program._create_block(pre_block_idx) optimize_blocks.append(per_opt_block) optimize_target_param_name = opt_op.attr(OP_ROLE_VAR_ATTR_NAME)[0] # append grad merging ops before clip and weight decay # e.g. merge grad -> L2Decay op -> clip op -> optimize merged_var = None for _, op in enumerate(self.optimize_ops): # find the origin grad var before clipping/L2Decay, # merged_var should be the input var name of L2Decay grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] if op.attr(OP_ROLE_VAR_ATTR_NAME)[ 0] == optimize_target_param_name: merged_var = self._append_pserver_grad_merge_ops( per_opt_block, grad_varname_for_block, endpoint, grad_to_block_id, self.origin_program) if merged_var: break # append optimize op once then append other ops. if merged_var: for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ op not in global_ops: log("append opt op: ", op.type, op.input_arg_names, merged_var) __append_optimize_op__(op, per_opt_block, grad_to_block_id, merged_var, lr_ops) # dedup grad to ids list grad_to_block_id = list(set(grad_to_block_id)) # append global ops if global_ops: opt_state_block = pserver_program._create_block( pserver_program.num_blocks - 1) optimize_blocks.append(opt_state_block) for glb_op in global_ops: __append_optimize_op__(glb_op, opt_state_block, grad_to_block_id, None, lr_ops) # process distributed lookup_table prefetch_var_name_to_block_id = [] if self.has_distributed_lookup_table: pserver_index = self.pserver_endpoints.index(endpoint) table_opt_block = self._create_table_optimize_block( pserver_index, pserver_program, pre_block_idx, grad_to_block_id) optimize_blocks.append(table_opt_block) lookup_table_var_name_to_block_id = self._create_prefetch_block( pserver_index, pserver_program, table_opt_block) checkpoint_block_id = self._create_checkpoint_save_block( pserver_program, table_opt_block.idx) pserver_program._distributed_lookup_table = self.table_name prefetch_var_name_to_block_id.extend( lookup_table_var_name_to_block_id) if len(optimize_blocks) == 0: logging.warn("pserver [" + str(endpoint) + "] has no optimize block!!") pre_block_idx = pserver_program.num_blocks - 1 empty_block = pserver_program._create_block(pre_block_idx) optimize_blocks.append(empty_block) # In some case, some parameter server will have no parameter to optimize # So we give an empty optimize block to parameter server. attrs = { "optimize_blocks": optimize_blocks, "endpoint": endpoint, "Fanin": self.trainer_num, "sync_mode": self.sync_mode, "grad_to_block_id": grad_to_block_id, } if self.has_distributed_lookup_table: attrs['checkpint_block_id'] = checkpoint_block_id if self.config.enable_dc_asgd: attrs['dc_asgd'] = True if len(prefetch_var_name_to_block_id) > 0: attrs[ 'prefetch_var_name_to_block_id'] = prefetch_var_name_to_block_id # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", inputs={'X': recv_inputs}, outputs={}, attrs=attrs) pserver_program._sync_with_cpp() # save pserver program to generate pserver side startup relatively. self.pserver_program = pserver_program return pserver_program def get_pserver_programs(self, endpoint): """ Get pserver side main program and startup program for distributed training. Args: endpoint (str): current pserver endpoint. Returns: tuple: (main_program, startup_program), of type "Program" """ pserver_prog = self.get_pserver_program(endpoint) pserver_startup = self.get_startup_program( endpoint, pserver_program=pserver_prog) return pserver_prog, pserver_startup def get_startup_program(self, endpoint, pserver_program=None, startup_program=None): """ **Deprecated** Get startup program for current parameter server. Modify operator input variables if there are variables that were split to several blocks. Args: endpoint (str): current pserver endpoint. pserver_program (Program): deprecated, call get_pserver_program first. startup_program (Program): deprecated, should pass startup_program when initalizing Returns: Program: parameter server side startup program. """ s_prog = Program() orig_s_prog = self.startup_program s_prog.random_seed = orig_s_prog.random_seed params = self.param_grad_ep_mapping[endpoint]["params"] def _get_splited_name_and_shape(varname): for idx, splited_param in enumerate(params): pname = splited_param.name if same_or_split_var(pname, varname) and varname != pname: return pname, splited_param.shape return "", [] # 1. create vars in pserver program to startup program pserver_vars = pserver_program.global_block().vars created_var_map = collections.OrderedDict() for _, var in six.iteritems(pserver_vars): tmpvar = s_prog.global_block()._clone_variable(var) created_var_map[var.name] = tmpvar # 2. rename op outputs for op in orig_s_prog.global_block().ops: new_outputs = collections.OrderedDict() # do not append startup op if var is not on this pserver op_on_pserver = False # TODO(gongwb): remove this line. if op.type not in ["recv", "fetch_barrier", "concat"]: for key in op.output_names: newname, _ = _get_splited_name_and_shape(op.output(key)[0]) if newname: op_on_pserver = True new_outputs[key] = created_var_map[newname] elif op.output(key)[0] in pserver_vars: op_on_pserver = True new_outputs[key] = pserver_vars[op.output(key)[0]] if op_on_pserver: # most startup program ops have no inputs new_inputs = self._get_input_map_from_op(pserver_vars, op) if op.type in [ "gaussian_random", "fill_constant", "uniform_random" ]: op._set_attr("shape", list(new_outputs["Out"].shape)) s_prog.global_block().append_op( type=op.type, inputs=new_inputs, outputs=new_outputs, attrs=op.all_attrs()) if self.config.enable_dc_asgd: for p, p_bak in self.param_bak_list: startup_param_var = s_prog.global_block().vars[p.name] startup_tmpvar = s_prog.global_block().vars[p_bak.name] # copy init random value to param_bak s_prog.global_block().append_op( type="assign", inputs={"X": startup_param_var}, outputs={"Out": startup_tmpvar}) return s_prog # ====================== private transpiler functions ===================== def _get_slice_var_info(self, slice_var): block_suffix = "block" block_idx = 0 offset = 0 is_slice = False orig_var_name, block_name, _ = self._get_varname_parts(slice_var.name) if not block_name: return is_slice, block_idx, offset block_idx = int(block_name.split(block_suffix)[1]) skip_dim0 = 0 slice_vars = self.param_var_mapping[orig_var_name] orig_dim1_flatten = reduce(lambda x, y: x * y, slice_vars[0].shape[1:]) for slice_var in slice_vars[:block_idx]: skip_dim0 += slice_var.shape[0] offset = skip_dim0 * orig_dim1_flatten is_slice = True return is_slice, block_idx, offset def _get_distributed_optimizer_vars(self): def _get_distributed_optimizer_var(endpoint): opt_op_on_pserver = [] for _, op in enumerate(self.optimize_ops): if self._is_optimizer_op(op) and self._is_opt_op_on_pserver( endpoint, op): opt_op_on_pserver.append(op) for opt_op in opt_op_on_pserver: dist_var = None for key in opt_op.input_names: if key == "Param": param_name = opt_op.input(key)[0] dist_var = self.vars_overview.get_distributed_var_by_origin_and_ep( param_name, endpoint) break for key in opt_op.input_names: if key in ["Param", "Grad", "LearningRate"]: continue origin_var = self.origin_program.global_block().vars[ opt_op.input(key)[0]] # update accumulator variable shape new_shape = self._get_optimizer_input_shape( opt_op.type, key, origin_var.shape, dist_var.slice.shape) if new_shape == dist_var.slice.shape: splited_var = VarStruct( name=origin_var.name, shape=new_shape, dtype=origin_var.dtype, type=origin_var.type, lod_level=origin_var.lod_level, persistable=origin_var.persistable) self.vars_overview.add_distributed_var( origin_var=origin_var, slice_var=splited_var, is_slice=dist_var.is_slice, block_id=dist_var.block_id, offset=dist_var.offset, vtype="Optimizer", endpoint=endpoint) else: self.vars_overview.add_distributed_var( origin_var=origin_var, slice_var=origin_var, is_slice=False, block_id=0, offset=0, vtype="Optimizer", endpoint=endpoint) for ep in self.pserver_endpoints: _get_distributed_optimizer_var(ep) def _update_dist_lookup_table_vars(self, param_list, grad_list, params_grads): # TODO(wuyi): put find a way to put dist lookup table stuff all together. # update self.table_param_grad and self.trainer_side_table_grad_list program = self.origin_program if self.has_distributed_lookup_table: param_list = [ param for param in param_list if param.name != self.table_name ] grad_list = [ grad for grad in grad_list if grad.name != grad_var_name(self.table_name) ] self.table_param_grad = [ param_grad for param_grad in params_grads if param_grad[0].name == self.table_name ][0] table_grad_var = self.table_param_grad[1] if self.sync_mode: self.trainer_side_table_grad_list = [ program.global_block().create_var( name="%s.trainer_%d.pserver_%d" % (table_grad_var.name, self.trainer_id, index), type=table_grad_var.type, shape=table_grad_var.shape, dtype=table_grad_var.dtype) for index in range(len(self.pserver_endpoints)) ] else: self.trainer_side_table_grad_list = [ program.global_block().create_var( name="%s.pserver_%d" % (table_grad_var.name, index), type=table_grad_var.type, shape=table_grad_var.shape, dtype=table_grad_var.dtype) for index in range(len(self.pserver_endpoints)) ] return param_list, grad_list def _init_splited_vars(self): # update these mappings for further transpile: # 1. param_var_mapping: param var name -> [splited params vars] # 2. grad_var_mapping: grad var name -> [splited grads vars] # 3. grad_param_mapping: grad.blockx -> param.blockx # 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []} param_list = [] grad_list = [] param_grad_set = set() for p, g in self.params_grads: # skip parameter marked not trainable if type(p) == Parameter and p.trainable == False: continue if p.name not in param_grad_set: param_list.append(p) param_grad_set.add(p.name) if g.name not in param_grad_set: grad_list.append(g) param_grad_set.add(g.name) param_list, grad_list = self._update_dist_lookup_table_vars( param_list, grad_list, self.params_grads) if self.config.slice_var_up: # when we slice var up into blocks, we will slice the var according to # pserver services' count. A pserver may have two or more listening ports. grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints), self.config.min_block_size) param_blocks = slice_variable(param_list, len(self.pserver_endpoints), self.config.min_block_size) else: # when we do NOT slice var up into blocks, we will always slice params # grads into one block. grad_blocks = slice_variable(grad_list, 1, self.config.min_block_size) param_blocks = slice_variable(param_list, 1, self.config.min_block_size) assert (len(grad_blocks) == len(param_blocks)) # origin_param_name -> [splited_param_vars] self.param_var_mapping = self._create_vars_from_blocklist( self.origin_program, param_blocks) for orig_name, splited_vars in self.param_var_mapping.items(): orig_var = self.origin_program.global_block().var(orig_name) for splited_var in splited_vars: is_slice, block_id, offset = self._get_slice_var_info( splited_var) self.vars_overview.add_distributed_var( origin_var=orig_var, slice_var=splited_var, block_id=block_id, offset=offset, is_slice=is_slice, vtype="Param") # origin_grad_name -> [splited_grad_vars] self.grad_var_mapping = self._create_vars_from_blocklist( self.origin_program, grad_blocks, add_trainer_suffix=self.trainer_num > 1) # dict(grad_splited_var -> param_splited_var) self.grad_param_mapping = collections.OrderedDict() for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.param_var_mapping[p_name][int(p_bid)] # create mapping of endpoint -> split var to create pserver side program self.param_grad_ep_mapping = collections.OrderedDict() [ self.param_grad_ep_mapping.update({ ep: { "params": [], "grads": [] } }) for ep in self.pserver_endpoints ] # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op self.all_in_ids_vars = [] self.all_prefetch_input_vars = [] self.all_prefetch_output_vars = [] self.all_out_emb_vars = [] lookup_table_op_index = -1 continue_search_lookup_table_op = True while continue_search_lookup_table_op: continue_search_lookup_table_op = False all_ops = program.global_block().ops for op in all_ops: if op.type == LOOKUP_TABLE_TYPE and self.table_name == op.input( "W")[0]: if not op.attr('is_distributed'): raise RuntimeError( "lookup_table_op that lookup an distributed embedding table" "should set is_distributed to true") continue_search_lookup_table_op = True lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list( all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") ids_var = program.global_block().vars[ids_name[0]] self.all_in_ids_vars.append(ids_var) out_var = program.global_block().vars[out_name[0]] self.all_out_emb_vars.append(out_var) # delete lookup_table_op delete_ops(program.global_block(), [op]) # break for loop break for index in range(len(self.pserver_endpoints)): in_var = program.global_block().create_var( name=str("prefetch_compress_in_tmp_" + str(index)), type=self.all_in_ids_vars[0].type, shape=self.all_in_ids_vars[0].shape, dtype=self.all_in_ids_vars[0].dtype) self.all_prefetch_input_vars.append(in_var) out_var = program.global_block().create_var( name=str("prefetch_compress_out_tmp_" + str(index)), type=self.all_out_emb_vars[0].type, shape=self.all_out_emb_vars[0].shape, dtype=self.all_out_emb_vars[0].dtype) self.all_prefetch_output_vars.append(out_var) # insert split_ids_op program.global_block()._insert_op( index=lookup_table_op_index, type="split_ids", inputs={'Ids': self.all_in_ids_vars}, outputs={"Out": self.all_prefetch_input_vars}) # insert prefetch_op program.global_block()._insert_op( index=lookup_table_op_index + 1, type="prefetch", inputs={'X': self.all_prefetch_input_vars}, outputs={"Out": self.all_prefetch_output_vars}, attrs={ "epmap": pserver_endpoints, # FIXME(qiao) temporarily disable this config because prefetch # is not act as other rpc op, it's more like a forward op # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) # insert concat_op program.global_block()._insert_op( index=lookup_table_op_index + 2, type="merge_ids", inputs={ 'Ids': self.all_in_ids_vars, 'Rows': self.all_prefetch_input_vars, 'X': self.all_prefetch_output_vars }, outputs={"Out": self.all_out_emb_vars}) def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_op to send gradient to pservers # there should only be one table_name all_ops = program.global_block().ops table_grad_name = grad_var_name(self.table_name) for op in all_ops: if table_grad_name in op.output_arg_names: op_index = list(all_ops).index(op) # insert split_ids_op program.global_block()._insert_op( index=op_index + 1, type="split_ids", inputs={ 'Ids': [program.global_block().vars[table_grad_name]] }, outputs={"Out": self.trainer_side_table_grad_list}, attrs={RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE}) program.global_block()._insert_op( index=op_index + 2, type="send", inputs={'X': self.trainer_side_table_grad_list}, outputs={ 'Out': [self.grad_name_to_send_dummy_out[self.table_name]] if self.sync_mode else [] }, attrs={ "sync_mode": not self.sync_mode, "epmap": pserver_endpoints, "trainer_id": self.trainer_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, OP_ROLE_VAR_ATTR_NAME: [ self.grad_name_to_param_name[table_grad_name], table_grad_name ] }) break def _create_prefetch_block(self, pserver_index, pserver_program, optimize_block): # STEP: create prefetch block table_var = pserver_program.global_block().vars[self.table_name] prefetch_var_name_to_block_id = [] prefetch_block = pserver_program._create_block(optimize_block.idx) trainer_ids = self.all_prefetch_input_vars[pserver_index] pserver_ids = pserver_program.global_block().create_var( name=trainer_ids.name, type=trainer_ids.type, shape=trainer_ids.shape, dtype=trainer_ids.dtype) trainer_out = self.all_prefetch_output_vars[pserver_index] pserver_out = pserver_program.global_block().create_var( name=trainer_out.name, type=trainer_out.type, shape=trainer_out.shape, dtype=trainer_out.dtype) prefetch_block.append_op( type="lookup_sparse_table", inputs={'Ids': pserver_ids, "W": table_var}, outputs={"Out": pserver_out}, attrs={ "is_sparse": True, # has no effect on lookup_table op "is_distributed": True, "padding_idx": -1 }) prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str( prefetch_block.idx)) return prefetch_var_name_to_block_id def _create_table_optimize_block(self, pserver_index, pserver_program, pre_block_idx, grad_to_block_id): # STEP: create table optimize block table_opt_block = pserver_program._create_block(pre_block_idx) # create table param and grad var in pserver program # create table optimize block in pserver program table_opt_op = [ op for op in self.optimize_ops if 'Param' in op.input_names and op.input("Param")[0] == self.table_name ][0] origin_param_var = self.origin_program.global_block().vars[ self.table_name] zero_dim = int( math.ceil(origin_param_var.shape[0] / float( len(self.pserver_endpoints)))) table_shape = list(origin_param_var.shape) table_shape[0] = zero_dim param_var = pserver_program.global_block().create_var( name=origin_param_var.name, shape=table_shape, dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True) # parameter must be selected rows param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) grad_var = pserver_program.global_block()._clone_variable( self.origin_program.global_block().vars[grad_var_name( self.table_name)]) lr_var = pserver_program.global_block()._clone_variable( self.origin_program.global_block().vars[table_opt_op.input( "LearningRate")[0]]) if self.sync_mode: # create grad vars in pserver program table_grad_var = self.table_param_grad[1] pserver_side_table_grad_list = [ pserver_program.global_block().create_var( name="%s.trainer_%d.pserver_%d" % (table_grad_var.name, index, pserver_index), type=table_grad_var.type, shape=table_grad_var.shape, dtype=table_grad_var.dtype) for index in range(self.trainer_num) ] # append sum op for pserver_side_table_grad_list table_opt_block.append_op( type="sum", inputs={"X": pserver_side_table_grad_list}, outputs={"Out": [grad_var]}, attrs={"use_mkldnn": False}) else: # in async_mode, for table gradient, it also need to be splited to each parameter server origin_grad_name = grad_var.name splited_grad_name = self.trainer_side_table_grad_list[ pserver_index].name if not splited_grad_name.startswith(origin_grad_name): raise ValueError("origin_grad_var: " + splited_grad_name + " grad_var:" + grad_var.name) grad_var = pserver_program.global_block()._rename_var( origin_grad_name, splited_grad_name) inputs = { "Param": [param_var], "Grad": [grad_var], "LearningRate": [lr_var] } outputs = {"ParamOut": [param_var]} # only support sgd now logging.warn( "distribute lookup table only support sgd optimizer, change it's optimizer to sgd instead of " + table_opt_op.type) table_opt_block.append_op(type="sgd", inputs=inputs, outputs=outputs) # add table parameter gradient and it's block id to grad_to_block_id grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx)) return table_opt_block def _create_checkpoint_save_block(self, pserver_program, pre_block_idx): """ create a new block to handle save checkpoint. """ pserver_program.global_block().create_var( name="kLookupTablePath", persistable=True, type=core.VarDesc.VarType.RAW) checkpoint_save_block = pserver_program._create_block(pre_block_idx) # this 'file_path' do not be used in save lookup table variable checkpoint_save_block.append_op( type='save', inputs={'X': [self.table_name]}, outputs={}, attrs={'file_path': "none"}) return checkpoint_save_block.idx def _create_vars_from_blocklist(self, program, block_list, add_trainer_suffix=False): """ Create vars for each split. NOTE: only grads need to be named for different trainers, use add_trainer_suffix to rename the grad vars. Args: program (ProgramDesc): ProgramDesc which gradients blong. block_list (list[(varname, block_id, block_size)]): List of gradient blocks. add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. Returns: var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping from original var name to each var split. """ # varname->[(block_id, current_block_size)] block_map = collections.OrderedDict() var_mapping = collections.OrderedDict() for block_str in block_list: varname, offset, size = block_str.split(":") if varname not in block_map: block_map[varname] = [] block_map[varname].append((int(offset), int(size))) for varname, splited in six.iteritems(block_map): orig_var = program.global_block().var(varname) if len(splited) == 1: if self.sync_mode and add_trainer_suffix: new_var_name = "%s.trainer_%d" % \ (orig_var.name, self.trainer_id) program.global_block()._rename_var(varname, new_var_name) var_mapping[varname] = \ [program.global_block().var(new_var_name)] else: var_mapping[varname] = \ [program.global_block().var(orig_var.name)] continue var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 if len(orig_shape) >= 2: orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) for i, block in enumerate(splited): size = block[1] rows = size // orig_dim1_flatten splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) new_var_name = "" if self.sync_mode and add_trainer_suffix: new_var_name = "%s.block%d.trainer_%d" % \ (varname, i, self.trainer_id) else: new_var_name = "%s.block%d" % \ (varname, i) var = program.global_block().create_var( name=new_var_name, persistable=False, dtype=orig_var.dtype, type=orig_var.type, shape=splited_shape) # flattend splited var var_mapping[varname].append(var) program.global_block()._sync_with_cpp() return var_mapping def _clone_var(self, block, var, persistable=True): return block.create_var( name=var.name, shape=var.shape, dtype=var.dtype, type=var.type, lod_level=var.lod_level, persistable=persistable) def _insert_split_op(self, program, orig_var, index, splited_vars): if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) sparse_param_name = self.grad_name_to_param_name[orig_var.name] if self._is_input_of_remote_sparse_update_op(sparse_param_name): self.sparse_param_to_height_sections[ sparse_param_name] = height_sections program.global_block()._insert_op( index=index + 1, type="split_selected_rows", inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={ "height_sections": height_sections, RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: sections.append(v.shape[0]) program.global_block()._insert_op( index=index + 1, type="split_byref", inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={ "sections": sections, RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE }) else: AssertionError("Variable type should be in set " "[LOD_TENSOR, SELECTED_ROWS]") def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, param_shape): """ Returns the shape for optimizer inputs that need to be reshaped when Param and Grad is split to multiple servers. """ # HACK(typhoonzero): Should use functions of corresponding optimizer in # optimizer.py to get the shape, do not bind this in the transpiler. if op_type == "adam": if varkey in ["Moment1", "Moment2"]: return param_shape elif op_type == "adagrad": if varkey == "Moment": return param_shape elif op_type == "adamax": if varkey in ["Moment", "InfNorm"]: return param_shape elif op_type in ["momentum", "lars_momentum"]: if varkey == "Velocity": return param_shape elif op_type == "rmsprop": if varkey in ["Moment", "MeanSquare"]: return param_shape elif op_type == "decayed_adagrad": if varkey == "Moment": return param_shape elif op_type == "ftrl": if varkey in ["SquaredAccumulator", "LinearAccumulator"]: return param_shape elif op_type == "sgd": pass else: raise ValueError( "Not supported optimizer for distributed training: %s" % op_type) return orig_shape def _get_varname_parts(self, varname): # returns origin, blockid, trainerid orig_var_name = "" trainer_part = "" block_part = "" trainer_idx = varname.find(".trainer_") if trainer_idx >= 0: trainer_part = varname[trainer_idx + 1:] else: trainer_idx = len(varname) block_index = varname.find(".block") if block_index >= 0: block_part = varname[block_index + 1:trainer_idx] else: block_index = len(varname) orig_var_name = varname[0:min(block_index, trainer_idx)] return orig_var_name, block_part, trainer_part def _orig_varname(self, varname): orig, _, _ = self._get_varname_parts(varname) return orig def _append_pserver_grad_merge_ops(self, optimize_block, grad_varname_for_block, endpoint, grad_to_block_id, origin_program): program = optimize_block.program pserver_block = program.global_block() grad_block = None for g in self.param_grad_ep_mapping[endpoint]["grads"]: if self._orig_varname(g.name) == \ self._orig_varname(grad_varname_for_block): grad_block = g break if not grad_block: # do not append this op if current endpoint # is not dealing with this grad block return None orig_varname, block_name, trainer_name = self._get_varname_parts( grad_block.name) if block_name: merged_var_name = '.'.join([orig_varname, block_name]) else: merged_var_name = orig_varname merged_var = pserver_block.vars[merged_var_name] grad_to_block_id.append(merged_var.name + ":" + str(optimize_block.idx)) if self.sync_mode and self.trainer_num > 1: vars2merge = [] for i in range(self.trainer_num): per_trainer_name = "%s.trainer_%d" % \ (merged_var_name, i) vars2merge.append(pserver_block.vars[per_trainer_name]) optimize_block.append_op( type="sum", inputs={"X": vars2merge}, outputs={"Out": merged_var}, attrs={"use_mkldnn": False}) optimize_block.append_op( type="scale", inputs={"X": merged_var}, outputs={"Out": merged_var}, attrs={"scale": 1.0 / float(self.trainer_num)}) return merged_var def _append_dc_asgd_ops(self, block, param_var, grad_var): # NOTE: can not use grammar candy here, should put ops in specific block local_param_bak = block.create_var( name="%s.local_bak" % param_var.name, shape=param_var.shape, type=param_var.type, dtype=param_var.dtype, persistable=False) # trainer_id_var is block local trainer_id_var = block.create_var( name="@TRAINER_ID@", type=core.VarDesc.VarType.LOD_TENSOR, dtype=core.VarDesc.VarType.INT64, shape=[1], persistable=False) # ref_inputs = [x[1] for x in self.param_bak_list] ref_inputs = [] for p, p_bak in self.param_bak_list: if p.name == param_var.name: ref_inputs.append(p_bak) block.append_op( type="ref_by_trainer_id", inputs={"X": ref_inputs, "TrainerId": trainer_id_var}, outputs={"Out": local_param_bak}) def __create_temp_var__(): return block.create_var( name=unique_name.generate("tmp_dc_output"), shape=param_var.shape, type=param_var.type, dtype=param_var.dtype, persistable=False) o1 = __create_temp_var__() block.append_op( type="elementwise_sub", inputs={"X": param_var, "Y": local_param_bak}, outputs={"Out": o1}) o2 = __create_temp_var__() block.append_op( type="elementwise_mul", inputs={"X": o1, "Y": grad_var}, outputs={"Out": o2}) o3 = __create_temp_var__() block.append_op( type="elementwise_mul", inputs={"X": o2, "Y": grad_var}, outputs={"Out": o3}) # TODO(typhoonzero): append scale o4 = __create_temp_var__() block.append_op( type="elementwise_add", inputs={"X": grad_var, "Y": o3}, outputs={"Out": o4}) return o4 def _append_pserver_ops(self, optimize_block, opt_op, endpoint, grad_to_block_id, origin_program, merged_var): program = optimize_block.program pserver_block = program.global_block() new_inputs = collections.OrderedDict() def _get_param_block(opt_op): # param is already created on global program param_block = None for p in self.param_grad_ep_mapping[endpoint]["params"]: if same_or_split_var(p.name, opt_op.input("Param")[0]): param_block = p break return param_block if self.config.enable_dc_asgd: param_var = _get_param_block(opt_op) dc = self._append_dc_asgd_ops(optimize_block, param_var, merged_var) for key in opt_op.input_names: if key == "Grad": if self.config.enable_dc_asgd: new_inputs[key] = dc else: # Note!! This is for l2decay on sparse gradient, because it will create a new tensor for # decayed gradient but not inplace modify the origin one origin_grad_name = opt_op.input(key)[0] if core.kNewGradSuffix( ) in origin_grad_name and pserver_block.has_var( origin_grad_name): new_grad = pserver_block.var(origin_grad_name) new_inputs[key] = new_grad else: new_inputs[key] = merged_var elif key == "Param": param_block = _get_param_block(opt_op) if not param_block: return tmpvar = pserver_block.create_var( name=param_block.name, persistable=True, dtype=param_block.dtype, shape=param_block.shape) new_inputs[key] = tmpvar elif key == "LearningRate": # learning rate variable has already be created by non-optimize op, # don't create it once again. lr_varname = opt_op.input(key)[0] if lr_varname in pserver_block.vars: new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]] else: origin_var = origin_program.global_block().vars[lr_varname] tmpvar = pserver_block.create_var( name=origin_var.name, persistable=origin_var.persistable, dtype=origin_var.dtype, shape=origin_var.shape) new_inputs[key] = tmpvar for key in opt_op.input_names: new_shape = None if key in ["Param", "Grad", "LearningRate"]: continue var = self.origin_program.global_block().vars[opt_op.input(key)[0]] param_var = new_inputs["Param"] # update accumulator variable shape new_shape = self._get_optimizer_input_shape( opt_op.type, key, var.shape, param_var.shape) tmpvar = pserver_block.create_var( name=var.name, persistable=var.persistable, dtype=var.dtype, shape=new_shape) new_inputs[key] = tmpvar # change output's ParamOut variable outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) outputs["ParamOut"] = new_inputs["Param"] optimize_block.append_op( type=opt_op.type, inputs=new_inputs, outputs=outputs, attrs=opt_op.all_attrs()) def _get_pserver_grad_param_var(self, var, var_dict): """ Return pserver side grad/param variable, return None if the variable is not grad/param, e.g. a@GRAD -> a@GRAD.block0 a@GRAD -> a@GRAD (a is not splited) fc_0.w_0 -> fc_0.w_0.block_0 fc_0.w_0 -> fc_0.w_0 (weight is not splited) _generated_var_123 -> None """ grad_block = None for _, g in six.iteritems(var_dict): if self._orig_varname(g.name) == self._orig_varname(var.name): # skip per trainer vars if g.name.find(".trainer_") == -1: # only param or grads have splited blocks if self._orig_varname(g.name) in self.grad_name_to_param_name or \ self._orig_varname(g.name) in self.param_name_to_grad_name: grad_block = g break return grad_block def _clone_lr_op(self, program, block, op): inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, op) for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: if var not in program.global_block().vars: block._clone_variable(var) outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, op) for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: if var not in program.global_block().vars: block._clone_variable(var) return block.append_op( type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()) def _append_pserver_non_opt_ops(self, optimize_block, opt_op): program = optimize_block.program # Append the ops for parameters that do not need to be optimized/updated inputs = self._get_input_map_from_op( self.origin_program.global_block().vars, opt_op) for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for i in range(len(varlist)): var = varlist[i] # for ops like clipping and weight decay, get the splited var (xxx.block0) # for inputs/outputs grad_block = self._get_pserver_grad_param_var( var, program.global_block().vars) if grad_block: varlist[i] = grad_block elif var.name not in program.global_block().vars: tmpvar = program.global_block()._clone_variable(var) varlist[i] = tmpvar else: varlist[i] = program.global_block().vars[var.name] inputs[key] = varlist outputs = self._get_output_map_from_op( self.origin_program.global_block().vars, opt_op) for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for i in range(len(varlist)): var = varlist[i] grad_block = self._get_pserver_grad_param_var( var, program.global_block().vars) if grad_block: varlist[i] = grad_block elif var.name not in program.global_block().vars: tmpvar = program.global_block()._clone_variable(var) varlist[i] = tmpvar else: varlist[i] = program.global_block().vars[var.name] outputs[key] = varlist return optimize_block.append_op( type=opt_op.type, inputs=inputs, outputs=outputs, attrs=opt_op.all_attrs()) def _is_op_connected(self, op1, op2): # If one op's input is another op's output or # one op's output is another op's input, we say # the two operator is connected. if set(op1.desc.output_arg_names()) & set(op2.desc.input_arg_names()) or \ set(op1.desc.input_arg_names()) & set(op2.desc.output_arg_names()): return True return False def _create_ufind(self, optimize_ops): # Create a unit find data struct by optimize ops ufind = UnionFind(optimize_ops) for i in range(len(optimize_ops)): for j in range(i, len(optimize_ops)): op1 = optimize_ops[i] op2 = optimize_ops[j] if self._is_op_connected(op1, op2): ufind.union(op1, op2) return ufind def _is_optimizer_op(self, op): if "Param" in op.input_names and \ "LearningRate" in op.input_names: return True return False def _is_opt_op_on_pserver(self, endpoint, op): param_names = [ p.name for p in self.param_grad_ep_mapping[endpoint]["params"] ] if op.input("Param")[0] in param_names: return True else: for n in param_names: param = op.input("Param")[0] if same_or_split_var(n, param) and n != param: return True return False def _get_input_map_from_op(self, varmap, op): """Returns a dict from op input name to the vars in varmap.""" iomap = collections.OrderedDict() for key in op.input_names: vars = [] for varname in op.input(key): vars.append(varmap[varname]) if len(vars) == 1: iomap[key] = vars[0] else: iomap[key] = vars return iomap def _get_output_map_from_op(self, varmap, op): """Returns a dict from op output name to the vars in varmap.""" iomap = collections.OrderedDict() for key in op.output_names: vars = [] for varname in op.output(key): vars.append(varmap[varname]) if len(vars) == 1: iomap[key] = vars[0] else: iomap[key] = vars return iomap def _get_lr_ops(self): lr_ops = [] block = self.origin_program.global_block() for op in block.ops: role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) log("append lr op: ", op.type) return lr_ops def _get_lr_ops_deprecated(self): lr_ops = [] # find learning rate variables by optimize op lr_vars = set() for op in self.optimize_ops: if self._is_optimizer_op(op): lr_vars.add(op.input("LearningRate")[0]) find_ops = [] # find ops which output is lr var block = self.origin_program.global_block() for op in block.ops: if set(op.output_arg_names) & lr_vars: find_ops.append(op) # make a union find struct by the ops in default_main_program ufind = UnionFind(block.ops) for op1 in block.ops: for op2 in block.ops: # NOTE: we need to skip all optimize ops, since it is connected # with forward/backward ops and lr ops, we only need the lr ops. if op1 != op2 and self._is_op_connected(op1, op2) and \ not self._is_optimizer_op(op1) and not self._is_optimizer_op(op2): ufind.union(op1, op2) # find all ops which is related with lr var for op1 in block.ops: for op2 in find_ops: if ufind.is_connected(op1, op2): lr_ops.append(op1) # we only need to append op for once break return lr_ops def _is_opt_role_op(self, op): # NOTE: depend on oprole to find out whether this op is for # optimize op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize if op_maker.kOpRoleAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False def _get_optimize_pass(self): """ Get optimizer operators, parameters and gradients from origin_program Returns: opt_ops (list): optimize operators. params_grads (dict): paramter->gradient. """ block = self.origin_program.global_block() opt_ops = [] params_grads = [] # tmp set to dedup optimize_params = set() origin_var_dict = self.origin_program.global_block().vars for op in block.ops: if self._is_opt_role_op(op): opt_ops.append(op) if op.attr(OP_ROLE_VAR_ATTR_NAME): param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] if not param_name in optimize_params: optimize_params.add(param_name) log("adding param_grad pair: ", param_name, grad_name) params_grads.append([ origin_var_dict[param_name], origin_var_dict[grad_name] ]) else: pass return opt_ops, params_grads