# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright(c) 2020 PaddlePaddle Authors.All Rights Reserved. # # Licensed under the Apache License, Version 2.0(the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http: // www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from functools import reduce import collections import math import os import six from paddle.fluid import core from paddle.fluid.core import CommContext from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher OP_NAME_SCOPE = "op_namescope" CLIP_OP_NAME_SCOPE = "@CLIP" STEP_COUNTER = "@PS_STEP_COUNTER@" LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName() RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize def _get_lr_ops(program): lr_ops = [] for index, op in enumerate(program.global_block().ops): role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) return lr_ops def _has_global_step(lr_ops): if len(lr_ops) > 0: for idx, op in enumerate(lr_ops): if op.type != 'increment': continue counter = op.input("X")[0] if counter == LEARNING_RATE_DECAY_COUNTER: return True return False def is_sparse_op(op): if op.type == "lookup_table" and op.attr('is_sparse') is True and op.attr( 'is_distributed') is False: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is False: return True return False def is_distributed_sparse_op(op): if op.type == "lookup_table" and op.attr('is_distributed') is True: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is True: return True return False def get_sparse_tablename(op): return op.input("W")[0] def get_sparse_tablenames(program, is_distributed): tablenames = set() if is_distributed: for op in program.global_block().ops: if is_distributed_sparse_op(op): tablenames.add(get_sparse_tablename(op)) else: for op in program.global_block().ops: if is_sparse_op(op): tablenames.add(get_sparse_tablename(op)) return list(tablenames) class MergedVariable: def __init__(self, merged, ordered, offsets): self.merged_var = merged self.ordered_vars = ordered self.offsets = offsets class CompileTimeStrategy(object): def __init__(self, main_program, startup_program, strategy, role_maker): self.min_block_size = 8192 self.origin_main_program = main_program self.origin_startup_program = startup_program self.strategy = strategy self.role_maker = role_maker self.origin_sparse_pairs = [] self.origin_dense_pairs = [] self.merged_variables_pairs = [] self.merged_dense_pairs = [] self.merged_sparse_pairs = [] self.merged_variable_map = {} self.param_name_to_grad_name = {} self.grad_name_to_param_name = {} self.param_grad_ep_mapping = collections.OrderedDict() self.grad_param_mapping = collections.OrderedDict() self._build_var_distributed() def get_distributed_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode def is_sync_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.SYNC def is_geo_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.GEO def is_async_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode == DistributedMode.ASYNC def get_role_id(self): return self.role_maker.role_id() def get_trainers(self): return self.role_maker.worker_num() def get_ps_endpoint(self): return self.role_maker.get_pserver_endpoints()[self.get_role_id()] def get_ps_endpoints(self): return self.role_maker.get_pserver_endpoints() def get_origin_programs(self): return self.origin_main_program, self.origin_startup_program def get_origin_main_program(self): return self.origin_main_program def get_origin_startup_program(self): return self.origin_startup_program def get_sparse_varname_on_ps(self, is_distributed, endpoint=None): if not endpoint: endpoint = self.get_ps_endpoint() varnames = get_sparse_tablenames(self.get_origin_main_program(), is_distributed) ps_sparse_varnames = [] for varname in varnames: tables = self.get_var_distributed(varname, True) for i in range(len(tables)): table, ep, _ = tables[i] if ep == endpoint: ps_sparse_varnames.append(table) return ps_sparse_varnames def build_ctx(self, vars, mapping, is_grad, is_sparse, is_send, is_distributed=False): def get_grad_var_ep(slices): names = [] eps = [] sections = [] for slice in slices: if self.is_geo_mode(): if is_send: names.append("{}.delta".format(slice.name)) else: names.append(slice.name) elif is_grad and self.is_sync_mode() and self.get_trainers( ) > 1: names.append("{}.trainer_{}".format(slice.name, self.get_role_id())) else: names.append(slice.name) sections.append(slice.shape[0]) for ep, pairs in self.param_grad_ep_mapping.items(): params, grads = pairs["params"], pairs["grads"] for var in params + grads: if slice.name == var.name: eps.append(ep) break return names, eps, sections if isinstance(vars, MergedVariable): name = vars.merged_var.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [var.name for var in vars.ordered_vars] else: name = vars.name slices = mapping[name] names, eps, sections = get_grad_var_ep(slices) origin_varnames = [vars.name] trainer_id = self.get_role_id() aggregate = True ctx = CommContext(name, names, eps, sections, origin_varnames, trainer_id, aggregate, is_sparse, is_distributed) return ctx def get_trainer_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) if not self.is_geo_mode(): for merged in self.merged_dense_pairs: grad = merged[1] ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, True) send_ctx[ctx.var_name()] = ctx for merged in self.merged_sparse_pairs: param = merged[0] grad = merged[1] param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx if self.is_async_mode(): name, ctx = self._step_ctx() send_ctx[name] = ctx else: for pairs in self.origin_sparse_pairs: param, grad = pairs param_name = param.name is_distributed = True if param_name in distibuted_varnames else False param_ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) grad_ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) ctx = CommContext(param_ctx.var_name(), param_ctx.split_varnames(), param_ctx.split_endpoints(), param_ctx.sections(), grad_ctx.origin_varnames(), param_ctx.trainer_id(), param_ctx.aggregate(), param_ctx.is_sparse(), param_ctx.is_distributed()) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx return send_ctx def get_communicator_send_context(self): send_ctx = {} distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) if self.is_geo_mode(): for pairs in self.merged_dense_pairs: param = pairs[0] ctx = self.build_ctx(param, self.param_var_mapping, False, False, True) send_ctx[ctx.var_name()] = ctx for pairs in self.merged_sparse_pairs: param = pairs[0] param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(param, self.param_var_mapping, False, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx else: for merged in self.merged_dense_pairs: grad = merged[1] ctx = self.build_ctx(grad, self.grad_var_mapping, True, False, True) send_ctx[ctx.var_name()] = ctx for merged in self.merged_sparse_pairs: param, grad = merged param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False ctx = self.build_ctx(grad, self.grad_var_mapping, True, True, True, is_distributed) send_ctx[ctx.var_name()] = ctx name, ctx = self._step_ctx() send_ctx[name] = ctx return send_ctx def get_communicator_recv_context(self, recv_type=1): # recv_type # 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL distibuted_varnames = get_sparse_tablenames(self.origin_main_program, True) sparse_varnames = [] for pairs in self.origin_sparse_pairs: param, grad = pairs sparse_varnames.append(param.name) dense_recv_ctx = {} sparse_recv_ctx = {} distributed_recv_ctx = {} for merged in self.merged_variables_pairs: params = merged[0] if params.merged_var.name in sparse_varnames: continue ctx = self.build_ctx(params, self.param_var_mapping, False, False, False) dense_recv_ctx[ctx.var_name()] = ctx for pairs in self.origin_sparse_pairs: param, grad = pairs if param.name in distibuted_varnames: ctx = self.build_ctx(param, self.param_var_mapping, False, True, False, True) distributed_recv_ctx[ctx.var_name()] = ctx else: ctx = self.build_ctx(param, self.param_var_mapping, False, True, False, False) sparse_recv_ctx[ctx.var_name()] = ctx if recv_type == 1: return dense_recv_ctx if recv_type == 2: return sparse_recv_ctx if recv_type == 3: return distributed_recv_ctx if recv_type == 4: dense_recv_ctx.update(sparse_recv_ctx) dense_recv_ctx.update(distributed_recv_ctx) return dense_recv_ctx assert ValueError( "recv_type can only be 1/2/3/4, 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL" ) def get_server_runtime_config(self): return self.strategy.get_server_runtime_config() def get_var_distributed(self, varname, is_param): var_distributed = [] offset = 0 if is_param: params = self.param_var_mapping[varname] param_varnames = [var.name for var in params] for ep, pairs in self.param_grad_ep_mapping.items(): for p in pairs["params"]: if p.name in param_varnames: offset += p.shape[0] var_distributed.append((p.name, ep, p.shape[0])) else: grads = self.grad_var_mapping[varname] grad_varnames = [var.name for var in grads] for ep, pairs in self.param_grad_ep_mapping.items(): for g in pairs["grads"]: if g.name in grad_varnames: var_distributed.append((g.name, ep, g.shape[0])) return var_distributed def _step_ctx(self): name = STEP_COUNTER trainer_id = self.get_role_id() endpoints = self.get_ps_endpoints() sections = [1] * len(endpoints) names = [name] * len(endpoints) ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, True, False, False) return name, ctx def _create_vars_from_blocklist(self, block_list): """ Create vars for each split. NOTE: only grads need to be named for different trainers, use add_trainer_suffix to rename the grad vars. Args: block_list (list[(varname, block_id, block_size)]): List of gradient blocks. add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. Returns: var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping from original var name to each var split. """ # varname->[(block_id, current_block_size)] block_map = collections.OrderedDict() var_mapping = collections.OrderedDict() for block_str in block_list: varname, offset, size = block_str.split(":") if varname not in block_map: block_map[varname] = [] block_map[varname].append((int(offset), int(size))) for varname, split in six.iteritems(block_map): orig_var = self.merged_variable_map[varname] if len(split) == 1: var_mapping[varname] = [orig_var] self.var_distributed.add_distributed_var( origin_var=orig_var, slice_var=orig_var, block_id=0, offset=0, is_slice=False, vtype="Param") else: var_mapping[varname] = [] orig_shape = orig_var.shape orig_dim1_flatten = 1 if len(orig_shape) >= 2: orig_dim1_flatten = reduce(lambda x, y: x * y, orig_shape[1:]) for i, block in enumerate(split): size = block[1] rows = size // orig_dim1_flatten splited_shape = [rows] if len(orig_shape) >= 2: splited_shape.extend(orig_shape[1:]) new_var_name = "%s.block%d" % (varname, i) slice_var = vars_metatools.VarStruct( name=new_var_name, shape=splited_shape, dtype=orig_var.dtype, type=orig_var.type, lod_level=orig_var.lod_level, persistable=False) var_mapping[varname].append(slice_var) self.var_distributed.add_distributed_var( origin_var=orig_var, slice_var=slice_var, block_id=i, offset=-1, is_slice=False, vtype="Param") return var_mapping def _dispatcher(self): ps_dispatcher = RoundRobin(self.get_ps_endpoints()) ps_dispatcher.reset() grad_var_mapping_items = list(six.iteritems(self.grad_var_mapping)) sparse_gradnames = [grad.name for _, grad in self.origin_sparse_pairs] for grad_varname, splited_vars in grad_var_mapping_items: if grad_varname in sparse_gradnames: continue send_vars = [] for _, var in enumerate(splited_vars): send_vars.append(var) recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) eps = ps_dispatcher.dispatch(recv_vars) for i, ep in enumerate(eps): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) for grad_varname, splited_vars in grad_var_mapping_items: if grad_varname not in sparse_gradnames: continue ps_dispatcher.reset() send_vars = [] for _, var in enumerate(splited_vars): send_vars.append(var) recv_vars = [] for _, var in enumerate(send_vars): recv_vars.append(self.grad_param_mapping[var]) eps = ps_dispatcher.dispatch(recv_vars) for i, ep in enumerate(eps): self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i]) self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i]) def _slice_variable(self, var_list, slice_count, min_block_size, uniform=False): """ We may need to split dense tensor to one or more blocks and put them equally onto parameter server. One block is a sub-tensor aligned by dim[0] of the tensor. We need to have a minimal block size so that the calculations in the parameter server side can gain better performance. By default minimum block size 8K elements (maybe 16bit or 32bit or 64bit). Args: var_list (list): List of variables. slice_count (int): Numel of count that variables will be sliced, which could be the pserver services' count. min_block_size (int): Minimum split block size. Returns: blocks (list[(varname, block_id, current_block_size)]): A list of VarBlocks. Each VarBlock specifies a shard of the var. """ blocks = [] for var in var_list: if not uniform: var_numel = reduce(lambda x, y: x * y, var.shape) split_count = 1 # if min_block_size == -1: # split_count = 1 # else: # split_count = slice_count # max_pserver_count = int( # math.floor(var_numel / float(min_block_size))) # if max_pserver_count == 0: # max_pserver_count = 1 # if max_pserver_count < slice_count: # split_count = max_pserver_count block_size = int(math.ceil(var_numel / float(split_count))) if len(var.shape) >= 2: # align by dim1(width) dim1 = reduce(lambda x, y: x * y, var.shape[1:]) remains = block_size % dim1 if remains != 0: block_size += dim1 - remains # update split_count after aligning split_count = int(math.ceil(var_numel / float(block_size))) for block_id in range(split_count): curr_block_size = min(block_size, var_numel - ( (block_id) * block_size)) block = vars_metatools.VarBlock(var.name, block_id, curr_block_size) blocks.append(str(block)) else: block_size = var.shape[0] / slice_count remainder = var.shape[0] % slice_count if block_size == 0: dim0s = [block_size] * remainder else: dim0s = [block_size] * slice_count for i in range(remainder): dim0s[i] = dim0s[i] + 1 dim1 = reduce(lambda x, y: x * y, var.shape[1:]) for block_id in range(len(dim0s)): numel = dim0s[block_id] * dim1 block = vars_metatools.VarBlock(var.name, block_id, numel) blocks.append(str(block)) return blocks def _get_param_grad_blocks(self, pairs, min_block_size, uniform=False): param_list = [] grad_list = [] param_grad_set = set() for p, g in pairs: # todo(tangwei12) skip parameter marked not trainable # if type(p) == Parameter and p.trainable == False: # continue p = p.merged_var g = g.merged_var if p.name not in param_grad_set: param_list.append(p) param_grad_set.add(p.name) if g.name not in param_grad_set: grad_list.append(g) param_grad_set.add(g.name) # when we slice var up into blocks, we will slice the var according to # pserver services' count. A pserver may have two or more listening ports. grad_blocks = self._slice_variable(grad_list, len(self.get_ps_endpoints()), min_block_size, uniform) param_blocks = self._slice_variable(param_list, len(self.get_ps_endpoints()), min_block_size, uniform) return param_blocks, grad_blocks def _var_slice_and_distribute(self): # update these mappings for further transpile: # 1. param_var_mapping : param var name->[split params vars] # 2. grad_var_mapping : grad var name->[split grads vars] # 3. grad_param_mapping : grad.blockx->param.blockx # 4. param_grad_ep_mapping : ep->{"params" : [], "grads" : [] } dps, dgs = self._get_param_grad_blocks(self.merged_dense_pairs, -1, False) sps, sgs = self._get_param_grad_blocks(self.merged_sparse_pairs, self.min_block_size, True) param_blocks = dps + sps grad_blocks = dgs + sgs assert (len(grad_blocks) == len(param_blocks)) # origin_param_name->[splited_param_vars] self.param_var_mapping = self._create_vars_from_blocklist(param_blocks) self.grad_var_mapping = self._create_vars_from_blocklist(grad_blocks) # dict(grad_splited_var->param_splited_var) self.grad_param_mapping = collections.OrderedDict() for g, p in zip(grad_blocks, param_blocks): g_name, g_bid, _ = g.split(":") p_name, p_bid, _ = p.split(":") self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \ self.param_var_mapping[p_name][int(p_bid)] print_maps = {} for k, v in self.grad_param_mapping.items(): print_maps[str(k)] = str(v) # create mapping of endpoint->split var to create pserver side program self.param_grad_ep_mapping = collections.OrderedDict() [ self.param_grad_ep_mapping.update({ ep: { "params": [], "grads": [] } }) for ep in self.get_ps_endpoints() ] def _build_var_distributed(self): self.var_distributed = vars_metatools.VarsDistributed() sparse_pairs, dense_pairs = self.get_param_grads() origin_for_sparse = [] origin_for_dense = [] param_name_grad_name = dict() grad_name_to_param_name = dict() for param, grad in sparse_pairs: param = vars_metatools.create_var_struct(param) grad = vars_metatools.create_var_struct(grad) origin_for_sparse.append((param, grad)) for param, grad in dense_pairs: param = vars_metatools.create_var_struct(param) grad = vars_metatools.create_var_struct(grad) origin_for_dense.append((param, grad)) for dense_pair in origin_for_dense: param, grad = dense_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) self.merged_variables_pairs.append((m_param, m_grad)) self.merged_dense_pairs.append((m_param, m_grad)) for sparse_pair in origin_for_sparse: param, grad = sparse_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) self.merged_variables_pairs.append((m_param, m_grad)) self.merged_sparse_pairs.append((m_param, m_grad)) for merged in self.merged_variables_pairs: m_param, m_grad = merged self.merged_variable_map[ m_param.merged_var.name] = m_param.merged_var self.merged_variable_map[m_grad.merged_var.name] = m_grad.merged_var param_merges = [] param_merges.extend(origin_for_sparse) param_merges.extend(origin_for_dense) for param, grad in param_merges: param_name_grad_name[param.name] = grad.name grad_name_to_param_name[grad.name] = param.name self.origin_sparse_pairs = origin_for_sparse self.origin_dense_pairs = origin_for_dense self.param_name_to_grad_name = param_name_grad_name self.grad_name_to_param_name = grad_name_to_param_name sparse_pair_map = collections.OrderedDict() for pair in self.origin_sparse_pairs + self.origin_dense_pairs: param, grad = pair sparse_pair_map[param.name] = str(param) sparse_pair_map[grad.name] = str(grad) self._var_slice_and_distribute() self._dispatcher() def get_param_grads(self): origin_program = self.origin_main_program def _get_params_grads(sparse_varnames): block = origin_program.global_block() dense_param_grads = [] sparse_param_grads = [] optimize_params = set() origin_var_dict = origin_program.global_block().vars role_id = int(core.op_proto_and_checker_maker.OpRole.Backward) for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr("op_role", role_id) continue if op.attr(OP_ROLE_VAR_ATTR_NAME): param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] if param_name not in optimize_params: optimize_params.add(param_name) param_grad = (origin_var_dict[param_name], origin_var_dict[grad_name]) if param_name in sparse_varnames: sparse_param_grads.append(param_grad) else: dense_param_grads.append(param_grad) return sparse_param_grads, dense_param_grads def _get_sparse_varnames(): varnames = [] op_types = {"lookup_table": "W"} for op in origin_program.global_block().ops: if op.type in op_types.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(op_types[op.type])[0] varnames.append(param_name) return list(set(varnames)) sparse_varnames = _get_sparse_varnames() sparse_param_grads, dense_param_grads = _get_params_grads( sparse_varnames) return sparse_param_grads, dense_param_grads def _is_opt_role_op(op): # NOTE : depend on oprole to find out whether this op is for # optimize op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize if op_maker.kOpRoleAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False def _get_optimize_ops(_program): block = _program.global_block() opt_ops = [] for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr( "op_role", int(core.op_proto_and_checker_maker.OpRole.Backward)) continue opt_ops.append(op) return opt_ops def _get_varname_parts(varname): # returns origin, blockid, trainerid orig_var_name = "" trainer_part = "" block_part = "" trainer_idx = varname.find(".trainer_") if trainer_idx >= 0: trainer_part = varname[trainer_idx + 1:] else: trainer_idx = len(varname) block_index = varname.find(".block") if block_index >= 0: block_part = varname[block_index + 1:trainer_idx] else: block_index = len(varname) orig_var_name = varname[0:min(block_index, trainer_idx)] return orig_var_name, block_part, trainer_part def _orig_varname(varname): orig, _, _ = _get_varname_parts(varname) return orig