# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import json import queue import copy from enum import Enum import paddle SUCC = 0 # successor PRED = 1 # predecessor class CostNodeType(Enum): DEFAULT = 0 COMPUTATION = 1 COMMUNICATION = 2 VARIABLE = 3 MERGED = 4 NOP = 5 class Cost(object): def __init__(self): self.runtime = None self.static_mem = None self.peak_mem = None class CostModelMode(Enum): DEFAULT = 0 BENCHMARKING = 1 # costs based on trial runs ANALYSIS = 2 # costs based on analysis MIXED = 3 class CostNode(object): def __init__(self, node, node_type, id=None): self.id = id self.node = node self.type = node_type self._cost = 0 self.is_optim = False self.is_bwd = False @property def cost(self): return self._cost @cost.setter def cost(self, cost): if cost < 0: raise ValueError('Cost must be above 0.') self._cost = cost class MergedOpsCostNode(CostNode): def __init__(self, node_type, id=None, base_node_list=None, is_bwd=False): super(MergedOpsCostNode, self).__init__(None, node_type, id) self.node_list = base_node_list self.is_bwd = is_bwd class CommOpCostNode(CostNode): def __init__(self, node, node_type, id=None, comm_node_list=None, is_bwd=False): super(CommOpCostNode, self).__init__(node, node_type, id) self.node_list = comm_node_list self.ranks = [] self.comm_type = node.type self.is_bwd = is_bwd def set_ranks(self, ranks): self.ranks = ranks def set_shapes(self, input_shape, output_shape): self.input_shape = input_shape self.output_shape = output_shape def init_comm_cost(self, cluster=None): # ref: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md # should get from `cluster` BANDWIDTH = 32 * 1024 / 1000 # MB/ms, V100 PCIe num_ranks = len(self.ranks) comm_volumn = np.prod(self.input_shape) * 4 if 'allreduce' in self.comm_type: self._cost = comm_volumn / (BANDWIDTH * num_ranks / (2 * (num_ranks - 1))) elif 'gather' in self.comm_type: self._cost = comm_volumn / (BANDWIDTH * num_ranks / (num_ranks - 1)) elif 'broadcast' in self.comm_type: self._cost = comm_volumn / BANDWIDTH elif 'send' in self.comm_type or 'recv' in self.comm_type: self._cost = comm_volumn / BANDWIDTH else: self._cost = 0 class TensorCostNode(CostNode): def __init__(self, node, node_type, id=None, base_node_list=None, batch_size=None, shared_node_id=None): super(TensorCostNode, self).__init__(node, node_type, id) self.shape = node.shape self.dtype = node.dtype self.dtype_factor = 1 self.persistable = None self.shared_node_id = shared_node_id if self.dtype == paddle.float32 or node.dtype == paddle.int32: self.dtype_factor *= 4 elif node.dtype == paddle.int64: self.dtype_factor *= 8 else: raise NotImplementedError("{} not counted".format(v.node.dtype)) self.batch_size = None if batch_size is not None: self.batch_size = batch_size def get_size(self): p = 1 for i in self.node.shape: if i == -1: # deal with placeholder assert self.batch_size is not None, "Batch size not decided." i = self.batch_size p *= i return p class CompOpCostNode(CostNode): def __init__(self, node, node_type, id=None, is_bwd=False, is_optim=False): super(CompOpCostNode, self).__init__(node, node_type, id) self.is_bwd = is_bwd self.is_optim = is_optim def init_comp_cost(self, cost_data): # TODO: improve fluid.CostModel for more specific cost_data op_name = self.node.type if op_name in cost_data.keys(): self.cost = cost_data[op_name] else: self.cost = 0.0 class PipeEvent(object): def __init__(self, stage_id, event_name, duration, start_time=-1): self.stage_id = stage_id self.name = event_name self.duration = duration self.s_time = start_time self.e_time = -1 class CostModel(object): def __init__(self, mode=CostModelMode.BENCHMARKING, cluster=None, batch_size=1, microbatch_num=1, opcall_overhead=0, standalone_cost_data=None, pipeline_config=None): self.mode = mode # parameters self.opcall_overhead = opcall_overhead self.batch_size = batch_size self.microbatch_num = microbatch_num self.nodes = {} # name -> node self.origin_graph = {} # original graph self.op_graph = {} # op graph (no variables nodes) self.runtime_graph = {} # runtime graph, for simulation self.cluster = cluster self.cost_data = standalone_cost_data self.pp2rank = pipeline_config if self.pp2rank is not None: self.rank2pp = {} for stage_idx, ranks in enumerate(self.pp2rank): for rank in ranks: self.rank2pp[rank] = stage_idx else: self.rank2pp = None self.ring2rank = {} self.fwd_time = [] self.bwd_time = [] self.optim_time = [] def _parse_sub_program(self, program, nodes, graph, cost_data, sub_idx): assert len( program.blocks) == 1, "Program more than 1 block not supported." block = program.blocks[0] for var in block.vars.values(): var_id = var.name nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id) graph[var_id] = [[], []] for op in block.ops: op_id = op.type + "_" + str(op.idx) if op.type.startswith('c_') or op.type.startswith( 'send') or op.type.startswith('recv'): is_bwd = False if op.type.startswith('c_'): ring_id = op.attr('ring_id') if ring_id not in self.ring2rank: self.ring2rank[ring_id] = set() self.ring2rank[ring_id].add(sub_idx) is_bwd = '@GRAD' in op.output('Out')[0] elif op.type.startswith('recv'): is_bwd = '@GRAD' in op.output('Out')[0] elif op.type.startswith('send'): is_bwd = '@GRAD' in op.input('X')[0] op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, is_bwd) else: is_bwd = '_grad' in op.type is_optim = 'LearningRate' in op.input_names op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, is_bwd, is_optim) op_node.init_comp_cost(cost_data) nodes[op_id] = op_node graph[op_id] = [[], []] comm_input_shape = [0] comm_output_shape = [0] for i in range(len(op.input_names)): try: var_id = op.input(op.input_names[i])[0] var_node = nodes[var_id] graph[op_id][PRED].append(var_node.id) graph[var_id][SUCC].append(op_node.id) comm_input_shape = var_node.shape except: continue for i in range(len(op.output_names)): try: var_id = op.output(op.output_names[i])[0] var_node = nodes[var_id] graph[op_id][SUCC].append(var_node.id) graph[var_id][PRED].append(op_node.id) comm_output_shape = var_node.shape except: continue if op_node.type == CostNodeType.COMMUNICATION: op_node.set_shapes(comm_input_shape, comm_output_shape) # resolve hazard: rename the r/w hazard variable nodes to ensure self.origin_graph is a DAG new_var_dict = {} for node_id, node in nodes.items(): if node.type == CostNodeType.VARIABLE and node.node.persistable: write_op_cnt = 0 for pred_id in graph[node_id][PRED]: pred = nodes[pred_id] if pred.type == CostNodeType.COMPUTATION and ( pred_id in graph[node_id][SUCC]): graph[pred_id][SUCC].remove(node_id) graph[node_id][PRED].remove(pred_id) write_op_cnt += 1 new_var_id = node_id + '_write_{}'.format(write_op_cnt) new_var = TensorCostNode( node.node, CostNodeType.VARIABLE, new_var_id, shared_node_id=node_id) graph[new_var_id] = [[], []] graph[pred_id][SUCC].append(new_var_id) graph[new_var_id][PRED].append(pred_id) new_var_dict[new_var_id] = new_var for k, v in new_var_dict.items(): nodes[k] = v return nodes def parse_program(self, distributed_program): self.distributed_program = distributed_program self.total_rank = len(self.distributed_program) sub_prog_cnt = len(distributed_program) self.nodes = [] * sub_prog_cnt self.origin_graph = [] * sub_prog_cnt # original graph self.op_graph = [] * sub_prog_cnt # op graph (no variables nodes) self.runtime_graph = [] * sub_prog_cnt # runtime graph, for simulation for sub_idx, sub_prog in enumerate(distributed_program): self.nodes.append({}) self.origin_graph.append({}) self.op_graph.append({}) self.runtime_graph.append({}) self._parse_sub_program( sub_prog, self.nodes[sub_idx], self.origin_graph[sub_idx], self.cost_data[0 if self.rank2pp is None else self.rank2pp[ sub_idx]], sub_idx) return self.nodes def _find_succ_op(self, node_id, sub_idx=0): succ_ops_id = [] for succ_id in self.origin_graph[sub_idx][node_id][SUCC]: succ = self.nodes[sub_idx][succ_id] if succ.type == CostNodeType.COMMUNICATION or \ succ.type == CostNodeType.COMPUTATION: succ_ops_id.append(succ_id) elif succ.type == CostNodeType.VARIABLE: succ_ops_id = succ_ops_id + self._find_succ_op(succ_id, sub_idx) else: raise NotImplementedError( 'This type of node not supported yet:{}'.format(succ.type)) return succ_ops_id def build_op_graph(self): for sub_idx in range(self.total_rank): op_nodes_id = [] for node_id, node in self.nodes[sub_idx].items(): if node.type == CostNodeType.VARIABLE: continue self.op_graph[sub_idx][node_id] = [[], []] op_nodes_id.append(node_id) for op_id in op_nodes_id: succ_nodes_id = self._find_succ_op(op_id, sub_idx) self.op_graph[sub_idx][op_id][SUCC] = succ_nodes_id for succ_id in succ_nodes_id: self.op_graph[sub_idx][succ_id][PRED].append(op_id) def build_runtime_graph(self): self.runtime_graph = copy.deepcopy(self.op_graph) def eliminate_multi_edges(self, graph=None): for node_id, edges in graph.items(): graph[node_id][PRED] = list(set(edges[PRED])) graph[node_id][SUCC] = list(set(edges[SUCC])) def merge_comm(self): for sub_idx in range(self.total_rank): for node_id, edges in self.op_graph[sub_idx].items(): node = self.nodes[sub_idx][node_id] if node_id.startswith('c_'): ring_id = node.node.attr('ring_id') node.set_ranks(list(self.ring2rank[ring_id])) node.init_comm_cost(self.cluster) elif node_id.startswith('send') or node_id.startswith('recv'): peer_rank = node.node.attr('peer') node.set_ranks([sub_idx, peer_rank]) node.init_comm_cost(self.cluster) else: pass # Not communication op def _merge_node(self, to_merge_node_list, merge_type='linear', nodes=None): nodes_list = [] node_cost = 0 for node in to_merge_node_list: if isinstance(node, MergedOpsCostNode): nodes_list += node.node_list else: nodes_list.append(node.id) if merge_type == 'linear': node_cost += node.cost elif merge_type == 'branch': node_cost = max(node_cost, node.cost) else: raise NotImplementedError( 'This type of merging is not supported:{}'.format( merge_type)) merged_node_id = 'merged_' + str(len(nodes)) is_bwd = to_merge_node_list[0].is_bwd merged_node = MergedOpsCostNode( CostNodeType.MERGED, id=merged_node_id, base_node_list=nodes_list, is_bwd=is_bwd) merged_node.cost = node_cost return merged_node_id, merged_node def merge_linear(self): ''' This method does the following: If X depends on Y only, they must be run sequentially. [ e.g. A ->- C ->- D D and E depends on C only.] [ B ->-/ \->- E C depends on A and B. ] We merge X and Y into a new node and sum up their cost time. ''' cnt = 0 for sub_idx in range(self.total_rank): cnt += self._merge_linear( self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) cnt += self._merge_linear( self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) return cnt def merge_branch(self): ''' This method does the following: If a node has more than one successor, there is *branch*. [ e.g. A ->- B ->- D ] [ \->- C ->- / , B and C can be run at the same time ] case 1: if B or C is null (or D is directly dependent on A), it's equivalent to A->C->D or A->B->D, fall back to self.merge_linear case 2: if both B and C are some op, merged_cost = max(cost(B), cost(C)) ''' cnt = 0 for sub_idx in range(self.total_rank): cnt += self._merge_branch( self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) cnt += self._merge_branch( self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) return cnt def _merge_linear(self, nodes, runtime_graph, is_bwd=False): reduct_cnt = 0 rt_nodes_id = list(runtime_graph.keys()) for node_id in rt_nodes_id: if node_id not in runtime_graph.keys(): continue node = nodes[node_id] if not is_bwd == node.is_bwd or node.is_optim: continue edges = runtime_graph[node_id] ind = len(edges[PRED]) # in_degree if ind == 1: # only depend on one node pred_id = edges[PRED][0] pred = nodes[pred_id] merged_node_id, merged_node = self._merge_node( [node, pred], merge_type='linear', nodes=nodes) nodes[merged_node_id] = merged_node runtime_graph[merged_node_id] = [[], []] # delete edges and add new edges succ = None runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC]) if len(runtime_graph[pred_id][SUCC]) > 1: # predecessor has more than 1 successor # the merged_node is to inherit the rest of its successors succ = runtime_graph[pred_id][SUCC] succ.remove(node_id) runtime_graph[merged_node_id][SUCC] += succ runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][ PRED] for i in runtime_graph[pred_id][PRED]: runtime_graph[i][SUCC].remove(pred_id) runtime_graph[i][SUCC].append(merged_node_id) for i in edges[SUCC]: runtime_graph[i][PRED].remove(node_id) runtime_graph[i][PRED].append(merged_node_id) if succ is not None: for i in succ: runtime_graph[i][PRED].remove(pred_id) runtime_graph[i][PRED].append(merged_node_id) runtime_graph.pop(node_id) runtime_graph.pop(pred_id) reduct_cnt += 1 self.eliminate_multi_edges(runtime_graph) return reduct_cnt # the number of nodes that have been reduced def _merge_branch(self, nodes, runtime_graph, is_bwd=False): reduct_cnt = 0 rt_nodes_id = list(runtime_graph.keys()) for node_id in rt_nodes_id: node = nodes[node_id] if not is_bwd == node.is_bwd or node.is_optim: continue edges = runtime_graph[node_id] outd = len(edges[SUCC]) # out_degree if outd > 1: # branch out succ_nodes_id = edges[SUCC] succ_to_elim = [] for succ_id in succ_nodes_id: for succ_2_id in succ_nodes_id: tmp = runtime_graph[succ_2_id][SUCC] if succ_id in tmp: succ_to_elim.append(succ_id) break for id in succ_to_elim: edges[SUCC].remove(id) runtime_graph[id][PRED].remove(node_id) reduct_cnt += 1 to_merge = True if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][ SUCC]) < 1: continue end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0] for i in succ_nodes_id: if len(runtime_graph[i][SUCC]) != 1 or \ runtime_graph[i][SUCC][0] != end_node_id: to_merge = False # if branches has different end node, we don't merge them break if to_merge: to_merge_node_list = [nodes[i] for i in succ_nodes_id] merged_node_id, merged_node = self._merge_node( to_merge_node_list, merge_type='branch', nodes=nodes) nodes[merged_node_id] = merged_node runtime_graph[merged_node_id] = [[], []] # delete edges and add new edges runtime_graph[merged_node_id][SUCC] = [end_node_id] runtime_graph[merged_node_id][PRED] = edges[PRED] runtime_graph[end_node_id][PRED] = [merged_node_id] runtime_graph[node_id][SUCC] = [merged_node_id] for i in succ_nodes_id: runtime_graph.pop(i) reduct_cnt += len(to_merge_node_list) - 1 return reduct_cnt def get_runtime_cost(self): def get_node_cost(node): node_cost = node.cost + self.opcall_overhead if isinstance(node, MergedOpsCostNode): for it in node.node_list: node_cost += self.opcall_overhead return node_cost for sub_idx in range(self.total_rank): fwd_cost = 0 bwd_cost = 0 optim_cost = 0 for node_id in self.runtime_graph[sub_idx].keys(): node = self.nodes[sub_idx][node_id] if node.is_optim: optim_cost += get_node_cost(node) elif node.is_bwd: bwd_cost += get_node_cost(node) else: fwd_cost += get_node_cost(node) self.fwd_time.append(fwd_cost) self.bwd_time.append(bwd_cost) self.optim_time.append(optim_cost) return self.fwd_time, self.bwd_time, self.optim_time def get_mem(self): static_list = [] top_list = [] for sub_idx in range(self.total_rank): static_mem, cur_mem, top_mem = self._simulate_mem( self.nodes[sub_idx], self.origin_graph[sub_idx]) static_list.append(static_mem) top_list.append(top_mem) return static_list, top_list def _simulate_mem(self, nodes, origin_graph): q = queue.Queue(1024) sim_graph = copy.deepcopy(origin_graph) for node_id, node in nodes.items(): if len(sim_graph[node_id][PRED]) == 0: q.put(node_id) q.put('nop') cur_mem = 0 top_mem = -1 static_mem = 0 while not q.empty(): node_id = q.get() node = None size = 0 if node_id == 'nop': top_mem = max(cur_mem, top_mem) if q.empty(): break else: q.put(node_id) continue else: node = nodes[node_id] if node.type == CostNodeType.VARIABLE: size = node.get_size() if node.node.persistable: static_mem += size cur_mem += size edges = sim_graph[node_id] if not (node.type == CostNodeType.VARIABLE and node.node.persistable): for succ_id in edges[SUCC]: sim_graph[succ_id][PRED].remove(node_id) if len(sim_graph[succ_id][PRED]) == 0: q.put(succ_id) for pred_id in edges[PRED]: pred = nodes if pred.type == CostNodeType.VARIABLE: sim_graph[pred_id][SUCC].remove(node_id) if len(sim_graph[pred_id][ SUCC]) == 0 and not pred.node.persistable: cur_mem -= pred.get_size() return static_mem, cur_mem, top_mem def get_pipeline_time(self): if self.total_rank <= 1: return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0] else: return self._simulate_pipeline() def _simulate_pipeline(self): stage_num = len(self.pp2rank) event_list = [] global_time = [0] * stage_num total_time = 0 fwd_cnt = list(range(stage_num, 0, -1)) bwd_cnt = [self.microbatch_num] * stage_num q = queue.Queue(1024) for i in range(self.microbatch_num): q.put(PipeEvent(0, 'fwd', self.fwd_time[0])) while not q.empty(): e = q.get() stid = e.stage_id if e.name == 'fwd': if fwd_cnt[stid] > 0: e.s_time = max(global_time[stid], e.s_time) e.e_time = e.s_time + e.duration event_list.append(e) if stid != stage_num - 1: q.put( PipeEvent( stid + 1, 'fwd', self.fwd_time[stid + 1], start_time=e.e_time)) else: q.put( PipeEvent( stid, 'bwd', self.bwd_time[stid], start_time=e.e_time)) fwd_cnt[stid] -= 1 global_time[stid] = e.e_time else: q.put(e) elif e.name == 'bwd': e.s_time = max(global_time[stid], e.s_time) e.e_time = e.s_time + e.duration event_list.append(e) if stid != 0: q.put( PipeEvent( stid - 1, 'bwd', self.bwd_time[stid - 1], start_time=e.e_time)) fwd_cnt[stid] += 1 bwd_cnt[stid] -= 1 if bwd_cnt[stid] == 0: q.put( PipeEvent( stid, 'optim', self.optim_time[stid], start_time=e.e_time)) global_time[stid] = e.e_time elif e.name == 'optim': e.s_time = max(global_time[stid], e.s_time) e.e_time = e.s_time + e.duration event_list.append(e) global_time[stid] = e.e_time else: raise NotImplementedError( 'This type of pipe event is not supported yet.{}'.format( e.name)) for t in global_time: total_time = max(total_time, t) return total_time def get_cost(self): cost = Cost() static_mem, peak_mem = self.get_mem() cost.static_mem = static_mem cost.peak_mem = peak_mem self.merge_comm() while True: cnt = 0 cnt += self.merge_linear() cnt += self.merge_branch() if cnt == 0: # can't be further merged break self.get_runtime_cost() cost.runtime = self.get_pipeline_time() return cost def init(self, distributed_program): self.parse_program(distributed_program) self.build_op_graph() for sub_idx in range(self.total_rank): self.eliminate_multi_edges(self.op_graph[sub_idx]) self.build_runtime_graph() def estimate_cost(distributed_program, cluster, pipeline_config, standalone_cost_data, batch_size): """ Estimated cost from distributed program, cluster model and distributed settings. Args: distributed_program(list): list of paddle programs cluster(Cluster): cluster model standalone_cost_data(CostData): cost data given by paddle.core batch_size(int): batch size of the training workload pipeline_config(list): configuration of pipeline stage allocation """ # the following line is left for now, cluster model will be involved in the future assert cluster is None, "For now, cluster remains None" cm_ctx = CostModel( cluster=cluster, batch_size=batch_size, standalone_cost_data=standalone_cost_data, pipeline_config=pipeline_config) cm_ctx.init(distributed_program) cost = cm_ctx.get_cost() return cost