From 6a15d407b043630420dcbf528990dc915e39a3e0 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 16 Aug 2022 17:35:56 +0800 Subject: [PATCH] [Auto Paralle]Add reshard cost and update estimator (#45118) * update reshard cost and cost estimator * add unittest * add dropout cost * fix import error * fix reshard code style error * improve unittest coverage --- .../auto_parallel/cost/comp_op_cost.py | 19 + .../auto_parallel/cost/estimate_cost.py | 382 +++++++++++++++++- .../auto_parallel/operators/dist_embedding.py | 3 +- .../dist_fill_constant_batch_size_like.py | 2 +- .../auto_parallel/operators/dist_matmul.py | 3 +- .../auto_parallel/operators/dist_softmax.py | 3 +- .../auto_parallel/operators/dist_transpose.py | 3 +- .../distributed/auto_parallel/reshard.py | 206 ++++++++++ .../auto_parallel_relaunch_with_planner.py | 27 ++ .../auto_parallel/test_new_cost_model.py | 1 + .../test_auto_parallel_reshard_dpmppp.py | 17 + .../test_auto_parallel_reshard_mppp.py | 19 +- 12 files changed, 659 insertions(+), 26 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index bdfcbfe06d3..b4ac972bcfd 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -148,6 +148,25 @@ class ConcatOpCost(CompOpCost): return 0 +@register_op_cost +class DropoutOpCost(CompOpCost): + OP_TYPE = "dropout" + + def __init__(self, op=None, op_desc=None, cluster=None): + super(DropoutOpCost, self).__init__(op=op, + op_desc=op_desc, + cluster=cluster) + + # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided + def calc_flops(self): + # NOTE: The actual formula will be filled in the future + return 0 + + def calc_time(self): + # NOTE: The actual formula will be filled in the future + return 0 + + @register_op_cost class ElementwiseAddOpCost(CompOpCost): OP_TYPE = "elementwise_add" diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index 5a1aeec2d9f..7bdde90b6a7 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,26 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License +from collections import OrderedDict +from functools import reduce + +import paddle +import paddle.fluid.core as core +from paddle.distributed.fleet.meta_optimizers.common import OpRole + +from .base_cost import Cost +from ..operators.common import get_distributed_operator_impl_container +from ..dist_tensor import DistributedTensor + class CostEstimator: + _sepical_op_type = ["fused_attention", "fused_feedforward"] def __init__(self, program, - cluster=None, - dist_context=None, - mode="modeling"): + cluster, + mode="modeling", + rank=None, + loop_count=10): self._program = program self._cluster = cluster - self._dist_context = dist_context self._check_mode(mode) self._mode = mode - self._global_cost = None - self._local_cost = {} + self._rank = rank if rank is not None else paddle.distributed.get_rank() + self._loop_count = loop_count + self._global_cost = Cost() + self._local_cost_mapping = {} + self._detailed_cost = OrderedDict( + ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} + self._bubble_time_mapping = {} + self._ordered_ops = [] + + @property + def loop_count(self): + return self._loop_count + + @property + def detailed_cost(self): + return self._detailed_cost @property def program(self): return self._program + @property + def rank(self): + return self._rank + @property def dist_context(self): return self._dist_context @@ -46,25 +76,337 @@ class CostEstimator: @property def global_cost(self): + max_time = 0 + memory = 0 + flops = 0 + for rank in self._local_cost_mapping: + cost = self._local_cost_mapping[rank] + if cost.time > max_time: + max_time = cost.time + memory += cost.memory + flops += cost.flops + self._global_cost.time = max_time + self._global_cost.memory = memory + self._global_cost.flops = flops return self._global_cost - @property - def local_cost(self): - return self._local_cost - - def get_op_cost(self): - return 0 + def local_cost(self, rank=None): + rank = self.rank if rank is None else rank + if rank not in self._local_cost_mapping: + self._local_cost_mapping[rank] = Cost() - def get_tensor_cost(self): - return 0 + return self._local_cost_mapping[rank] - def get_global_cost(self): - return 0 - - def get_local_cost(self, rank=None): - return 0 + def local_bubble_time(self, rank=None): + rank = self.rank if rank is None else rank + return self._bubble_time_mapping[rank] def _check_mode(self, mode): if mode not in ["modeling", "profiling"]: raise ValueError( "Just support modeling and profiling, but got {}".format(mode)) + + def _is_special_var_name(self, var_name): + special_var_name = ["lod_tensor_blocking_queue_0"] + if var_name in special_var_name: + return True + return False + + def _estimate_core(self, dist_context, resharder, block): + from ..reshard import get_var_with_recursion + ops = block.ops + loop_count = None + if block.desc.id != self.program.global_block().desc.id: + loop_count = self.loop_count + else: + loop_count = 1 + for i in range(loop_count): + for op in ops: + self._detailed_cost[op.desc.id()] = OrderedDict() + # if in the while sub block, the detail of cost is the last cost + detail = self._detailed_cost[op.desc.id()] + detail["reshard_cost"] = OrderedDict() # + detail["dist_op_cost"] = [] + if int(op.attr('op_role')) == int(OpRole.Optimize): + continue + if op.type in [ + "create_py_reader", "create_double_buffer_reader", + "read" + ]: + continue + + # NOTE: It does not support nested loop and just supports while op when op has sub block now. + if op.type == "while": + while_block = self.program.blocks[op.attr("sub_block").id] + self._estimate_core(dist_context, resharder, while_block) + continue + + for var_name in op.input_arg_names: + if self._is_special_var_name(var_name): + continue + var = get_var_with_recursion(var_name, block, self.program) + reshard_cost = resharder.get_cost(op, var, self.cluster) + + # calc reshard cost + if reshard_cost is not None: + detail["reshard_cost"][var_name] = reshard_cost + + comm_costs = reshard_cost[0] + local_comp_cost = reshard_cost[1] + for comm_cost in comm_costs: + # time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost. + # comm sync + for item in comm_cost: + group_ranks, cost = item + max_time = None + cost_time = {} + for rank in group_ranks: + rank_cost = self.local_cost(rank) + cost_time[rank] = rank_cost.time + if max_time is None: + max_time = rank_cost.time + else: + if max_time < rank_cost.time: + max_time = rank_cost.time + + for rank in group_ranks: + self.local_cost( + rank).time = max_time + cost.time + + if rank not in self._bubble_time_mapping: + self._bubble_time_mapping[rank] = 0 + + self._bubble_time_mapping[rank] += ( + max_time - cost_time[rank]) + + for rank in local_comp_cost: + for comp_cost in local_comp_cost[rank]: + self.local_cost(rank).time += comp_cost.time + + # calc dist op cost + dist_op = dist_context.get_dist_op_for_program(op) + op_dist_attr = dist_op.dist_attr + processes = op_dist_attr.process_mesh.processes + + container = get_distributed_operator_impl_container( + op_dist_attr.impl_type) + dist_impl = container.impls[op_dist_attr.impl_idx] + + dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, + dist_context, self.cluster) + detail["dist_op_cost"] = dist_op_cost + + if dist_op_cost is None: + assert dist_op.serial_op.type in CostEstimator._sepical_op_type + continue + for item in dist_op_cost: + if isinstance(item, list): + # comm sync + for comm_op_cost in item: + max_time = None + cost_time = {} + group_ranks = comm_op_cost.group_ranks + for rank in comm_op_cost.group_ranks: + rank_cost = self.local_cost(rank) + cost_time[rank] = rank_cost.time + if max_time is None: + max_time = rank_cost.time + else: + if max_time < rank_cost.time: + max_time = rank_cost.time + for rank in group_ranks: + self.local_cost( + rank).time = max_time + comm_op_cost.time + if rank not in self._bubble_time_mapping: + self._bubble_time_mapping[rank] = 0 + self._bubble_time_mapping[rank] += ( + max_time - cost_time[rank]) + elif isinstance(item, dict): + # op just one + for rank in processes: + # dp+pp+mp + if rank not in item: + continue + self.local_cost(rank).time += item[rank].time + + def prepare(self): + self._global_cost = Cost() + self._local_cost_mapping = {} + self._detailed_cost = OrderedDict() + self._bubble_time_mapping = {} + + def _calculate_bytes(self, sizes, dtype): + if sizes: + total_count = reduce(lambda x, y: x * y, sizes) + else: + total_count = 0 + + if dtype == paddle.float64 or dtype == paddle.int64: + dtype_factor = 8 + elif dtype == paddle.float32 or dtype == paddle.int32: + dtype_factor = 4 + elif dtype == paddle.float16 or dtype == paddle.bfloat16 \ + or dtype == paddle.int16: + dtype_factor = 2 + elif dtype == paddle.int8 or dtype == paddle.uint8: + dtype_factor = 1 + else: + dtype_factor = 8 + + memory = total_count * dtype_factor + return memory + + def _estimate_max_memory_by_dist_op(self, dist_context): + # This estimation will be improved, now reshard and inplace are not considered. + # Persist var is not free. + def _convert_pm_and_dm_to_str(process_mesh, dims_mapping): + processes = ",".join([str(x) for x in process_mesh.processes]) + topology = ",".join([str(x) for x in process_mesh.topology]) + dims_mapping = ",".join([str(x) for x in dims_mapping]) + result = processes + topology + dims_mapping + return result + + memories = {} + max_memories = {} + var_info = { + } # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} + + for block in self.program.blocks: + for op in block.ops: + self._ordered_ops.append([op.desc.id(), op]) + self._ordered_ops.sort(key=lambda x: x[0]) + + for op_id, op in self._ordered_ops: + dist_op = dist_context.get_dist_op_for_program(op) + process_mesh = dist_op.dist_attr.process_mesh + for var_name in op.input_arg_names: + input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( + var_name) + if var_name not in var_info: + var_info[var_name] = {} + key = _convert_pm_and_dm_to_str(process_mesh, + input_dims_mapping) + if key not in var_info[var_name]: + var_info[var_name][key] = {} + # it is even partition now + if "memory" not in var_info[var_name][key]: + var = dist_op.get_serial_input(var_name) + global_sizes = var.shape + dtype = var.dtype + sizes = DistributedTensor.get_local_sizes( + global_sizes, input_dims_mapping, process_mesh.topology, + process_mesh.processes) + var_info[var_name][key]["memory"] = self._calculate_bytes( + sizes, dtype) + if "position" not in var_info[var_name][key]: + var_info[var_name][key]["position"] = [] + var_info[var_name][key]["position"].append(op_id) + + for var_name in op.output_arg_names: + output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( + var_name) + if var_name not in var_info: + var_info[var_name] = {} + key = _convert_pm_and_dm_to_str(process_mesh, + output_dims_mapping) + if key not in var_info[var_name]: + var_info[var_name][key] = {} + if "memory" not in var_info[var_name][key]: + var = dist_op.get_serial_output(var_name) + global_sizes = var.shape + dtype = var.dtype + sizes = DistributedTensor.get_local_sizes( + global_sizes, output_dims_mapping, + process_mesh.topology, process_mesh.processes) + var_info[var_name][key]["memory"] = self._calculate_bytes( + sizes, dtype) + if "position" not in var_info[var_name][key]: + var_info[var_name][key]["position"] = [] + var_info[var_name][key]["position"].append(op_id) + + has_used_vars = set() + for op_id, op in self._ordered_ops: + can_free_memories = {} + can_free_vars = set() + dist_op = dist_context.get_dist_op_for_program(op) + process_mesh = dist_op.dist_attr.process_mesh + for var_name in op.input_arg_names: + input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( + var_name) + key = _convert_pm_and_dm_to_str(process_mesh, + input_dims_mapping) + has_used_var = var_name + key + var = dist_op.get_serial_input(var_name) + # not used + if var_name + key not in has_used_vars: + has_used_vars.add(has_used_var) + for process in process_mesh.processes: + if process not in memories: + memories[process] = 0 + memories[process] += var_info[var_name][key]["memory"] + # used + else: + if op_id == var_info[var_name][key]["position"][-1]: + if has_used_var not in can_free_vars: + can_free_vars.add(has_used_var) + if not var.persistable: + for process in process_mesh.processes: + if process not in can_free_memories: + can_free_memories[process] = 0 + can_free_memories[process] += var_info[ + var_name][key]["memory"] + + for var_name in op.output_arg_names: + output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( + var_name) + key = _convert_pm_and_dm_to_str(process_mesh, + output_dims_mapping) + has_used_var = var_name + key + var = dist_op.get_serial_output(var_name) + # not used + if var_name + key not in has_used_vars: + has_used_vars.add(has_used_var) + for process in process_mesh.processes: + if process not in memories: + memories[process] = 0 + memories[process] += var_info[var_name][key]["memory"] + # used + else: + if op_id == var_info[var_name][key]["position"][-1]: + if has_used_var not in can_free_vars: + can_free_vars.add(has_used_var) + if not var.persistable: + for process in process_mesh.processes: + if process not in can_free_memories: + can_free_memories[process] = 0 + can_free_memories[process] += var_info[ + var_name][key]["memory"] + + # calc peak memory + for process in memories: + if process not in max_memories: + max_memories[process] = memories[process] + else: + if memories[process] > max_memories[process]: + max_memories[process] = memories[process] + + # free memory + for process in can_free_memories: + if process in memories: + memories[process] -= can_free_memories[process] + + # Calculate the max memory in all ranks + max_memory = max(max_memories.values()) + + return max_memory + + def estimate(self, dist_context, resharder=None): + self.prepare() + from ..reshard import Resharder + resharder = Resharder(self.program, None, self.rank, dist_context, + []) if resharder is None else resharder + + block = self.program.global_block() + self._estimate_core(dist_context, resharder, block) + + return self.global_cost diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index bf12ebb4589..cf7779a02a1 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -34,7 +34,8 @@ from ..process_group import new_process_group from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs, build_dp_costs -from ..cost import EmbeddingOpCost, EmbeddingGradOpCost, AllreduceSumOpCost, IdentityOpCost +from ..cost import EmbeddingOpCost, EmbeddingGradOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost class DistributedEmbedding(DistributedOperatorImplContainer): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index d39a775d16e..3b519c2cc5b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -32,7 +32,7 @@ from .dist_default import DistributedDefaultImpl0 from ..cost import FillConstantBatchSizeLikeOpCost from ..cost import build_comp_desc_from_dist_op, build_dp_costs from ..cost import build_comp_costs_from_descs -from ..cost import AllreduceSumOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 18ceb79ea8f..f4c3e5a5800 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -39,8 +39,9 @@ from ..utils import _get_comm_group, _get_corresponding_rank from .dist_default import DistributedDefaultImpl0 from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs from ..cost import build_comm_costs_from_descs, build_comp_costs_from_descs -from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost, IdentityOpCost, AllreduceSumOpCost +from ..cost import MatmulV2OpCost, MatmulOpCost, MulOpCost from ..cost import MatmulV2GradOpCost, MatmulGradOpCost, MulGradOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost, IdentityOpCost def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index bef18d1da8a..890eb670def 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -24,11 +24,12 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping from .dist_default import DistributedDefaultImpl0 -from ..cost import AllreduceSumOpCost, _g_op_cost_factory +from ..cost import _g_op_cost_factory from ..cost import build_comp_desc_from_dist_op, build_dp_costs from ..cost import build_comp_costs_from_descs from ..cost import SoftmaxOpCost, SoftmaxGradOpCost from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost class DistributedSoftmax(DistributedOperatorImplContainer): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index e5b4a51c4db..88024f3777f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -24,10 +24,11 @@ from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping from .dist_default import DistributedDefaultImpl0 -from ..cost import AllreduceSumOpCost, Transpose2OpCost, Transpose2GradOpCost +from ..cost import Transpose2OpCost, Transpose2GradOpCost from ..cost import build_comp_desc_from_dist_op, build_comm_desc_from_dist_op, build_dp_costs from ..cost import build_comp_costs_from_descs from paddle.distributed.fleet.meta_optimizers.common import OpRole +from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost class DistributedTranspose2(DistributedOperatorImplContainer): diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 6b902d6fb77..10c9162a233 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -2065,3 +2065,209 @@ class Resharder: # reset some variable when remove operation ended Resharder.while_block_info = {} + + def get_cost(self, op, tensor, cluster): + # NOTE: The program should be the serial_program which is not been parted + global _g_special_ops + not_supported_op_type = _g_special_ops + ["while"] + reshard_op_cost = None + if op.type in not_supported_op_type: + return reshard_op_cost + else: + tensor_name = tensor.name + if tensor_name == "lod_tensor_blocking_queue_0": + return reshard_op_cost + else: + dist_tensor = self.dist_context.get_dist_tensor_for_program( + tensor) + # simplified processing: ignore union process mesh and output reshard + dist_op = self.dist_context.get_dist_op_for_program(op) + dims_mapping = dist_op.dist_attr.get_input_dims_mapping( + tensor.name) + process_mesh = dist_op.dist_attr.process_mesh + dist_attr = [process_mesh, dims_mapping] + if dist_tensor is not None and self.need_reshard( + dist_tensor, dist_attr): + if tensor_name not in self._has_resharded: + self._has_resharded[tensor_name] = [dist_op] + else: + for item in self._has_resharded[tensor_name]: + item_dist_attr = item.dist_attr + item_dims_mapping = item_dist_attr.get_input_dims_mapping( + tensor_name) + item_process_mesh = item_dist_attr.process_mesh + if dims_mapping == item_dims_mapping and item_process_mesh == process_mesh: + return reshard_op_cost + self._has_resharded[tensor_name].append(dist_op) + + reshard_op_desc = self.find_op_desc_seq(dist_tensor, + dist_attr, + serial=True) + dtype = dist_tensor.serial_tensor.dtype + reshard_op_cost = self.parse_op_desc_for_cost( + reshard_op_desc, dtype, cluster) + + return reshard_op_cost + + def _concat_partitions_for_cost(self, partition_tensor_list, + partition_index, dtype, rank_id, + local_rank_comp_cost, cluster): + if not partition_tensor_list: + partition_tensor_list.append(partition_index) + else: + i = 0 + has_concat = False + while i < len(partition_tensor_list): + concat_axis, first_order, new_partition = Resharder.compute_concat_info( + partition_tensor_list[i], partition_index) + if concat_axis != -1: + has_concat = True + concat_desc = {} + concat_desc["op"] = "concat" + concat_desc["attrs"] = {"axis": concat_axis} + if first_order == 0: + concat_desc["inputs"] = { + "X": [(dtype, partition_tensor_list[i]), + (dtype, partition_index)] + } + else: + concat_desc["inputs"] = { + "X": [(dtype, partition_index), + (dtype, partition_tensor_list[i])] + } + partition_tensor_list.pop(i) + if rank_id not in local_rank_comp_cost: + local_rank_comp_cost[rank_id] = [] + local_rank_comp_cost[rank_id].append( + ConcatOpCost(op_desc=concat_desc, cluster=cluster)) + self._concat_partitions_for_cost(partition_tensor_list, + new_partition, dtype, + rank_id, + local_rank_comp_cost, + cluster) + break + i += 1 + if not has_concat: + partition_tensor_list.append(partition_index) + + def parse_op_desc_for_cost(self, reshard_op_desc, dtype, cluster): + + def _get_idx(comm_ranks, group_ranks): + res, is_the_same = None, False + idx = 0 + while idx < len(comm_ranks): + if comm_ranks[idx] == set(group_ranks): + is_the_same = True + + for rank in group_ranks: + if rank in comm_ranks[idx]: + res = idx + comm_ranks[idx].add(rank) + if res is None: + idx += 1 + else: + break + return res, is_the_same + + comm_context = CommContext(cluster) + # run communication op before computation op + # TODO: Communication cost is not calculated when the var has been transfered by the same group in the past + comm_costs = [] + comm_ranks = [] + local_rank_comp_cost = {} + for key in reshard_op_desc: + partition_tensor_list = [] + op_desc_list = reshard_op_desc[key] + for op_desc in op_desc_list: + if isinstance(op_desc, SendOpDesc): + group_ranks = [key, op_desc.dst] + shape = op_desc.shape + send_desc = build_comm_desc("send_v2", group_ranks, dtype, + shape) + idx, is_the_same = _get_idx(comm_ranks, group_ranks) + if idx is None: + comm_costs.append([ + (group_ranks, + SendOpCost(op_desc=send_desc, + comm_context=comm_context)) + ]) + comm_ranks.append(set(group_ranks)) + else: + if not is_the_same: + comm_costs[idx].append( + (group_ranks, + SendOpCost(op_desc=send_desc, + comm_context=comm_context))) + elif isinstance(op_desc, AllGatherOpDesc): + # NOTE: fill_const and other unnecessary op is not calculated because those cost is very small + group_ranks = op_desc.group + shape = op_desc.shape + allgather_desc = build_comm_desc("c_allgather", group_ranks, + dtype, shape) + split_inputs_shape = [] + for idx, dim in enumerate(shape): + if idx == 0: + split_inputs_shape.append(dim * len(group_ranks)) + else: + split_inputs_shape.append(dim) + idx, is_the_same = _get_idx(comm_ranks, group_ranks) + if idx is None: + comm_costs.append([ + (group_ranks, + AllgatherOpCost(op_desc=allgather_desc, + comm_context=comm_context)) + ]) + comm_ranks.append(set(group_ranks)) + else: + if not is_the_same: + comm_costs[idx].append( + (group_ranks, + AllgatherOpCost(op_desc=allgather_desc, + comm_context=comm_context))) + # calc the split op cost + if key not in local_rank_comp_cost: + local_rank_comp_cost[key] = [] + split_desc = {} + split_desc["op"] = "split" + split_desc["inputs"] = { + "inputs": [(dtype, split_inputs_shape)] + } + split_desc["attrs"] = {"num": len(group_ranks), "axis": 0} + local_rank_comp_cost[key].append( + SplitOpCost(op_desc=split_desc, cluster=cluster)) + elif isinstance(op_desc, ConcatOpDesc): + partition_index_list = op_desc._partition_index_list + for idx, partion_idex in enumerate(partition_index_list): + self._concat_partitions_for_cost( + partition_tensor_list, partion_idex, dtype, key, + local_rank_comp_cost, cluster) + + elif isinstance(op_desc, SliceOpDesc): + if key not in local_rank_comp_cost: + local_rank_comp_cost[key] = [] + assert len( + partition_tensor_list) == 1 or not partition_tensor_list + to_slice_tensor_shape = [] + if len(partition_tensor_list) == 1: + for item in partition_tensor_list[0]: + to_slice_tensor_shape.append(item[1] - item[0]) + else: + to_slice_tensor_shape = op_desc.shape + slice_desc = {} + slice_desc["op"] = "slice" + infer_flags = list(1 for i in range(len(op_desc.axes))) + slice_desc["attrs"] = { + "axes": op_desc.axes, + "starts": op_desc.starts, + "ends": op_desc.ends, + "infer_flags": infer_flags + } + slice_desc["inputs"] = { + "Input": [(dtype, to_slice_tensor_shape)] + } + local_rank_comp_cost[key].append( + SliceOpCost(op_desc=slice_desc, cluster=cluster)) + + res = (comm_costs, local_rank_comp_cost) + + return res diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py index 20d45e32b7a..b40a61ed34c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_with_planner.py @@ -15,6 +15,9 @@ import paddle import paddle.static as static from paddle.distributed import fleet +from paddle.distributed.auto_parallel.cost import CostEstimator +from paddle.distributed.auto_parallel.cluster import Cluster +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context def train(): @@ -39,6 +42,30 @@ def train(): _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( loss, start_program) + # add cost estimator + dist_context = get_default_distributed_context() + cluster = Cluster() + for op in train_program.global_block().ops: + dist_op = dist_context.get_dist_op_for_program(op) + for var_name in op.input_arg_names: + dims_mapping = dist_op.dist_attr.get_input_dims_mapping(var_name) + if dims_mapping is None: + dist_op.dist_attr.set_input_dims_mapping( + var_name, [ + -1 for i in range( + len(train_program.global_block().vars[var_name]. + shape)) + ]) + cluster.gen_default_config_cluster(device_count=2) + cost_estimator = CostEstimator(train_program, cluster) + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context) + # test cache + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context) + assert global_cost.time > 0 + assert max_memory > 0 + places = static.cuda_places() loader.set_batch_generator(batch_generator_creator(), places=places) exe = paddle.static.Executor(places[0]) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py index 6b0db61b984..e463adfc66d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_new_cost_model.py @@ -19,6 +19,7 @@ import tempfile import paddle import paddle.distributed.auto_parallel.cost as cost_model + from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_from_op from paddle.distributed.auto_parallel.cost.base_cost import build_comp_desc_str_for_predict from paddle.distributed.auto_parallel.cost.base_cost import calc_time_by_modeling diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index d6d613225d7..33396f283ec 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -29,6 +29,8 @@ from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.cost import CostEstimator +from paddle.distributed.auto_parallel.cluster import Cluster paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" @@ -196,6 +198,21 @@ class TestMLPReshard(unittest.TestCase): rank_id = 2 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) + + # test estimator + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=8) + cost_estimator = CostEstimator(train_program, cluster) + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context) + # test cache + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context) + assert global_cost.time > 0 + assert max_memory > 0 + resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index 5c699881c21..d5de1c12873 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -29,6 +29,8 @@ from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr +from paddle.distributed.auto_parallel.cost import CostEstimator +from paddle.distributed.auto_parallel.cluster import Cluster paddle.enable_static() _global_parallel_strategy = "mp_pp" @@ -247,7 +249,7 @@ class TestMLPReshard(unittest.TestCase): def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() - process_mesh = auto.ProcessMesh(mesh=[0, 3]) + process_mesh = auto.ProcessMesh(mesh=[0, 1]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = auto.shard_tensor(x, @@ -284,6 +286,21 @@ class TestMLPReshard(unittest.TestCase): dist_context.block_state.parse_forward_blocks(complete_train_program) partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( complete_train_program, startup_program, []) + + # test estimator + cluster = Cluster() + cluster.gen_default_config_cluster(device_count=2) + cost_estimator = CostEstimator(train_program, cluster) + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context) + # test cache + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context) + assert global_cost.time > 0 + assert max_memory > 0 + resharder = Resharder(partitioned_main_prog, partitioned_startup_prog, rank_id, dist_context, partitioned_params_grads) resharder.reshard() -- GitLab