From a573a7ed7f4113cc7658b38f889e442bc805171e Mon Sep 17 00:00:00 2001 From: YipZLF <22539457+YipZLF@users.noreply.github.com> Date: Tue, 19 Oct 2021 14:03:46 +0800 Subject: [PATCH] Add auto parallel cost model and unittests (#36363) * Add auto parallel cost model and unittests * Fixed code styles. * Fixed bugs and codes style * fixed typo * Improved code style: object encapsulation. * Fixed codes. * Refractored estimate_cost * Fixed typo --- .../distributed/auto_parallel/__init__.py | 1 + .../distributed/auto_parallel/cost_model.py | 741 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 3 + .../test_auto_parallel_cost_model.py | 236 ++++++ 4 files changed, 981 insertions(+) create mode 100644 python/paddle/distributed/auto_parallel/cost_model.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 31f92e2575a..2779a9feb0b 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -21,5 +21,6 @@ from .interface import ProcessMesh # noqa: F401 from .completion import complete_annotation # noqa: F401 from .completion import complete_backward_annotation # noqa: F401 from .reshard import reshard # noqa: F401 +from .cost_model import estimate_cost __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/cost_model.py b/python/paddle/distributed/auto_parallel/cost_model.py new file mode 100644 index 00000000000..3fd438e2a62 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/cost_model.py @@ -0,0 +1,741 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import json +import queue +import copy +from enum import Enum +import paddle + +SUCC = 0 # successor +PRED = 1 # predecessor + + +class CostNodeType(Enum): + DEFAULT = 0 + COMPUTATION = 1 + COMMUNICATION = 2 + VARIABLE = 3 + MERGED = 4 + NOP = 5 + + +class Cost(object): + def __init__(self): + self.runtime = None + self.static_mem = None + self.peak_mem = None + + +class CostModelMode(Enum): + DEFAULT = 0 + BENCHMARKING = 1 # costs based on trial runs + ANALYSIS = 2 # costs based on analysis + MIXED = 3 + + +class CostNode(object): + def __init__(self, node, node_type, id=None): + self.id = id + self.node = node + self.type = node_type + self._cost = 0 + self.is_optim = False + self.is_bwd = False + + @property + def cost(self): + return self._cost + + @cost.setter + def cost(self, cost): + if cost < 0: + raise ValueError('Cost must be above 0.') + self._cost = cost + + +class MergedOpsCostNode(CostNode): + def __init__(self, node_type, id=None, base_node_list=None, is_bwd=False): + super(MergedOpsCostNode, self).__init__(None, node_type, id) + self.node_list = base_node_list + self.is_bwd = is_bwd + + +class CommOpCostNode(CostNode): + def __init__(self, + node, + node_type, + id=None, + comm_node_list=None, + is_bwd=False): + super(CommOpCostNode, self).__init__(node, node_type, id) + self.node_list = comm_node_list + self.ranks = [] + self.comm_type = node.type + self.is_bwd = is_bwd + + def set_ranks(self, ranks): + self.ranks = ranks + + def set_shapes(self, input_shape, output_shape): + self.input_shape = input_shape + self.output_shape = output_shape + + def init_comm_cost(self, cluster=None): + # ref: https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md + # should get from `cluster` + BANDWIDTH = 32 * 1024 / 1000 # MB/ms, V100 PCIe + num_ranks = len(self.ranks) + comm_volumn = np.prod(self.input_shape) * 4 + + if 'allreduce' in self.comm_type: + self._cost = comm_volumn / (BANDWIDTH * num_ranks / + (2 * (num_ranks - 1))) + elif 'gather' in self.comm_type: + self._cost = comm_volumn / (BANDWIDTH * num_ranks / (num_ranks - 1)) + elif 'broadcast' in self.comm_type: + self._cost = comm_volumn / BANDWIDTH + elif 'send' in self.comm_type or 'recv' in self.comm_type: + self._cost = comm_volumn / BANDWIDTH + else: + self._cost = 0 + + +class TensorCostNode(CostNode): + def __init__(self, + node, + node_type, + id=None, + base_node_list=None, + batch_size=None, + shared_node_id=None): + super(TensorCostNode, self).__init__(node, node_type, id) + self.shape = node.shape + self.dtype = node.dtype + self.dtype_factor = 1 + self.persistable = None + self.shared_node_id = shared_node_id + if self.dtype == paddle.float32 or node.dtype == paddle.int32: + self.dtype_factor *= 4 + elif node.dtype == paddle.int64: + self.dtype_factor *= 8 + else: + raise NotImplementedError("{} not counted".format(v.node.dtype)) + + self.batch_size = None + if batch_size is not None: + self.batch_size = batch_size + + def get_size(self): + p = 1 + for i in self.node.shape: + if i == -1: # deal with placeholder + assert self.batch_size is not None, "Batch size not decided." + i = self.batch_size + p *= i + return p + + +class CompOpCostNode(CostNode): + def __init__(self, node, node_type, id=None, is_bwd=False, is_optim=False): + super(CompOpCostNode, self).__init__(node, node_type, id) + self.is_bwd = is_bwd + self.is_optim = is_optim + + def init_comp_cost(self, cost_data): + # TODO: improve fluid.CostModel for more specific cost_data + op_name = self.node.type + if op_name in cost_data.keys(): + self.cost = cost_data[op_name] + else: + self.cost = 0.0 + + +class PipeEvent(object): + def __init__(self, stage_id, event_name, duration, start_time=-1): + self.stage_id = stage_id + self.name = event_name + self.duration = duration + self.s_time = start_time + self.e_time = -1 + + +class CostModel(object): + def __init__(self, + mode=CostModelMode.BENCHMARKING, + cluster=None, + batch_size=1, + microbatch_num=1, + opcall_overhead=0, + standalone_cost_data=None, + pipeline_config=None): + self.mode = mode + + # parameters + self.opcall_overhead = opcall_overhead + self.batch_size = batch_size + self.microbatch_num = microbatch_num + + self.nodes = {} # name -> node + + self.origin_graph = {} # original graph + self.op_graph = {} # op graph (no variables nodes) + self.runtime_graph = {} # runtime graph, for simulation + + self.cluster = cluster + self.cost_data = standalone_cost_data + self.pp2rank = pipeline_config + if self.pp2rank is not None: + self.rank2pp = {} + for stage_idx, ranks in enumerate(self.pp2rank): + for rank in ranks: + self.rank2pp[rank] = stage_idx + else: + self.rank2pp = None + + self.ring2rank = {} + + self.fwd_time = [] + self.bwd_time = [] + self.optim_time = [] + + def _parse_sub_program(self, program, nodes, graph, cost_data, sub_idx): + assert len( + program.blocks) == 1, "Program more than 1 block not supported." + block = program.blocks[0] + + for var in block.vars.values(): + var_id = var.name + nodes[var_id] = TensorCostNode(var, CostNodeType.VARIABLE, var_id) + graph[var_id] = [[], []] + + for op in block.ops: + op_id = op.type + "_" + str(op.idx) + if op.type.startswith('c_') or op.type.startswith( + 'send') or op.type.startswith('recv'): + is_bwd = False + if op.type.startswith('c_'): + ring_id = op.attr('ring_id') + if ring_id not in self.ring2rank: + self.ring2rank[ring_id] = set() + self.ring2rank[ring_id].add(sub_idx) + is_bwd = '@GRAD' in op.output('Out')[0] + elif op.type.startswith('recv'): + is_bwd = '@GRAD' in op.output('Out')[0] + elif op.type.startswith('send'): + is_bwd = '@GRAD' in op.input('X')[0] + op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, + is_bwd) + else: + is_bwd = '_grad' in op.type + is_optim = 'LearningRate' in op.input_names + op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, + is_bwd, is_optim) + op_node.init_comp_cost(cost_data) + + nodes[op_id] = op_node + graph[op_id] = [[], []] + + comm_input_shape = [0] + comm_output_shape = [0] + for i in range(len(op.input_names)): + try: + var_id = op.input(op.input_names[i])[0] + var_node = nodes[var_id] + graph[op_id][PRED].append(var_node.id) + graph[var_id][SUCC].append(op_node.id) + comm_input_shape = var_node.shape + except: + continue + for i in range(len(op.output_names)): + try: + var_id = op.output(op.output_names[i])[0] + var_node = nodes[var_id] + graph[op_id][SUCC].append(var_node.id) + graph[var_id][PRED].append(op_node.id) + comm_output_shape = var_node.shape + except: + continue + if op_node.type == CostNodeType.COMMUNICATION: + op_node.set_shapes(comm_input_shape, comm_output_shape) + + # resolve hazard: rename the r/w hazard variable nodes to ensure self.origin_graph is a DAG + new_var_dict = {} + for node_id, node in nodes.items(): + if node.type == CostNodeType.VARIABLE and node.node.persistable: + write_op_cnt = 0 + for pred_id in graph[node_id][PRED]: + pred = nodes[pred_id] + if pred.type == CostNodeType.COMPUTATION and ( + pred_id in graph[node_id][SUCC]): + + graph[pred_id][SUCC].remove(node_id) + graph[node_id][PRED].remove(pred_id) + + write_op_cnt += 1 + new_var_id = node_id + '_write_{}'.format(write_op_cnt) + new_var = TensorCostNode( + node.node, + CostNodeType.VARIABLE, + new_var_id, + shared_node_id=node_id) + + graph[new_var_id] = [[], []] + graph[pred_id][SUCC].append(new_var_id) + graph[new_var_id][PRED].append(pred_id) + + new_var_dict[new_var_id] = new_var + for k, v in new_var_dict.items(): + nodes[k] = v + return nodes + + def parse_program(self, distributed_program): + self.distributed_program = distributed_program + self.total_rank = len(self.distributed_program) + sub_prog_cnt = len(distributed_program) + self.nodes = [] * sub_prog_cnt + self.origin_graph = [] * sub_prog_cnt # original graph + self.op_graph = [] * sub_prog_cnt # op graph (no variables nodes) + self.runtime_graph = [] * sub_prog_cnt # runtime graph, for simulation + + for sub_idx, sub_prog in enumerate(distributed_program): + self.nodes.append({}) + self.origin_graph.append({}) + self.op_graph.append({}) + self.runtime_graph.append({}) + self._parse_sub_program( + sub_prog, self.nodes[sub_idx], self.origin_graph[sub_idx], + self.cost_data[0 if self.rank2pp is None else self.rank2pp[ + sub_idx]], sub_idx) + return self.nodes + + def _find_succ_op(self, node_id, sub_idx=0): + succ_ops_id = [] + for succ_id in self.origin_graph[sub_idx][node_id][SUCC]: + succ = self.nodes[sub_idx][succ_id] + if succ.type == CostNodeType.COMMUNICATION or \ + succ.type == CostNodeType.COMPUTATION: + succ_ops_id.append(succ_id) + elif succ.type == CostNodeType.VARIABLE: + succ_ops_id = succ_ops_id + self._find_succ_op(succ_id, sub_idx) + else: + raise NotImplementedError( + 'This type of node not supported yet:{}'.format(succ.type)) + return succ_ops_id + + def build_op_graph(self): + for sub_idx in range(self.total_rank): + op_nodes_id = [] + for node_id, node in self.nodes[sub_idx].items(): + if node.type == CostNodeType.VARIABLE: + continue + self.op_graph[sub_idx][node_id] = [[], []] + op_nodes_id.append(node_id) + for op_id in op_nodes_id: + succ_nodes_id = self._find_succ_op(op_id, sub_idx) + + self.op_graph[sub_idx][op_id][SUCC] = succ_nodes_id + for succ_id in succ_nodes_id: + self.op_graph[sub_idx][succ_id][PRED].append(op_id) + + def build_runtime_graph(self): + self.runtime_graph = copy.deepcopy(self.op_graph) + + def eliminate_multi_edges(self, graph=None): + for node_id, edges in graph.items(): + graph[node_id][PRED] = list(set(edges[PRED])) + graph[node_id][SUCC] = list(set(edges[SUCC])) + + def merge_comm(self): + for sub_idx in range(self.total_rank): + for node_id, edges in self.op_graph[sub_idx].items(): + node = self.nodes[sub_idx][node_id] + if node_id.startswith('c_'): + ring_id = node.node.attr('ring_id') + node.set_ranks(list(self.ring2rank[ring_id])) + node.init_comm_cost(self.cluster) + elif node_id.startswith('send') or node_id.startswith('recv'): + peer_rank = node.node.attr('peer') + node.set_ranks([sub_idx, peer_rank]) + node.init_comm_cost(self.cluster) + else: + pass # Not communication op + + def _merge_node(self, to_merge_node_list, merge_type='linear', nodes=None): + nodes_list = [] + node_cost = 0 + for node in to_merge_node_list: + if isinstance(node, MergedOpsCostNode): + nodes_list += node.node_list + else: + nodes_list.append(node.id) + if merge_type == 'linear': + node_cost += node.cost + elif merge_type == 'branch': + node_cost = max(node_cost, node.cost) + else: + raise NotImplementedError( + 'This type of merging is not supported:{}'.format( + merge_type)) + merged_node_id = 'merged_' + str(len(nodes)) + is_bwd = to_merge_node_list[0].is_bwd + merged_node = MergedOpsCostNode( + CostNodeType.MERGED, + id=merged_node_id, + base_node_list=nodes_list, + is_bwd=is_bwd) + merged_node.cost = node_cost + return merged_node_id, merged_node + + def merge_linear(self): + ''' + This method does the following: + If X depends on Y only, they must be run sequentially. + [ e.g. A ->- C ->- D D and E depends on C only.] + [ B ->-/ \->- E C depends on A and B. ] + We merge X and Y into a new node and sum up their cost time. + ''' + cnt = 0 + for sub_idx in range(self.total_rank): + cnt += self._merge_linear( + self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) + cnt += self._merge_linear( + self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) + return cnt + + def merge_branch(self): + ''' + This method does the following: + If a node has more than one successor, there is *branch*. + [ e.g. A ->- B ->- D ] + [ \->- C ->- / , B and C can be run at the same time ] + case 1: if B or C is null (or D is directly dependent on A), + it's equivalent to A->C->D or A->B->D, fall back to self.merge_linear + case 2: if both B and C are some op, + merged_cost = max(cost(B), cost(C)) + ''' + cnt = 0 + for sub_idx in range(self.total_rank): + cnt += self._merge_branch( + self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) + cnt += self._merge_branch( + self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) + return cnt + + def _merge_linear(self, nodes, runtime_graph, is_bwd=False): + reduct_cnt = 0 + rt_nodes_id = list(runtime_graph.keys()) + for node_id in rt_nodes_id: + if node_id not in runtime_graph.keys(): + continue + node = nodes[node_id] + if not is_bwd == node.is_bwd or node.is_optim: + continue + edges = runtime_graph[node_id] + ind = len(edges[PRED]) # in_degree + if ind == 1: # only depend on one node + pred_id = edges[PRED][0] + pred = nodes[pred_id] + merged_node_id, merged_node = self._merge_node( + [node, pred], merge_type='linear', nodes=nodes) + nodes[merged_node_id] = merged_node + runtime_graph[merged_node_id] = [[], []] + + # delete edges and add new edges + succ = None + runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[SUCC]) + if len(runtime_graph[pred_id][SUCC]) > 1: + # predecessor has more than 1 successor + # the merged_node is to inherit the rest of its successors + succ = runtime_graph[pred_id][SUCC] + succ.remove(node_id) + runtime_graph[merged_node_id][SUCC] += succ + runtime_graph[merged_node_id][PRED] = runtime_graph[pred_id][ + PRED] + for i in runtime_graph[pred_id][PRED]: + runtime_graph[i][SUCC].remove(pred_id) + runtime_graph[i][SUCC].append(merged_node_id) + + for i in edges[SUCC]: + runtime_graph[i][PRED].remove(node_id) + runtime_graph[i][PRED].append(merged_node_id) + if succ is not None: + for i in succ: + runtime_graph[i][PRED].remove(pred_id) + runtime_graph[i][PRED].append(merged_node_id) + + runtime_graph.pop(node_id) + runtime_graph.pop(pred_id) + reduct_cnt += 1 + self.eliminate_multi_edges(runtime_graph) + return reduct_cnt # the number of nodes that have been reduced + + def _merge_branch(self, nodes, runtime_graph, is_bwd=False): + reduct_cnt = 0 + rt_nodes_id = list(runtime_graph.keys()) + for node_id in rt_nodes_id: + node = nodes[node_id] + if not is_bwd == node.is_bwd or node.is_optim: + continue + edges = runtime_graph[node_id] + outd = len(edges[SUCC]) # out_degree + if outd > 1: # branch out + succ_nodes_id = edges[SUCC] + + succ_to_elim = [] + for succ_id in succ_nodes_id: + for succ_2_id in succ_nodes_id: + tmp = runtime_graph[succ_2_id][SUCC] + if succ_id in tmp: + succ_to_elim.append(succ_id) + break + for id in succ_to_elim: + edges[SUCC].remove(id) + runtime_graph[id][PRED].remove(node_id) + reduct_cnt += 1 + + to_merge = True + if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]][ + SUCC]) < 1: + continue + end_node_id = runtime_graph[edges[SUCC][0]][SUCC][0] + for i in succ_nodes_id: + if len(runtime_graph[i][SUCC]) != 1 or \ + runtime_graph[i][SUCC][0] != end_node_id: + to_merge = False # if branches has different end node, we don't merge them + break + if to_merge: + to_merge_node_list = [nodes[i] for i in succ_nodes_id] + merged_node_id, merged_node = self._merge_node( + to_merge_node_list, merge_type='branch', nodes=nodes) + nodes[merged_node_id] = merged_node + runtime_graph[merged_node_id] = [[], []] + + # delete edges and add new edges + runtime_graph[merged_node_id][SUCC] = [end_node_id] + runtime_graph[merged_node_id][PRED] = edges[PRED] + + runtime_graph[end_node_id][PRED] = [merged_node_id] + runtime_graph[node_id][SUCC] = [merged_node_id] + + for i in succ_nodes_id: + runtime_graph.pop(i) + reduct_cnt += len(to_merge_node_list) - 1 + return reduct_cnt + + def get_runtime_cost(self): + def get_node_cost(node): + node_cost = node.cost + self.opcall_overhead + if isinstance(node, MergedOpsCostNode): + for it in node.node_list: + node_cost += self.opcall_overhead + return node_cost + + for sub_idx in range(self.total_rank): + fwd_cost = 0 + bwd_cost = 0 + optim_cost = 0 + for node_id in self.runtime_graph[sub_idx].keys(): + node = self.nodes[sub_idx][node_id] + if node.is_optim: + optim_cost += get_node_cost(node) + elif node.is_bwd: + bwd_cost += get_node_cost(node) + else: + fwd_cost += get_node_cost(node) + self.fwd_time.append(fwd_cost) + self.bwd_time.append(bwd_cost) + self.optim_time.append(optim_cost) + return self.fwd_time, self.bwd_time, self.optim_time + + def get_mem(self): + static_list = [] + top_list = [] + for sub_idx in range(self.total_rank): + static_mem, cur_mem, top_mem = self._simulate_mem( + self.nodes[sub_idx], self.origin_graph[sub_idx]) + static_list.append(static_mem) + top_list.append(top_mem) + return static_list, top_list + + def _simulate_mem(self, nodes, origin_graph): + q = queue.Queue(1024) + sim_graph = copy.deepcopy(origin_graph) + for node_id, node in nodes.items(): + if len(sim_graph[node_id][PRED]) == 0: + q.put(node_id) + + q.put('nop') + cur_mem = 0 + top_mem = -1 + static_mem = 0 + while not q.empty(): + node_id = q.get() + node = None + size = 0 + if node_id == 'nop': + top_mem = max(cur_mem, top_mem) + if q.empty(): + break + else: + q.put(node_id) + continue + else: + node = nodes[node_id] + if node.type == CostNodeType.VARIABLE: + size = node.get_size() + if node.node.persistable: + static_mem += size + cur_mem += size + edges = sim_graph[node_id] + if not (node.type == CostNodeType.VARIABLE and + node.node.persistable): + for succ_id in edges[SUCC]: + sim_graph[succ_id][PRED].remove(node_id) + if len(sim_graph[succ_id][PRED]) == 0: + q.put(succ_id) + for pred_id in edges[PRED]: + pred = nodes + if pred.type == CostNodeType.VARIABLE: + sim_graph[pred_id][SUCC].remove(node_id) + if len(sim_graph[pred_id][ + SUCC]) == 0 and not pred.node.persistable: + cur_mem -= pred.get_size() + return static_mem, cur_mem, top_mem + + def get_pipeline_time(self): + if self.total_rank <= 1: + return self.fwd_time[0] + self.bwd_time[0] + self.optim_time[0] + else: + return self._simulate_pipeline() + + def _simulate_pipeline(self): + stage_num = len(self.pp2rank) + event_list = [] + global_time = [0] * stage_num + total_time = 0 + fwd_cnt = list(range(stage_num, 0, -1)) + bwd_cnt = [self.microbatch_num] * stage_num + q = queue.Queue(1024) + + for i in range(self.microbatch_num): + q.put(PipeEvent(0, 'fwd', self.fwd_time[0])) + + while not q.empty(): + e = q.get() + stid = e.stage_id + if e.name == 'fwd': + if fwd_cnt[stid] > 0: + e.s_time = max(global_time[stid], e.s_time) + e.e_time = e.s_time + e.duration + event_list.append(e) + if stid != stage_num - 1: + q.put( + PipeEvent( + stid + 1, + 'fwd', + self.fwd_time[stid + 1], + start_time=e.e_time)) + else: + q.put( + PipeEvent( + stid, + 'bwd', + self.bwd_time[stid], + start_time=e.e_time)) + fwd_cnt[stid] -= 1 + global_time[stid] = e.e_time + else: + q.put(e) + elif e.name == 'bwd': + e.s_time = max(global_time[stid], e.s_time) + e.e_time = e.s_time + e.duration + event_list.append(e) + if stid != 0: + q.put( + PipeEvent( + stid - 1, + 'bwd', + self.bwd_time[stid - 1], + start_time=e.e_time)) + fwd_cnt[stid] += 1 + bwd_cnt[stid] -= 1 + if bwd_cnt[stid] == 0: + q.put( + PipeEvent( + stid, + 'optim', + self.optim_time[stid], + start_time=e.e_time)) + global_time[stid] = e.e_time + elif e.name == 'optim': + e.s_time = max(global_time[stid], e.s_time) + e.e_time = e.s_time + e.duration + event_list.append(e) + global_time[stid] = e.e_time + else: + raise NotImplementedError( + 'This type of pipe event is not supported yet.{}'.format( + e.name)) + + for t in global_time: + total_time = max(total_time, t) + return total_time + + def get_cost(self): + cost = Cost() + static_mem, peak_mem = self.get_mem() + cost.static_mem = static_mem + cost.peak_mem = peak_mem + self.merge_comm() + while True: + cnt = 0 + cnt += self.merge_linear() + cnt += self.merge_branch() + if cnt == 0: # can't be further merged + break + self.get_runtime_cost() + cost.runtime = self.get_pipeline_time() + return cost + + def init(self, distributed_program): + self.parse_program(distributed_program) + self.build_op_graph() + for sub_idx in range(self.total_rank): + self.eliminate_multi_edges(self.op_graph[sub_idx]) + self.build_runtime_graph() + + +def estimate_cost(distributed_program, cluster, pipeline_config, + standalone_cost_data, batch_size): + """ + Estimated cost from distributed program, cluster model and distributed settings. + + Args: + distributed_program(list): list of paddle programs + cluster(Cluster): cluster model + standalone_cost_data(CostData): cost data given by paddle.core + batch_size(int): batch size of the training workload + pipeline_config(list): configuration of pipeline stage allocation + """ + # the following line is left for now, cluster model will be involved in the future + assert cluster is None, "For now, cluster remains None" + cm_ctx = CostModel( + cluster=cluster, + batch_size=batch_size, + standalone_cost_data=standalone_cost_data, + pipeline_config=pipeline_config) + cm_ctx.init(distributed_program) + cost = cm_ctx.get_cost() + return cost diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index f883d7a80a4..90f59758a2f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -91,6 +91,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_cost_model) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -234,6 +235,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_cost_model) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) @@ -608,6 +610,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_cost_model MODULES test_auto_parallel_cost_model ENVS ${dist_ENVS}) endif(NOT WIN32) endif(NOT APPLE) if(WITH_DGC) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py new file mode 100644 index 00000000000..58d033ad658 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -0,0 +1,236 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.completion import complete_backward_annotation +from paddle.distributed.auto_parallel.reshard import reshard +from paddle.distributed.auto_parallel.cost_model import estimate_cost +import paddle.fluid.core as core + +paddle.enable_static() +_global_parallel_strategy = "dp_mp_pp" +ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) +_global_process_mesh = auto.ProcessMesh( + [[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH) +NUM_RANKS = 8 +STAGE_0_CNT = 5 +STAGE_1_CNT = 10 +pp_cfg = [[0, 1, 4, 5], [2, 3, 6, 7]] + +device = "gpu" if core.is_compiled_with_cuda() else "cpu" + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=256, + intermediate_size=4 * 256, + initializer_range=0.02, + is_distributed=True): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + self.is_distributed = is_distributed + + def forward(self, input): + if self.is_distributed: + auto.shard_tensor( + self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + + return out + + +def get_single_node_data(): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + + loss, train_program, startup_program = mlp_forward( + train_program, startup_program, is_distributed=False) + + cost_model = core.CostModel() + cost_data = cost_model.profile_measure(train_program, startup_program, + device, ["time"]) + + op_name2cost = [{}, {}] + for idx, op in enumerate(train_program.blocks[0].ops): + if idx <= STAGE_0_CNT: + op_name2cost[0][op.type] = cost_data.get_op_time_ms(idx) + elif idx <= STAGE_1_CNT: + op_name2cost[1][op.type] = cost_data.get_op_time_ms(idx) + return op_name2cost + + +def mlp_forward(train_program, start_program, is_distributed=True): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 256 + sequence_len = 128 + if is_distributed: + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + else: + input = paddle.ones( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = paddle.ones( + name="label", shape=[batch_size, 1], dtype='float32') + + if is_distributed: + auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02, + is_distributed=is_distributed) + + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + + return loss, train_program, start_program + + +def get_dist_prog(train_program, startup_program, dist_context, rank_id): + global _global_process_mesh + dist_context.set_process_mesh(_global_process_mesh) + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + # auto completion + complete_train_program = auto.complete_annotation(train_program, + dist_context) + + dist_strategy = fleet.DistributedStrategy() + dist_main_prog = [] + dist_startup_prog = [] + for rank_id in range(NUM_RANKS): + partitioner = Partitioner(dist_strategy, dist_context, rank_id) + # logical partition + auto_parallel_main_prog, auto_parallel_startup_prog = partitioner.transpile_forward( + complete_train_program, startup_program) + dist_params_grads = partitioner.apply_backward( + loss, complete_train_program, startup_program, + auto_parallel_main_prog, auto_parallel_startup_prog) + optimizer = paddle.fluid.optimizer.AdamOptimizer() + opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, + auto_parallel_main_prog, + auto_parallel_startup_prog) + dist_main_prog.append(auto_parallel_main_prog) + dist_startup_prog.append(auto_parallel_startup_prog) + return dist_main_prog, dist_startup_prog + + +def check_runtime_estimation(cost): + return cost.runtime > 0 + + +def check_memory_estimation(cost): + for i in range(NUM_RANKS): + if cost.static_mem[i] <= 0 or cost.peak_mem[i] <= 0: + return False + if cost.static_mem[i] > cost.peak_mem[i]: + return False + return True + + +def check_empty_program_runtime(cost): + return cost.runtime == 0 + + +def check_empty_program_memory(cost): + for mem in cost.peak_mem: + if mem > 0: + return False + for mem in cost.static_mem: + if mem > 0: + return False + return True + + +class TestCostModel(unittest.TestCase): + def test_empty_program_cost_model(self): + empty_program = paddle.static.Program() + startup_program = paddle.static.Program() + standalone_cost_data = [{}] + empty_pp_cfg = None + cluster = None + cost = estimate_cost( + [empty_program], + cluster=cluster, + pipeline_config=empty_pp_cfg, + standalone_cost_data=standalone_cost_data, + batch_size=1) + + self.assertTrue(check_empty_program_runtime(cost)) + self.assertTrue(check_empty_program_memory(cost)) + + def test_auto_parallel_cost_model(self): + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + standalone_cost_data = get_single_node_data() + distributed_program, dist_startup_prog = get_dist_prog( + train_program, startup_program, dist_context, 0) + for rank_id in range(NUM_RANKS): + complete_backward_annotation(distributed_program[rank_id], + dist_context) + reshard(distributed_program[rank_id], dist_startup_prog[rank_id], + rank_id, dist_context) + cluster = None + cost = estimate_cost( + distributed_program, + cluster=cluster, + pipeline_config=pp_cfg, + standalone_cost_data=standalone_cost_data, + batch_size=4) + self.assertTrue(check_runtime_estimation(cost)) + self.assertTrue(check_memory_estimation(cost)) + + +if __name__ == "__main__": + unittest.main() -- GitLab