# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from functools import reduce import collections import math import os import warnings import six import paddle.fluid as fluid from paddle.fluid import core from paddle.fluid.core import CommContext import paddle.fluid.framework as framework from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher from paddle.fluid.transpiler.details.program_utils import delete_ops OP_NAME_SCOPE = "op_namescope" CLIP_OP_NAME_SCOPE = "@CLIP" STEP_COUNTER = "@PS_STEP_COUNTER@" LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName() RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"] SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} def _get_lr_ops(program): lr_ops = [] for index, op in enumerate(program.global_block().ops): role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) return lr_ops def _has_global_step(lr_ops): if len(lr_ops) > 0: for idx, op in enumerate(lr_ops): if op.type != 'increment': continue counter = op.input("X")[0] if counter == LEARNING_RATE_DECAY_COUNTER: return True return False def is_sparse_op(op): if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr( 'is_distributed') is False: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is False: return True return False def is_distributed_sparse_op(op): if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is True: return True return False def get_sparse_tablename(op): return op.input("W")[0] def get_sparse_tablenames(program, is_distributed): tablenames = set() if is_distributed: for op in program.global_block().ops: if is_distributed_sparse_op(op): tablenames.add(get_sparse_tablename(op)) else: for op in program.global_block().ops: if is_sparse_op(op): tablenames.add(get_sparse_tablename(op)) return list(tablenames) class MergedVariable: def __init__(self, merged, ordered, offsets): self.merged_var = merged self.ordered_vars = ordered self.offsets = offsets def Singleton(cls): _instance = {} def _singleton(*args, **kargs): if cls not in _instance: _instance[cls] = cls(*args, **kargs) return _instance[cls] return _singleton @Singleton class CompileTimeStrategy(object): def __init__(self, main_program, startup_program, strategy, role_maker): self.min_block_size = 8192 self.origin_main_program = main_program self.origin_startup_program = startup_program self.origin_ps_main_program = main_program self.origin_ps_startup_program = startup_program self.strategy = strategy self.role_maker = role_maker self.origin_sparse_pairs = [] self.origin_dense_pairs = [] self.merged_variables_pairs = [] self.merged_dense_pairs = [] self.merged_sparse_pairs = [] self.merged_variable_map = {} self.param_name_to_grad_name = {} self.grad_name_to_param_name = {} self.param_grad_ep_mapping = collections.OrderedDict() self.grad_param_mapping = collections.OrderedDict() self._build_var_distributed() # for heter-ps save variables self.origin_merged_variables_pairs = list(self.merged_variables_pairs) self.origin_merged_dense_pairs = list(self.merged_dense_pairs) self.origin_merged_sparse_pairs = list(self.merged_sparse_pairs) def get_distributed_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode def is_sync_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.SYNC def is_geo_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.GEO def is_async_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.ASYNC def get_role_id(self): try: return self.role_maker._role_id() except Exception: return self.role_maker.role_id() def get_trainers(self): try: return self.role_maker._worker_num() except Exception: return self.role_maker.worker_num() def get_ps_endpoint(self): try: return self.role_maker._get_pserver_endpoints()[self.get_role_id()] except Exception: return self.role_maker.get_pserver_endpoints()[self.get_role_id()] def get_ps_endpoints(self): try: return self.role_maker._get_pserver_endpoints() except Exception: return self.role_maker.get_pserver_endpoints() def get_heter_worker_endpoints(self): try: return self.role_maker._get_heter_worker_endpoints() except Exception: return self.role_maker.get_heter_worker_endpoints() def get_heter_worker_endpoint(self): try: return self.role_maker._get_heter_worker_endpoint() except Exception: return self.role_maker.get_heter_worker_endpoint() def get_origin_programs(self): return self.origin_main_program, self.origin_startup_program def get_origin_main_program(self): return self.origin_main_program def get_origin_startup_program(self): return self.origin_startup_program def set_origin_ps_main_program(self, program): self.origin_ps_main_program = program def set_origin_ps_startup_program(self, program): self.origin_ps_startup_program = program def get_origin_ps_main_program(self): return self.origin_ps_main_program def get_origin_ps_startup_program(self): return self.origin_ps_startup_program def get_sparse_varname_on_ps(self, is_distributed, endpoint=None): if not endpoint: endpoint = self.get_ps_endpoint() varnames = get_sparse_tablenames(self.get_origin_main_program(), is_distributed) ps_sparse_varnames = [] for varname in varnames: tables = self.get_var_distributed(varname, True) for i in range(len(tables)): table, ep, _ = tables[i] if ep == endpoint: ps_sparse_varnames.append(table) return ps_sparse_varnames def build_ctx(self, vars, mapping, is_grad, is_sparse, is_send, is_distributed=False): def get_grad_var_ep(slices): names = [] eps = [] sections = [] for slice in slices: if self.is_geo_mode(): if is_send: names.append("{}.delta".format(slice.name)) else: names.append(slice.name) elif is_grad and self.is_sync_mode() and self.get_trainers( ) > 1: names.append("{}.trainer_{}".format(slice.name, self.get_role_id())) else: names.append(slice.name) sections.append(slice.shape[0]) for ep, pairs in self.param_grad_ep_mapping.items(): params, grads = pairs["params"], pairs["grads"] for var in params + grads: if slice.name == var.name: eps.append(ep) break return names, eps, sections if isinstance(vars, MergedVariable): name = vars.merged_var.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [var.name for var in vars.ordered_vars] else: name = vars.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [vars.name] trainer_id = self.get_role_id() aggregate = True ctx = CommContext(name, names, eps, sections, origin_varnames, trainer_id, aggregate, is_sparse, is_distributed) return ctx def get_trainer_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) if not self.is_geo_mode(): for merged in self.merged_dense_pairs: grad = merged[1] ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, True) send_ctx[ctx.var_name()] = ctx for merged in self.merged_sparse_pairs: param = merged[0] grad = merged[1] param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx if self.is_async_mode(): name, ctx = self._step_ctx() send_ctx[name] = ctx else: for pairs in self.origin_sparse_pairs: param, grad = pairs param_name = param.name is_distributed = True if param_name in distibuted_varnames else False param_ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) grad_ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) ctx = CommContext(param_ctx.var_name(), param_ctx.split_varnames(), param_ctx.split_endpoints(), param_ctx.sections(), grad_ctx.origin_varnames(), param_ctx.trainer_id(), param_ctx.aggregate(), param_ctx.is_sparse(), param_ctx.is_distributed()) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx return send_ctx def get_communicator_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) if self.is_geo_mode(): for pairs in self.merged_dense_pairs: param = pairs[0] ctx = self.build_ctx(param, self.param_var_mapping, False, False, True) send_ctx[ctx.var_name()] = ctx for pairs in self.merged_sparse_pairs: param = pairs[0] param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx else: for merged in self.merged_dense_pairs: grad = merged[1] ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, True) send_ctx[ctx.var_name()] = ctx for merged in self.merged_sparse_pairs: param, grad = merged param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx return send_ctx def get_communicator_recv_context(self, recv_type=1, use_origin_program=False): # recv_type # 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) sparse_varnames = [] for pairs in self.origin_sparse_pairs: param, grad = pairs sparse_varnames.append(param.name) dense_recv_ctx = {} sparse_recv_ctx = {} distributed_recv_ctx = {} variables_pairs = self.merged_variables_pairs if not use_origin_program else self.origin_merged_variables_pairs for merged in variables_pairs: params = merged[0] if params.merged_var.name in sparse_varnames: continue ctx = self.build_ctx(params, self.param_var_mapping, False, False, False) dense_recv_ctx[ctx.var_name()] = ctx for pairs in self.origin_sparse_pairs: param, grad = pairs if param.name in distibuted_varnames: ctx = self.build_ctx(param, self.param_var_mapping, False, True, False, True) distributed_recv_ctx[ctx.var_name()] = ctx else: ctx = self.build_ctx(param, self.param_var_mapping, False, True, False, False) sparse_recv_ctx[ctx.var_name()] = ctx if recv_type == 1: return dense_recv_ctx if recv_type == 2: return sparse_recv_ctx if recv_type == 3: return distributed_recv_ctx if recv_type == 4: dense_recv_ctx.update(sparse_recv_ctx) dense_recv_ctx.update(distributed_recv_ctx) return dense_recv_ctx assert ValueError( "recv_type can only be 1/2/3/4, 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL" ) def get_server_runtime_config(self): return self.strategy.get_server_runtime_config() def get_var_distributed(self, varname, is_param): var_distributed = [] offset = 0 if is_param: params = self.param_var_mapping[varname] param_varnames = [var.name for var in params] for ep, pairs in self.param_grad_ep_mapping.items(): for p in pairs["params"]: if p.name in param_varnames: offset += p.shape[0] var_distributed.append((p.name, ep, p.shape[0])) else: grads = self.grad_var_mapping[varname] grad_varnames = [var.name for var in grads] for ep, pairs in self.param_grad_ep_mapping.items(): for g in pairs["grads"]: if g.name in grad_varnames: var_distributed.append((g.name, ep, g.shape[0])) return var_distributed def _step_ctx(self): name = STEP_COUNTER trainer_id = self.get_role_id() endpoints = self.get_ps_endpoints() sections = [1] * len(endpoints) names = [name] * len(endpoints) ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, True, False, False) return name, ctx def _create_vars_from_blocklist(self, block_list): """ Create vars for each split. NOTE: only grads need to be named for different trainers, use add_trainer_suffix to rename the grad vars. Args: block_list (list[(varname, block_id, block_size)]): List of gradient blocks. add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. Returns: var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping from original var name to each var split. """ # varname->[(block_id, current_block_size)] block_map = collections.OrderedDict() var_mapping = collections.OrderedDict() for block_str in block_list: varname, offset, size = block_str.split(":") if varname not in block_map: block_map[varname] = [] block_map[varname].append((int(offset), int(size))) for varname, split in six.iteritems(block_map): orig_var = self.merged_variable_map[varname] if len(split) == 1: var_mapping[varname] = [orig_var] self.var_distributed.add_distributed_var( origin_var=orig_var, slice_var=orig_var, block_id=0, offset=0, is_slice=False, vtype="Param") else: var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 if len(orig_shape) >= 2: orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) for i, block in enumerate(split): size = block[1] rows = size // orig_dim1_flatten splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) new_var_name = "%s.block%d" % (varname, i) slice_var = vars_metatools.VarStruct( name=new_var_name, shape=splited_shape, dtype=orig_var.dtype, type=orig_var.type, lod_level=orig_var.lod_level, persistable=False) var_mapping[varname].append(slice_var) self.var_distributed.add_distributed_var( origin_var=orig_var, slice_var=slice_var, block_id=i, offset=-1, is_slice=False, vtype="Param") return var_mapping def _dispatcher(self): ps_dispatcher = RoundRobin(self.get_ps_endpoints()) ps_dispatcher.reset() grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping)) sparse_gradnames = [grad.name for _, grad in self.origin_sparse_pairs] for grad_varname, splited_vars in grad_var_mapping_items: if grad_varname in sparse_gradnames: continue send_vars = [] for _, var in enumerate(splited_vars): send_vars.append(var) recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) eps = ps_dispatcher.dispatch(recv_vars) for i, ep in enumerate(eps): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) for grad_varname, splited_vars in grad_var_mapping_items: if grad_varname not in sparse_gradnames: continue ps_dispatcher.reset() send_vars = [] for _, var in enumerate(splited_vars): send_vars.append(var) recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) eps = ps_dispatcher.dispatch(recv_vars) for i, ep in enumerate(eps): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) def _slice_variable(self, var_list, slice_count, min_block_size, uniform=False): """ We may need to split dense tensor to one or more blocks and put them equally onto parameter server. One block is a sub-tensor aligned by dim[0] of the tensor. We need to have a minimal block size so that the calculations in the parameter server side can gain better performance. By default minimum block size 8K elements (maybe 16bit or 32bit or 64bit). Args: var_list (list): List of variables. slice_count (int): Numel of count that variables will be sliced, which could be the pserver services' count. min_block_size (int): Minimum split block size. Returns: blocks (list[(varname, block_id, current_block_size)]): A list of VarBlocks. Each VarBlock specifies a shard of the var. """ blocks = [] for var in var_list: if not uniform: var_numel = reduce(lambda x, y: x * y, var.shape) split_count = 1 # if min_block_size == -1: # split_count = 1 # else: # split_count = slice_count # max_pserver_count = int( # math.floor(var_numel / float(min_block_size))) # if max_pserver_count == 0: # max_pserver_count = 1 # if max_pserver_count < slice_count: # split_count = max_pserver_count block_size = int(math.ceil(var_numel / float(split_count))) if len(var.shape) >= 2: # align by dim1(width) dim1 = reduce(lambda x, y: x * y, var.shape[1:]) remains = block_size % dim1 if remains != 0: block_size += dim1 - remains # update split_count after aligning split_count = int(math.ceil(var_numel / float(block_size))) for block_id in range(split_count): curr_block_size = min(block_size, var_numel - ( (block_id) * block_size)) block = vars_metatools.VarBlock(var.name, block_id, curr_block_size) blocks.append(str(block)) else: block_size = var.shape[0] / slice_count remainder = var.shape[0] % slice_count if block_size == 0: dim0s = [block_size] * remainder else: dim0s = [block_size] * slice_count for i in range(remainder): dim0s[i] = dim0s[i] + 1 dim1 = reduce(lambda x, y: x * y, var.shape[1:]) for block_id in range(len(dim0s)): numel = dim0s[block_id] * dim1 block = vars_metatools.VarBlock(var.name, block_id, numel) blocks.append(str(block)) return blocks def _get_param_grad_blocks(self, pairs, min_block_size, uniform=False): param_list = [] grad_list = [] param_grad_set = set() for p, g in pairs: # todo(tangwei12) skip parameter marked not trainable # if type(p) == Parameter and p.trainable == False: # continue p = p.merged_var g = g.merged_var if p.name not in param_grad_set: param_list.append(p) param_grad_set.add(p.name) if g.name not in param_grad_set: grad_list.append(g) param_grad_set.add(g.name) # when we slice var up into blocks, we will slice the var according to # pserver services' count. A pserver may have two or more listening ports. grad_blocks = self._slice_variable(grad_list, len(self.get_ps_endpoints()), min_block_size, uniform) param_blocks = self._slice_variable(param_list, len(self.get_ps_endpoints()), min_block_size, uniform) return param_blocks, grad_blocks def _var_slice_and_distribute(self): # update these mappings for further transpile: # 1. param_var_mapping : param var name->[split params vars] # 2. grad_var_mapping : grad var name->[split grads vars] # 3. grad_param_mapping : grad.blockx->param.blockx # 4. param_grad_ep_mapping : ep->{"params" : [], "grads" : [] } dps, dgs = self._get_param_grad_blocks(self.merged_dense_pairs, -1, False) sps, sgs = self._get_param_grad_blocks(self.merged_sparse_pairs, self.min_block_size, True) param_blocks = dps + sps grad_blocks = dgs + sgs assert (len(grad_blocks) == len(param_blocks)) # origin_param_name->[splited_param_vars] self.param_var_mapping = self._create_vars_from_blocklist(param_blocks) self.grad_var_mapping = self._create_vars_from_blocklist(grad_blocks) # dict(grad_splited_var->param_splited_var) self.grad_param_mapping = collections.OrderedDict() for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.param_var_mapping[p_name][int(p_bid)] print_maps = {} for k, v in self.grad_param_mapping.items(): print_maps[str(k)] = str(v) # create mapping of endpoint->split var to create pserver side program self.param_grad_ep_mapping = collections.OrderedDict() [ self.param_grad_ep_mapping.update({ ep: { "params": [], "grads": [] } }) for ep in self.get_ps_endpoints() ] def _build_var_distributed(self): self.var_distributed = vars_metatools.VarsDistributed() sparse_pairs, dense_pairs = self.get_param_grads() origin_for_sparse = [] origin_for_dense = [] param_name_grad_name = dict() grad_name_to_param_name = dict() for param, grad in sparse_pairs: param = vars_metatools.create_var_struct(param) grad = vars_metatools.create_var_struct(grad) origin_for_sparse.append((param, grad)) for param, grad in dense_pairs: param = vars_metatools.create_var_struct(param) grad = vars_metatools.create_var_struct(grad) origin_for_dense.append((param, grad)) for dense_pair in origin_for_dense: param, grad = dense_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) self.merged_variables_pairs.append((m_param, m_grad)) self.merged_dense_pairs.append((m_param, m_grad)) for sparse_pair in origin_for_sparse: param, grad = sparse_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) self.merged_variables_pairs.append((m_param, m_grad)) self.merged_sparse_pairs.append((m_param, m_grad)) for merged in self.merged_variables_pairs: m_param, m_grad = merged self.merged_variable_map[ m_param.merged_var.name] = m_param.merged_var self.merged_variable_map[m_grad.merged_var.name] = m_grad.merged_var param_merges = [] param_merges.extend(origin_for_sparse) param_merges.extend(origin_for_dense) for param, grad in param_merges: param_name_grad_name[param.name] = grad.name grad_name_to_param_name[grad.name] = param.name self.origin_sparse_pairs = origin_for_sparse self.origin_dense_pairs = origin_for_dense self.param_name_to_grad_name = param_name_grad_name self.grad_name_to_param_name = grad_name_to_param_name sparse_pair_map = collections.OrderedDict() for pair in self.origin_sparse_pairs + self.origin_dense_pairs: param, grad = pair sparse_pair_map[param.name] = str(param) sparse_pair_map[grad.name] = str(grad) self._var_slice_and_distribute() self._dispatcher() def get_param_grads(self): origin_program = self.origin_main_program def _get_params_grads(sparse_varnames): block = origin_program.global_block() dense_param_grads = [] sparse_param_grads = [] optimize_params = set() origin_var_dict = origin_program.global_block().vars role_id = int(core.op_proto_and_checker_maker.OpRole.Backward) for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr("op_role", role_id) continue if op.attr(OP_ROLE_VAR_ATTR_NAME): param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] if param_name not in optimize_params: optimize_params.add(param_name) param_grad = (origin_var_dict[param_name], origin_var_dict[grad_name]) if param_name in sparse_varnames: sparse_param_grads.append(param_grad) else: dense_param_grads.append(param_grad) return sparse_param_grads, dense_param_grads def _get_sparse_varnames(): varnames = [] for op in origin_program.global_block().ops: if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] varnames.append(param_name) return list(set(varnames)) sparse_varnames = _get_sparse_varnames() sparse_param_grads, dense_param_grads = _get_params_grads( sparse_varnames) return sparse_param_grads, dense_param_grads def remove_var_pair_by_grad(self, var_name): for index, pair in enumerate(self.merged_variables_pairs): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del self.merged_variables_pairs[index] for index, pair in enumerate(self.merged_dense_pairs): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del self.merged_dense_pairs[index] return for index, pair in enumerate(self.merged_sparse_pairs): var = pair[0] var_grad = pair[1] if var_grad.merged_var.name == var_name: del self.merged_sparse_pairs[index] return print("Not find {} in self.merge_pairs".format(var_name)) def _is_opt_role_op(op): # NOTE : depend on oprole to find out whether this op is for # optimize op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize if op_maker.kOpRoleAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False def _get_optimize_ops(_program): block = _program.global_block() opt_ops = [] for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr( "op_role", int(core.op_proto_and_checker_maker.OpRole.Backward)) continue opt_ops.append(op) return opt_ops def _get_varname_parts(varname): # returns origin, blockid, trainerid orig_var_name = "" trainer_part = "" block_part = "" trainer_idx = varname.find(".trainer_") if trainer_idx >= 0: trainer_part = varname[trainer_idx + 1:] else: trainer_idx = len(varname) block_index = varname.find(".block") if block_index >= 0: block_part = varname[block_index + 1:trainer_idx] else: block_index = len(varname) orig_var_name = varname[0:min(block_index, trainer_idx)] return orig_var_name, block_part, trainer_part def _orig_varname(varname): orig, _, _ = _get_varname_parts(varname) return orig