# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function from functools import reduce import collections import math import os import warnings import logging import six import paddle.fluid as fluid from paddle.fluid import core import paddle.fluid.framework as framework #logging.basicConfig( # format='%(levelname)s - %(asctime)s - %(pathname)s: %(lineno)s - %(message)s', level=logging.INFO) #logger = logging.getLogger(__name__) OP_NAME_SCOPE = "op_namescope" CLIP_OP_NAME_SCOPE = "gradient_clip" STEP_COUNTER = "@PS_STEP_COUNTER@" LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName() RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize backward = core.op_proto_and_checker_maker.OpRole.Backward DEVICE_LIST = ["cpu", "gpu", "xpu"] COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"] SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"] SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"} SPARSE_GRAD_OP_TYPE_DICT = { "lookup_table_grad": "W", "lookup_table_v2_grad": "W" } DEFAULT_DEVICE = 'cpu' DATA_NORM_NAME = [".batch_size", ".batch_sum", ".batch_square_sum"] DATA_NORM_GRAD_NAME = [x + "@GRAD" for x in DATA_NORM_NAME] def logger_config(log_path, logging_name): logger = logging.getLogger(logging_name) logger.setLevel(level=logging.WARNING) handler = logging.FileHandler( log_path, mode='a', encoding='UTF-8', delay=True) handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(levelname)s - %(asctime)s - %(pathname)s: %(lineno)s - %(message)s') handler.setFormatter(formatter) console = logging.StreamHandler() console.setLevel(logging.DEBUG) logger.addHandler(handler) logger.addHandler(console) return logger ps_log_root_dir = './ps_log/' logger = logger_config( log_path='./ps_usr_print_log', logging_name='ps_usr_print_log') class DistributedMode: SYNC = 0 ASYNC = 1 HALF_ASYNC = 2 GEO = 3 FL = 4 class TrainerRuntimeConfig(object): def __init__(self, valid_strategy): self.mode = None num_threads = os.getenv("CPU_NUM", "1") send_queue_size = num_threads k_steps = valid_strategy.a_sync_configs["k_steps"] logger.info("ps mode in strategy: {}, {}".format( valid_strategy.a_sync, valid_strategy.a_sync_configs["k_steps"])) if not valid_strategy.a_sync and k_steps == 0: self.mode = DistributedMode.SYNC if valid_strategy.a_sync and k_steps == 0: self.mode = DistributedMode.ASYNC if valid_strategy.a_sync and k_steps > 0: self.mode = DistributedMode.GEO send_queue_size = k_steps self.runtime_configs = {} self.runtime_configs['communicator_max_merge_var_num'] = os.getenv( "FLAGS_communicator_max_merge_var_num", send_queue_size) self.runtime_configs['communicator_send_queue_size'] = os.getenv( "FLAGS_communicator_send_queue_size", send_queue_size) self.runtime_configs[ 'communicator_independent_recv_thread'] = os.getenv( "FLAGS_communicator_independent_recv_thread", "1") self.runtime_configs[ 'communicator_min_send_grad_num_before_recv'] = os.getenv( "FLAGS_communicator_min_send_grad_num_before_recv", num_threads) self.runtime_configs['communicator_thread_pool_size'] = os.getenv( "FLAGS_communicator_thread_pool_size", "5") self.runtime_configs['communicator_send_wait_times'] = os.getenv( "FLAGS_communicator_send_wait_times", "5") self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv( "FLAGS_communicator_is_sgd_optimizer", "1") def get_communicator_flags(self): need_keys = [] num_threads = os.getenv("CPU_NUM", "1") mode_str = "" if self.mode is None or self.mode == DistributedMode.ASYNC: need_keys = self.runtime_configs.keys() mode_str = "async" elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: mode_str = "sync or half_async" need_keys = [ 'communicator_max_merge_var_num', 'communicator_send_wait_times', 'communicator_thread_pool_size', 'communicator_send_queue_size' ] elif self.mode == DistributedMode.GEO: mode_str = "GEO" need_keys = [ 'communicator_thread_pool_size', 'communicator_send_wait_times', 'communicator_max_merge_var_num', 'communicator_send_queue_size' ] else: raise ValueError("Unsupported Mode") if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: max_merge_var_num = self.runtime_configs[ 'communicator_max_merge_var_num'] send_queue_size = self.runtime_configs[ 'communicator_send_queue_size'] if max_merge_var_num != num_threads: print('WARNING: In {} mode, communicator_max_merge_var_num ' 'must be equal to CPU_NUM. But received, ' 'communicator_max_merge_var_num = {}, CPU_NUM = ' '{}. communicator_max_merge_var_num will be forced to {}.' .format(mode_str, max_merge_var_num, num_threads, num_threads)) self.runtime_configs[ 'communicator_max_merge_var_num'] = num_threads if send_queue_size != num_threads: print('WARNING: In {} mode, communicator_send_queue_size ' 'must be equal to CPU_NUM. But received, ' 'communicator_send_queue_size = {}, CPU_NUM = ' '{}. communicator_send_queue_size will be forced to {}.' .format(mode_str, send_queue_size, num_threads, num_threads)) self.runtime_configs[ 'communicator_send_queue_size'] = num_threads return dict((key, str(self.runtime_configs[key])) for key in need_keys) def get_lr_ops(program): lr_ops = [] for index, op in enumerate(program.global_block().ops): role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ int(OPT_OP_ROLE_ATTR_VALUE): lr_ops.append(op) return lr_ops def get_optimize_ops(_program): block = _program.global_block() opt_ops = [] for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr( "op_role", int(core.op_proto_and_checker_maker.OpRole.Backward)) continue opt_ops.append(op) return opt_ops def get_dist_env(): trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0')) trainer_endpoints = '' current_endpoint = '' num_trainers = 0 if os.getenv('PADDLE_TRAINER_ENDPOINTS'): trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS') current_endpoint = trainer_endpoints.split(',')[trainer_id] num_trainers = len(trainer_endpoints.split(',')) return { 'trainer_id': trainer_id, 'num_trainers': num_trainers, 'current_endpoint': current_endpoint, 'trainer_endpoints': trainer_endpoints } def get_role_id(role_maker): try: return role_maker._role_id() except Exception: return role_maker.role_id() def get_ps_endpoint(role_maker): try: return role_maker._get_pserver_endpoints()[get_role_id(role_maker)] except Exception: return role_maker.get_pserver_endpoints()[get_role_id(role_maker)] def get_ps_endpoints(role_maker): try: return role_maker._get_pserver_endpoints() except Exception: return role_maker.get_pserver_endpoints() def get_heter_worker_endpoint(role_maker): try: return role_maker._get_heter_worker_endpoint() except Exception: return role_maker.get_heter_worker_endpoint() def get_trainer_endpoint(role_maker): try: return role_maker._get_trainer_endpoint() except Exception: return role_maker.get_trainer_endpoint() def get_previous_stage_trainers(role_maker): try: return role_maker._get_previous_trainers() except Exception: return role_maker.get_previous_trainers() def is_distributed_sparse_op(op): if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is True: return True return False def get_sparse_tablename(op): return op.input("W")[0] def is_sparse_op(op): if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr( 'is_distributed') is False: return True if op.type == "distributed_lookup_table" and op.attr( 'is_distributed') is False: return True return False def get_sparse_tablenames(programs, is_distributed): tablenames = set() for program in programs: if is_distributed: for op in program.global_block().ops: if is_distributed_sparse_op(op): tablenames.add(get_sparse_tablename(op)) else: for op in program.global_block().ops: if is_sparse_op(op): tablenames.add(get_sparse_tablename(op)) return list(tablenames) def get_trainers(role_maker): try: return role_maker._worker_num() except Exception: return role_maker.worker_num() def get_dense_send_context(program, send_ctx, idx, merged_dense_pairs, trainer_id, split_dense_table=False): if len(merged_dense_pairs) < 1: return idx if not split_dense_table: dense_pairs = [] data_norm_pairs = [] for merged in merged_dense_pairs: is_data_norm = False grad = merged[1] varname = grad.merged_var.name for name in DATA_NORM_GRAD_NAME: if varname.endswith(name): is_data_norm = True if is_data_norm: data_norm_pairs.append(merged) else: dense_pairs.append(merged) # simple dense table origin_varnames = [] var_numel = 0 for merged in dense_pairs: grad = merged[1] origin_varnames.append(grad.merged_var.name) var = program.global_block().vars[grad.merged_var.name] var_numel += reduce(lambda x, y: x * y, var.shape) grad_name = "Dense@GRAD_" + str(idx) aggregate = True print("public get_dense_send_context dense_table:", grad_name, var_numel, origin_varnames) from paddle.fluid.core import CommContext dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], origin_varnames, trainer_id, aggregate, False, False, idx, False, False, id(program)) send_ctx[grad_name] = dense_ctx idx += 1 if len(data_norm_pairs) <= 0: return idx # data norm table origin_varnames = [] var_numel = 0 for merged in data_norm_pairs: grad = merged[1] origin_varnames.append(grad.merged_var.name) var = program.global_block().vars[grad.merged_var.name] var_numel += reduce(lambda x, y: x * y, var.shape) grad_name = "DataNorm@GRAD_" + str(idx) aggregate = True print("public get_dense_send_context data_norm table:", grad_name, var_numel, origin_varnames) from paddle.fluid.core import CommContext data_norm_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], origin_varnames, trainer_id, aggregate, False, False, idx, False, True, id(program)) send_ctx[grad_name] = data_norm_ctx idx += 1 else: for merged in merged_dense_pairs: grad = merged[1] origin_varname = grad.merged_var.name var = program.global_block().vars[origin_varname] var_numel = reduce(lambda x, y: x * y, var.shape) grad_name = origin_varname aggregate = True from paddle.fluid.core import CommContext dense_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [origin_varname], trainer_id, aggregate, False, False, idx, False, False, id(program)) send_ctx[grad_name] = dense_ctx idx += 1 return idx def get_geo_trainer_send_context(context): if context['ps_mode'] != DistributedMode.GEO: raise ValueError("ps mode: {} not matched {}", format(ps_mode, "get_geo_trainer_send_context")) send_ctx = {} trainer_id = get_role_id(context['role_maker']) origin_programs = context['origin_main_programs'] idx = 0 distibuted_varnames = get_sparse_tablenames(origin_programs, True) for i, program in enumerate(origin_programs): merged_sparse_pairs = context['merged_sparse_pairs'][i] for merged in merged_sparse_pairs: param, grad = merged grad_name = grad.merged_var.name param_name = param.merged_var.name is_distributed = True if param_name in distibuted_varnames else False var = program.global_block().vars[grad.merged_var.name] var_numel = reduce(lambda x, y: x * y, var.shape[1:]) from paddle.fluid.core import CommContext sparse_ctx = CommContext(grad_name, [grad_name], ["127.0.0.1:6071"], [var_numel], [grad_name], trainer_id, True, True, is_distributed, idx, False, False, id(program)) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx if len(send_ctx) == 0: raise ValueError("GeoSGD require sparse parameters in your net.") if len(context['tensor_table']) > 0 and context['is_worker']: name, ctx = _step_ctx(idx, context['role_maker']) send_ctx[name] = ctx return send_ctx def _step_ctx(idx, role_maker): name = STEP_COUNTER trainer_id = get_role_id(role_maker) endpoints = get_ps_endpoints(role_maker) sections = [1] * len(endpoints) names = [name] * len(endpoints) from paddle.fluid.core import CommContext ctx = CommContext(name, names, endpoints, sections, [name], trainer_id, True, False, False, idx, True, False, -1) return name, ctx def get_the_one_send_context(context, split_dense_table=False, use_origin_program=False, ep_list=None): if ep_list is None: ep_list = ["127.0.0.1:6071"] send_ctx = {} trainer_id = get_role_id(context['role_maker']) origin_programs = context['origin_main_programs'] idx = 0 distibuted_varnames = get_sparse_tablenames(origin_programs, True) # print("public distibuted_varnames:", distibuted_varnames) for i, program in enumerate(origin_programs): merged_sparse_pairs = context['merged_sparse_pairs'][i] for merged in merged_sparse_pairs: param, grad = merged grad_name = grad.merged_var.name param_name = param.merged_var.name splited_varname = [] for i in range(len(ep_list)): splited_varname.append("{}.block{}".format(param_name, i)) is_distributed = True if param_name in distibuted_varnames else False var = program.global_block().vars[grad.merged_var.name] shape = list(var.shape) shape[0] = 0 if is_distributed else shape[0] # print("public get_the_one_send_context sparse:", grad_name, # splited_varname, shape) if grad_name in send_ctx: continue from paddle.fluid.core import CommContext sparse_ctx = CommContext(grad_name, splited_varname, ep_list, shape, [grad_name], trainer_id, True, True, is_distributed, idx, False, False, id(program)) idx += 1 send_ctx[sparse_ctx.var_name()] = sparse_ctx for i, program in enumerate(origin_programs): merged_dense_pairs = context['merged_dense_pairs'][i] idx = get_dense_send_context(program, send_ctx, idx, merged_dense_pairs, trainer_id, split_dense_table) if len(context['tensor_table']) > 0 and context['is_worker']: name, ctx = _step_ctx(idx, context['role_maker']) send_ctx[name] = ctx return send_ctx def find_heter_ops(program, default_device="cpu"): if default_device not in DEVICE_LIST: raise ValueError("Given device {} is not in device list {}".format( default_device, DEVICE_LIST)) def _is_heter_op(op, current_heter_device, default_device="cpu"): heter_devices = list(DEVICE_LIST) heter_devices.remove(default_device) op_device = op.attr("op_device") op_type = op.type if op_device in heter_devices: return True elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device: # for distributed communciate ops: send & recv & barrier etc. # Todo: need update this method #op._set_attr('op_device', current_heter_device) return True elif op_device == None or op_device == default_device: op._set_attr('op_device', default_device) return False return False def _is_same_device(op, pre_device, default_device="cpu"): op_device = op.attr("op_device") if op_device == pre_device: return True if pre_device == default_device: return True return False def _append_heter_op(op, current_heter_block_ops, heter_ops): op_device = op.attr("op_device") if op_device not in heter_ops: heter_ops[op_device] = {} current_heter_block_ops.append(op) origin_porgram = program.clone() block = program.global_block() ''' re-place sum op to fix bug for union forward backward op ''' var2idx = {} op_list = list(block.ops) op_size = len(op_list) for i in range(op_size - 1, -1, -1): op_list = list(block.ops) op = op_list[i] if "_grad" in op.type: forward_op_type = op.type.split("_grad")[0] if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0] if param_name in var2idx: ## insert sum op & remove sum op from var2idx and origin place op_list = list(block.ops) sum_op = op_list[var2idx[param_name]] sum_op_inputs = { sum_op.input_names[0]: [ block.vars[input] for input in sum_op.input_arg_names ] } sum_op_outputs = { sum_op.output_names[0]: [ block.vars[output] for output in sum_op.output_arg_names ] } block._insert_op( index=i + 1, type=sum_op.type, inputs=sum_op_inputs, outputs=sum_op_outputs, attrs=sum_op.all_attrs()) block._remove_op(var2idx[param_name] + 1) var2idx.pop(param_name) for var_ in var2idx: var2idx[var_] += 1 elif forward_op_type == "elementwise_mul": """ get output varname of pre op """ output_vars_no_grad = [] for key in op.output_names: for varname in op.output(key): if varname == "@EMPTY@": continue if "lod_tensor_blocking_queue" in varname: continue output_vars_no_grad.append(varname.split("@GRAD")[0]) for no_grad_var in output_vars_no_grad: if no_grad_var in var2idx: """ insert sum op & remove sum op from var2idx and origin place """ op_list = list(block.ops) sum_op = op_list[var2idx[no_grad_var]] sum_op_inputs = { sum_op.input_names[0]: [ block.vars[input] for input in sum_op.input_arg_names ] } sum_op_outputs = { sum_op.output_names[0]: [ block.vars[output] for output in sum_op.output_arg_names ] } block._insert_op( index=i + 1, type=sum_op.type, inputs=sum_op_inputs, outputs=sum_op_outputs, attrs=sum_op.all_attrs()) block._remove_op(var2idx[no_grad_var] + 1) var2idx.pop(no_grad_var) for var_ in var2idx: var2idx[var_] += 1 else: if op.type == "sum": var = op.output("Out")[0] if "@GRAD" in var: origin_var = var.split("@GRAD")[0] pre_op = op_list[i - 1] if "_grad" in pre_op.type: forward_op_type = pre_op.type.split("_grad")[0] if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \ and pre_op.attr('remote_prefetch') is True: param_name = pre_op.input(SPARSE_OP_TYPE_DICT[ forward_op_type])[0] if param_name == origin_var and op.attr( "op_device") == pre_op.attr("op_device"): continue else: var2idx[origin_var] = i elif forward_op_type == "elementwise_mul": output_vars = [] for key in pre_op.output_names: for varname in pre_op.output(key): if varname == "@EMPTY@": continue if "lod_tensor_blocking_queue" in varname: continue output_vars.append(varname) input_vars = [] for key in op.input_names: for varname in op.input(key): if varname == "@EMPTY@": continue if "lod_tensor_blocking_queue" in varname: continue input_vars.append(varname) is_match = False for varname in output_vars: if varname in input_vars: is_match = True break if is_match: continue else: var2idx[origin_var] = i else: var2idx[origin_var] = i origin_porgram = program.clone() block = program.global_block() program_block_ops = [] default_ops = {default_device: {}} heter_ops = {} block_index = 0 current_heter_block_ops = [] current_default_block_ops = [] current_heter_device = default_device is_heter = False for op in block.ops: if _is_heter_op(op, current_heter_device, default_device): # for gpu/xpu-op is_heter = True # for cpu-op block append if len(current_default_block_ops) > 1: default_ops[default_device][ block_index] = current_default_block_ops program_block_ops.append(current_default_block_ops) current_default_block_ops = [] block_index += 1 if _is_same_device(op, current_heter_device, default_device): # for gpu-op, gpu-op -> gpu-op,... current_heter_device = op.attr("op_device") _append_heter_op(op, current_heter_block_ops, heter_ops) else: # for gpu-op -> xpu-op, ... op_device = current_heter_block_ops[0].attr("op_device") heter_ops[op_device][block_index] = current_heter_block_ops program_block_ops.append(current_heter_block_ops) block_index += 1 current_heter_block_ops = [] current_heter_device = op.attr("op_device") _append_heter_op(op, current_heter_block_ops, heter_ops) elif is_heter: # for gpu/xpu-op -> cpu-op op_device = current_heter_block_ops[0].attr("op_device") heter_ops[op_device][block_index] = current_heter_block_ops program_block_ops.append(current_heter_block_ops) block_index += 1 current_heter_block_ops = [] current_heter_device = default_device is_heter = False current_default_block_ops.append(op) else: # for cpu-op current_default_block_ops.append(op) if current_default_block_ops != []: default_ops[default_device][block_index] = current_default_block_ops program_block_ops.append(current_default_block_ops) if current_heter_block_ops != []: op_device = current_heter_block_ops[0].attr("op_device") heter_ops[op_device][block_index] = current_heter_block_ops program_block_ops.append(current_heter_block_ops) if len(heter_ops) == 0: warnings.warn( "No heterogeneous OP was found in your program , " " please using fluid.device_guard() to run OPs on different device.") total_heter_ops = 0 heter_blocks = 0 for device in heter_ops.keys(): heter_block_dict = heter_ops[device] heter_blocks += len(heter_block_dict) for _, heter_block in heter_block_dict.items(): total_heter_ops += len(heter_block) print( "There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.". format(len(block.ops), total_heter_ops, heter_blocks)) return origin_porgram, heter_ops, default_ops, program_block_ops def union_forward_gradient_op(program_block_ops_list): """ before analyzing the input & output of each block in program_block_list, we should union the forward op and corresponding gradient op to elimincate the unnecessary variable transmit """ """ fix for 2emb model, re-place sum op """ block_length = len(program_block_ops_list) union_program_block_ops_list = [] assert block_length % 2 != 0, "the length of program_block_ops_list should be odd" for i in range(0, block_length // 2): block_op_list = {"forward": program_block_ops_list[i]} block_op_list.update({ "backward": program_block_ops_list[block_length - 1 - i] }) union_program_block_ops_list.append(block_op_list) block_op_list = {"forward": [], "backward": []} for op in program_block_ops_list[block_length // 2]: if not "_grad" in op.type and not (op.type == "sum"): block_op_list["forward"].append(op) else: block_op_list["backward"].append(op) union_program_block_ops_list.append(block_op_list) return union_program_block_ops_list def find_block_joints(program, program_block_ops_list, heter_ops): block_var_detail = find_entrance_exit_private(program, program_block_ops_list) block_var_detail = entrance_exit_check(program, program_block_ops_list, block_var_detail, heter_ops) block_var_detail = delete_block_useless_exit( program, program_block_ops_list, block_var_detail) return block_var_detail def find_ops_list_input_output(program, ops_list): input_var_list = [] output_var_list = [] for op in ops_list: inputs = _get_input_map_from_op(program.global_block().vars, op) input_var_list += get_varlist_from_op_map(inputs) outputs = _get_output_map_from_op(program.global_block().vars, op) output_var_list += get_varlist_from_op_map(outputs) input_var_list = list(set(input_var_list)) output_var_list = list(set(output_var_list)) return input_var_list, output_var_list def find_entrance_exit_private(program, program_block_ops_list): block_var_detail = [] persistables = [] for index, block_op_list in enumerate(program_block_ops_list): ## forward block_input, block_output = find_ops_list_input_output( program, block_op_list["forward"]) persistables = screen_persistables( program, block_input) + screen_persistables(program, block_output) # find entrance & exit block_private_vars = list(set(block_input) & set(block_output)) block_entrance = list(set(block_input) - set(block_private_vars)) block_exit = list(set(block_output) - set(block_private_vars)) detail = { "forward": { "entrance": block_entrance, "exit": block_exit, "private": block_private_vars, "persistables": persistables } } ## backward bp_block_input, bp_block_output = find_ops_list_input_output( program, block_op_list["backward"]) bp_persistables = screen_persistables( program, bp_block_input) + screen_persistables(program, bp_block_output) # find entrance & exit bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output)) bp_block_entrance = list( set(bp_block_input) - set(bp_block_private_vars)) bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars)) detail.update({ "backward": { "entrance": bp_block_entrance, "exit": bp_block_exit, "private": bp_block_private_vars, "persistables": bp_persistables } }) block_var_detail.append(detail) return block_var_detail def entrance_exit_check(program, program_block_ops_list, block_var_detail, heter_ops): for index in range(len(block_var_detail) - 1, -1, -1): if index - 1 < 0: break previous_block_exit = block_var_detail[index - 1]["forward"]["exit"] previous_block_exit.sort() current_block_entrance = block_var_detail[index]["forward"]["entrance"] backward_entrance = block_var_detail[index]["backward"]["entrance"] forward_all = block_var_detail[index]["forward"][ "entrance"] + block_var_detail[index]["forward"][ "private"] + block_var_detail[index]["forward"]["exit"] for var in backward_entrance: if not ("@GRAD" in var) and not (var in forward_all): current_block_entrance.append(var) current_block_entrance.sort() if previous_block_exit == current_block_entrance: continue exist_vars = list( set(previous_block_exit) & set(current_block_entrance)) need_add_vars = list(set(current_block_entrance) - set(exist_vars)) # var in different stage should not be ignored, since they are not placed in the same program & device #need_add_vars = find_need_var_from_previous_block( # need_add_vars, block_var_detail, index, heter_ops) previous_block_private = block_var_detail[index - 1]["forward"][ "private"] previous_block_entrance = block_var_detail[index - 1]["forward"][ "entrance"] for var in need_add_vars: if var not in previous_block_private and var not in previous_block_entrance: previous_block_entrance.append(var) previous_block_exit.append(var) if not var in current_block_entrance: current_block_entrance.append(var) for index in range(0, len(block_var_detail) - 1, 1): previous_block_exit = block_var_detail[index + 1]["backward"]["exit"] previous_block_exit.sort() current_block_entrance = block_var_detail[index]["backward"]["entrance"] current_block_entrance.sort() if previous_block_exit == current_block_entrance: continue exist_vars = list( set(previous_block_exit) & set(current_block_entrance)) need_add_vars = list(set(current_block_entrance) - set(exist_vars)) need_ignore_vars = [] for var in need_add_vars: if not "@GRAD" in var: need_ignore_vars.append(var) need_add_vars = list( set(need_add_vars).difference(set(need_ignore_vars))) previous_block_private = block_var_detail[index + 1]["backward"][ "private"] previous_block_entrance = block_var_detail[index + 1]["backward"][ "entrance"] for var in need_add_vars: if var not in previous_block_private and var not in previous_block_entrance: previous_block_entrance.append(var) previous_block_exit.append(var) return block_var_detail def delete_block_useless_exit(program, program_block_ops_list, block_var_detail): ## forward for index in range(len(block_var_detail)): if index == len(block_var_detail) - 1: break current_block_exit = block_var_detail[index]["forward"]["exit"] next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"] need_delete_var = [] for var in current_block_exit: if var not in next_block_entrance: need_delete_var.append(var) for var in need_delete_var: current_block_exit.remove(var) ## backward for index in range(len(block_var_detail) - 1, -1, -1): if index - 1 < 0: break current_block_exit = block_var_detail[index]["backward"]["exit"] next_block_entrance = block_var_detail[index - 1]["backward"][ "entrance"] need_delete_var = [] for var in current_block_exit: if var not in next_block_entrance: need_delete_var.append(var) for var in need_delete_var: current_block_exit.remove(var) return block_var_detail def get_communicate_var_info(program, block_index, entrance_var_list, type="forward"): input_var_reshape_dim = [] input_var_reshape_name = [] if type == "forward": block_input_var_name = "forward_joint_{}_{}@Heter".format( block_index - 1, block_index) else: block_input_var_name = "backward_joint_{}_{}@Heter".format( block_index + 1, block_index) entrance_var_list.sort() # input # Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var for name in entrance_var_list: var = program.global_block().vars[name] shape = var.shape recv_var_dim = -1 * reduce(lambda x, y: x * y, shape) input_var_reshape_dim.append(recv_var_dim) input_var_reshape_name.append("{}.input_reshape@Heter".format(name)) info = { "input_var_reshape_dim": input_var_reshape_dim, "input_var_reshape_name": input_var_reshape_name, "block_input_var_name": block_input_var_name, } return info def add_vars_by_var_list(var_name_list, origin_program, program, block): for var_name in var_name_list: if var_name not in program.global_block( ).vars and var_name not in block.vars: var = origin_program.global_block().vars[var_name] if var.persistable: program.global_block()._clone_variable( var, force_persistable=False) else: block._clone_variable(var, force_persistable=False) def _get_output_map_from_op(varmap, op): """Returns a dict from op output name to the vars in varmap.""" iomap = collections.OrderedDict() for key in op.output_names: vars = [] for varname in op.output(key): if varname == "@EMPTY@": continue if "lod_tensor_blocking_queue" in varname: continue vars.append(varmap[varname]) if len(vars) == 1: iomap[key] = vars[0] else: iomap[key] = vars return iomap def get_varlist_from_op_map(var_map): var_list = [] for key, varlist in six.iteritems(var_map): if not isinstance(varlist, list): varlist = [varlist] for i in range(len(varlist)): var = varlist[i] var_list.append(var.name) return var_list def _get_input_map_from_op(varmap, op): """Returns a dict from op input name to the vars in varmap.""" iomap = collections.OrderedDict() for key in op.input_names: vars = [] for varname in op.input(key): if varname == "@EMPTY@": continue if "lod_tensor_blocking_queue" in varname: continue vars.append(varmap[varname]) if len(vars) == 1: iomap[key] = vars[0] else: iomap[key] = vars return iomap def screen_persistables(program, var_list): need_remove = [] for var_name in var_list: if "@GRAD" in var_name: if "GRAD" != var_name.split("@")[-1]: continue origin_var_name = var_name.split("@GRAD")[0] var = program.global_block().vars[origin_var_name] else: var = program.global_block().vars[var_name] if fluid.io.is_persistable(var): need_remove.append(var_name) for var_name in need_remove: var_list.remove(var_name) return need_remove def block_append_op(program, origin_program, block, op): merge_ordereddict = origin_program.global_block().vars.copy() merge_ordereddict.update(block.vars) inputs = _get_input_map_from_op(merge_ordereddict, op) for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: if var.name not in program.global_block( ).vars and var.name not in block.vars: if var.persistable: program.global_block()._clone_variable( var, force_persistable=False) else: block._clone_variable(var, force_persistable=False) outputs = _get_output_map_from_op(origin_program.global_block().vars, op) for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: if var.name not in program.global_block( ).vars and var.name not in block.vars: if var.persistable: program.global_block()._clone_variable( var, force_persistable=False) else: block._clone_variable(var, force_persistable=False) if "_grad" not in op.type: # for forward op return block.append_op( type=op.type, inputs=inputs, outputs=outputs, attrs=op.all_attrs()) else: # for grad op op_desc = op.desc op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() backward = core.op_proto_and_checker_maker.OpRole.Backward device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() # append grad op new_op_desc = block.desc.append_op() new_op_desc.copy_from(op_desc) new_op_desc._set_attr(op_role_attr_name, backward) # set device gard if op.desc.has_attr(device_attr_name): op_device = op_desc.attr(device_attr_name) new_op_desc._set_attr(device_attr_name, op_device) block._sync_with_cpp() def get_next_stage_trainers(role_maker): try: return role_maker._get_next_trainers() except Exception: return role_maker.get_next_trainers() def insert_communicate_op(orign_program, role_maker, heter_block, stage_id, first_op_index, block_var_detail, device, is_forward=True): if is_forward: next_heter_worker_endpoints = get_next_stage_trainers(role_maker) previous_heter_worker_endpoints = get_previous_stage_trainers( role_maker) entrance_var = block_var_detail[stage_id]["forward"]["entrance"] comm_info = get_communicate_var_info(orign_program, stage_id + 1, entrance_var) else: next_heter_worker_endpoints = get_next_stage_trainers(role_maker) previous_heter_worker_endpoints = get_previous_stage_trainers( role_maker) entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"] comm_info = get_communicate_var_info(orign_program, stage_id - 1, entrance_var, "backward") heter_block._insert_op( index=first_op_index, type="send_and_recv", inputs={"X": heter_block.vars[entrance_var[0]]}, outputs={"Out": []}, attrs={ "mode": "forward" if is_forward else "backward", "send_var_name": entrance_var + ["microbatch_id"], "recv_var_name": [], "message_name": comm_info["block_input_var_name"], "next_endpoints": next_heter_worker_endpoints, "previous_endpoints": previous_heter_worker_endpoints, "trainer_id": get_role_id(role_maker), "op_device": device, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return entrance_var def get_the_one_recv_context(context, is_dense=True, split_dense_table=False, use_origin_program=False): recv_id_maps = {} grad_name_to_param_name = {} if is_dense: send_ctx = get_the_one_send_context( context, split_dense_table=split_dense_table, use_origin_program=use_origin_program) for idx, (name, ctx) in enumerate(send_ctx.items()): if ctx.is_sparse(): continue if ctx.is_tensor_table(): continue origin_grad_varnames = ctx.origin_varnames() param_names = [] for grad_varname in origin_grad_varnames: param_name = context["grad_name_to_param_name"][grad_varname] param_names.append(param_name) recv_id_maps[ctx.table_id()] = param_names else: send_ctx = get_the_one_send_context( context, split_dense_table=False, use_origin_program=False, ep_list=None) for idx, (name, ctx) in enumerate(send_ctx.items()): if not ctx.is_sparse(): continue origin_grad_varnames = ctx.origin_varnames() param_names = [] for grad_varname in origin_grad_varnames: param_name = context["grad_name_to_param_name"][grad_varname] param_names.append(param_name) recv_id_maps[ctx.table_id()] = param_names return recv_id_maps def _get_varname_parts(varname): # returns origin, blockid, trainerid orig_var_name = "" trainer_part = "" block_part = "" trainer_idx = varname.find(".trainer_") if trainer_idx >= 0: trainer_part = varname[trainer_idx + 1:] else: trainer_idx = len(varname) block_index = varname.find(".block") if block_index >= 0: block_part = varname[block_index + 1:trainer_idx] else: block_index = len(varname) orig_var_name = varname[0:min(block_index, trainer_idx)] return orig_var_name, block_part, trainer_part dtype_to_size = { core.VarDesc.VarType.FP16: 2, core.VarDesc.VarType.FP32: 4, core.VarDesc.VarType.FP64: 8, core.VarDesc.VarType.INT16: 2, core.VarDesc.VarType.INT32: 4, core.VarDesc.VarType.INT64: 8, core.VarDesc.VarType.BOOL: 1, core.VarDesc.VarType.UINT8: 1, } def get_var_mem_size(var): m_size = reduce(lambda x, y: x * y, var.shape) m_size *= dtype_to_size[var.dtype] return m_size class MergedVariable: def __init__(self, merged, ordered, offsets): self.merged_var = merged self.ordered_vars = ordered self.offsets = offsets def build_var_distributed(context): origin_programs = context['origin_main_programs'] param_name_to_grad_name = {} grad_name_to_param_name = {} context["origin_sparse_pairs"] = [] context["origin_dense_pairs"] = [] context["merged_sparse_pairs"] = [] context['merged_dense_pairs'] = [] context["merged_variables_pairs"] = [] context["merged_variable_map"] = {} for origin_program in origin_programs: sparse_pairs, dense_pairs = get_param_grads(origin_program) # print("public build_var_distributed sparse_pairs:", sparse_pairs) # print("public build_var_distributed dense_pairs:", dense_pairs) origin_for_sparse = [] origin_for_dense = [] merged_sparse_pairs = [] merged_dense_pairs = [] merged_variables_pairs = [] for param, grad in sparse_pairs: origin_for_sparse.append((param, grad)) for param, grad in dense_pairs: origin_for_dense.append((param, grad)) for dense_pair in origin_for_dense: param, grad = dense_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) merged_variables_pairs.append((m_param, m_grad)) merged_dense_pairs.append((m_param, m_grad)) # print("public build_var_distributed merged_dense_pairs:", # merged_dense_pairs) for sparse_pair in origin_for_sparse: param, grad = sparse_pair m_param = MergedVariable(param, [param], [0]) m_grad = MergedVariable(grad, [grad], [0]) merged_variables_pairs.append((m_param, m_grad)) merged_sparse_pairs.append((m_param, m_grad)) # print("public build_var_distributed merged_sparse_pairs:", # merged_sparse_pairs) for merged in merged_variables_pairs: m_param, m_grad = merged context["merged_variable_map"][ m_param.merged_var.name] = m_param.merged_var context["merged_variable_map"][ m_grad.merged_var.name] = m_grad.merged_var param_merges = [] param_merges.extend(origin_for_sparse) param_merges.extend(origin_for_dense) for param, grad in param_merges: param_name_to_grad_name[param.name] = grad.name grad_name_to_param_name[grad.name] = param.name context["origin_sparse_pairs"].append(origin_for_sparse) context["origin_dense_pairs"].append(origin_for_dense) context["merged_sparse_pairs"].append(merged_sparse_pairs) context['merged_dense_pairs'].append(merged_dense_pairs) context["param_name_to_grad_name"] = param_name_to_grad_name context["grad_name_to_param_name"] = grad_name_to_param_name # print("public build_var_distributed origin_sparse_pairs:", # context["origin_sparse_pairs"]) # print("public build_var_distributed origin_for_dense:", # context["origin_dense_pairs"]) # print("public build_var_distributed merged_sparse_pairs:", # context["merged_sparse_pairs"]) # print("public build_var_distributed merged_dense_pairs:", # context['merged_dense_pairs']) # print("public build_var_distributed param_name_to_grad_name:", # param_name_to_grad_name) # print("public build_var_distributed grad_name_to_param_name:", # grad_name_to_param_name) def _is_opt_role_op(op): # NOTE : depend on oprole to find out whether this op is for # optimize op_maker = core.op_proto_and_checker_maker optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize if op_maker.kOpRoleAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role): return True return False def get_param_grads(origin_program): def _get_params_grads(sparse_varnames): block = origin_program.global_block() dense_param_grads = [] sparse_param_grads = [] optimize_params = set() origin_var_dict = origin_program.global_block().vars role_id = int(core.op_proto_and_checker_maker.OpRole.Backward) for op in block.ops: if _is_opt_role_op(op): # delete clip op from opt_ops when run in Parameter Server mode if OP_NAME_SCOPE in op.all_attrs() \ and CLIP_OP_NAME_SCOPE in op.attr(OP_NAME_SCOPE): op._set_attr("op_role", role_id) continue if op.attr(OP_ROLE_VAR_ATTR_NAME): param_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[0] grad_name = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] if param_name not in optimize_params: optimize_params.add(param_name) param_grad = (origin_var_dict[param_name], origin_var_dict[grad_name]) if param_name in sparse_varnames: sparse_param_grads.append(param_grad) else: dense_param_grads.append(param_grad) return sparse_param_grads, dense_param_grads def _get_sparse_varnames(): varnames = [] for op in origin_program.global_block().ops: if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] varnames.append(param_name) return list(set(varnames)) sparse_varnames = _get_sparse_varnames() sparse_param_grads, dense_param_grads = _get_params_grads(sparse_varnames) return sparse_param_grads, dense_param_grads def delete_ops(block, ops): for op in ops: try: idx = list(block.ops).index(op) block._remove_op(idx) except Exception as e: print(e) def find_send_op(program): send_op_list = [] for op in program.global_block().ops: if op.type == "send": send_op_list.append(op) return send_op_list def find_op_input_output(program, block, op): input_var_list = [] output_var_list = [] inputs = _get_input_map_from_op(block.vars, op) input_var_list += get_varlist_from_op_map(inputs) outputs = _get_output_map_from_op(block.vars, op) output_var_list += get_varlist_from_op_map(outputs) input_var_list = list(set(input_var_list)) output_var_list = list(set(output_var_list)) return input_var_list, output_var_list def add_heter_send_op(program, heter_program, block, block_var_detail): def _get_send_op_dict(): send_op_dict = {} send_op_list = find_send_op(program) for op in send_op_list: input_list, _ = find_op_input_output(program, program.global_block(), op) for var in input_list: send_op_dict[var] = op return send_op_dict send_grad_var_list = [] send_op_dict = _get_send_op_dict() table_dict = {} for persistable_var in block_var_detail["backward"]["persistables"]: if "@GRAD" not in persistable_var: continue if "GRAD" != persistable_var.split("@")[-1]: continue if persistable_var not in send_op_dict: continue send_op = send_op_dict[persistable_var] is_sparse = send_op.attr('is_sparse') table_id = send_op.attr('table_id') send_varnames = send_op.attr('send_varnames') send_grad_var_list.append(persistable_var) if table_id not in table_dict: table_dict[table_id] = {} table_dict[table_id]['var_list'] = [] table_dict[table_id]['is_sparse'] = is_sparse table_dict[table_id]['send_varnames'] = send_varnames table_dict[table_id]['var_list'].append(persistable_var) for table_id in table_dict: dummy_output = block.create_var( name=framework.generate_control_dev_var_name()) send_input_vars = [ block.vars[union_var] for union_var in table_dict[table_id]['var_list'] ] block.append_op( type="send", inputs={"X": send_input_vars}, outputs={"Out": dummy_output}, attrs={ "send_varnames": table_dict[table_id]['send_varnames'], "is_sparse": is_sparse, "table_id": table_id, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) return send_grad_var_list def get_vars_name_in_block(block): vars_list = block.vars.keys() vars_name_list = [var_name for var_name in vars_list] return vars_name_list def delete_trainer_useless_var(program, static_var): static_var = list(set(static_var)) program_useful_var_list = [] for op in program.global_block().ops: input_var_list, output_var_list = find_op_input_output( program, program.global_block(), op) op_var_list = list(set(input_var_list).union(set(output_var_list))) program_useful_var_list = list( set(program_useful_var_list).union(set(op_var_list))) program_useful_var_list += static_var program_useless_var_list = list( set(get_vars_name_in_block(program.global_block())).difference( set(program_useful_var_list))) for var in program_useless_var_list: program.global_block()._remove_var(var) return program_useless_var_list def create_backward_block(program, origin_program, bp_ops_list, block_var_detail): pre_block_idx = program.num_blocks - 1 heter_block = program._create_block(pre_block_idx) for _, op in enumerate(bp_ops_list): if op.type == "send": send_varnames = op.attr('send_varnames') is_skip = False for varname in send_varnames: if varname not in program.global_block( ).vars and varname not in heter_block.vars: is_skip = True break if is_skip == True: continue block_append_op(program, origin_program, heter_block, op) entrance_vars = block_var_detail[0]["backward"]["entrance"] add_vars_by_var_list(entrance_vars, origin_program, program, heter_block) exit_vars = block_var_detail[0]["backward"]["exit"] add_vars_by_var_list(exit_vars, origin_program, program, heter_block) return heter_block def debug_program(file, program): with open(file, 'w+') as f: f.write(str(program))