diff --git a/python/paddle/distributed/auto_parallel/cluster.py b/python/paddle/distributed/auto_parallel/cluster.py index b1cb2b8a741c1409ec557623cb5e074e9267f2c0..9cb9cde457eeb74f3ea6a645176ef89bd8577128 100644 --- a/python/paddle/distributed/auto_parallel/cluster.py +++ b/python/paddle/distributed/auto_parallel/cluster.py @@ -14,6 +14,7 @@ import json import os +import re from enum import IntEnum, unique import paddle @@ -449,7 +450,6 @@ class Cluster: npu_models = ["NPU"] dcu_models = ["DCU"] all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models - assert gpu_model in all_gpu_models self._num_devices_per_machine = device_count def _convert_to_type(gpu_model): @@ -462,6 +462,8 @@ class Cluster: type = "NPU" elif gpu_model in dcu_models: type = "DCU" + else: + type = "GPU" assert type is not None return type @@ -470,6 +472,12 @@ class Cluster: model = None if gpu_model == "V100": model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB" + elif gpu_model == "A100": + model = "Tesla A100-SXM-" + str(gpu_memory) + "GB" + elif gpu_model == "A30": + model = "Tesla A30-SXM-" + str(gpu_memory) + "GB" + else: + model = gpu_model + str(gpu_memory) + "GB" assert model is not None return model @@ -527,6 +535,8 @@ class Cluster: device["memory"] = memory device["sp_gflops"] = sp_gflops device["dp_gflops"] = dp_gflops + # hard code + device["type"] = "GPU" global_id_to_device_type[global_id] = type global_id_to_node[global_id] = i devices.append(device) @@ -820,30 +830,82 @@ class Cluster: return self.__str__() -def get_default_cluster(): +def get_default_cluster(json_config=None): + def is_by_json_config(json_config): + if not json_config: + return False + if "cluster" not in json_config: + return False + else: + if "path" not in json_config["cluster"]: + if "num_nodes" not in json_config["cluster"]: + return False + if "num_gpus" not in json_config["cluster"]: + return False + if "gpu_model" not in json_config["cluster"]: + return False + if "gpu_memory" not in json_config["cluster"]: + return False + return True + else: + return True + cluster = Cluster() - local_device_count = os.getenv("PADDLE_LOCAL_SIZE") - if local_device_count is None: - local_device_count = 1 - else: - local_device_count = int(local_device_count) - global_device_count = os.getenv("PADDLE_GLOBAL_SIZE") - if global_device_count is None: - node_count = 1 + if json_config and is_by_json_config(json_config): + # Get GPU info by json config + if "path" in json_config["cluster"]: + cluster.build_from_file(json_config["cluster"]["path"]) + return cluster + else: + node_count = json_config["cluster"]["num_nodes"] + local_device_count = json_config["cluster"]["num_gpus"] + gpu_model = json_config["cluster"]["gpu_model"] + memory = json_config["cluster"]["gpu_memory"] else: - global_device_count = int(global_device_count) - assert global_device_count % local_device_count == 0 - node_count = int(global_device_count) // local_device_count + # Get GPU info by get_device_properties + local_device_count = os.getenv("PADDLE_LOCAL_SIZE") + if local_device_count is None: + local_device_count = 1 + else: + local_device_count = int(local_device_count) + + global_device_count = os.getenv("PADDLE_GLOBAL_SIZE") + if global_device_count is None: + node_count = 1 + else: + global_device_count = int(global_device_count) + assert global_device_count % local_device_count == 0 + node_count = int(global_device_count) // local_device_count + + gpu_info = paddle.device.cuda.get_device_properties() + assert gpu_info, "Auto parallel just runs on gpu now." + + gpu_name = gpu_info.name + try: + re_result = re.split(r'[ , -]', gpu_name) + gpu_model = re_result[1] + memory = int(re_result[-1][:-2]) + except: + memory = int(gpu_info.total_memory) // (1000**3) + gpu_model = gpu_name + print( "Node Count: ", node_count, "Local Device Size: ", local_device_count, + "GPU Model: ", + gpu_model, + "GPU Memory: ", + memory, "World size: ", paddle.distributed.get_world_size(), flush=True, ) cluster.gen_default_config_cluster( - node_count=node_count, device_count=local_device_count + node_count=node_count, + device_count=local_device_count, + gpu_model=gpu_model, + gpu_memory=memory, ) return cluster diff --git a/python/paddle/distributed/auto_parallel/cost/base_cost.py b/python/paddle/distributed/auto_parallel/cost/base_cost.py index 32a51302d18ef498622d3390bd7d85e4043103eb..4046b8cc4dba587ac5c75b1a0d29baf1c773c9e9 100644 --- a/python/paddle/distributed/auto_parallel/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/base_cost.py @@ -16,6 +16,7 @@ from collections import OrderedDict from functools import reduce import paddle +from paddle.utils.flops import flops from ..cluster import LinkType from ..dist_tensor import DistributedTensor @@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): output_desc = OrderedDict() # Get partitioned shape of input + input_var_desc = {} for input_name in op.input_names: var_name_list = op.input(input_name) - var_desc = [] + input_var_desc[input_name] = [] for var_name in var_name_list: var = get_var_with_recursion( var_name, op.block, op.block.program @@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): process, shard_sizes, ) - var_desc.append((var.dtype, shape)) + input_var_desc[input_name].append(shape) # For special op such as embedding and its grad op if ( @@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): relative_idx = relative_idx * per_part_size desc["attrs"]["start_index"] = relative_idx - input_desc[input_name] = var_desc - desc["inputs"] = input_desc + desc["inputs"] = input_var_desc for out_name in op.output_names: var_name_list = op.output(out_name) @@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): return desc -def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): +def build_comm_costs_from_descs( + op_cost_class, ctx, processes, descs, cluster, is_dp=False +): """Build comm costs by descriptions""" comm_context = CommContext(cluster) group_ranks_list = [] @@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): comm_op_cost = op_cost_class( op_desc=desc, comm_context=comm_context ) + if is_dp: + comm_op_cost.cost.time *= 0.9 comm_op_cost_list.append(comm_op_cost) return comm_op_cost_list @@ -389,6 +394,7 @@ def build_dp_costs( vars = dist_op.serial_op.block.vars var_name = var_names[0] has_found = False + is_input = True for name in dist_op.serial_op.input_arg_names: if var_name in name: var_name = name @@ -400,6 +406,7 @@ def build_dp_costs( if var_name in name: var_name = name has_found = True + is_input = False break if not has_found: return @@ -418,6 +425,7 @@ def build_dp_costs( processes, c_allreduce_sum_descs, cluster, + is_dp=True, ) result.append(comm_cost_list) @@ -431,22 +439,11 @@ def build_dp_costs( desc = {} desc["op"] = op_type desc["inputs"] = {} - if var_name in dist_attr.inputs_dist_attrs: - dims_mapping = dist_attr.get_input_dims_mapping(var_name) - elif var_name in dist_attr.outputs_dist_attrs: - dims_mapping = dist_attr.get_output_dims_mapping(var_name) - else: - raise AssertionError( - "cannot find dims_mapping for {} in {}".format( - var_name, dist_attr - ) - ) - - # dims_mapping = ( - # dist_attr.get_input_dims_mapping(var_name) - # if dist_attr.get_input_dims_mapping(var_name) is not None - # else dist_attr.get_output_dims_mapping(var_name) - # ) + dims_mapping = ( + dist_attr.get_input_dims_mapping(var_name) + if is_input + else dist_attr.get_output_dims_mapping(var_name) + ) var = get_var_with_recursion( var_name, dist_op.serial_op.block, @@ -493,8 +490,6 @@ class CommContext: # if cluster has no info about those vars, it will be set by default self.base_ring = None self.base_tree = None - # self.base_inter_ring = None - # self.base_inter_tree = None self.intra_ring = None self.intra_tree = None self.inter_ring = None @@ -508,8 +503,6 @@ class CommContext: # set default self.base_ring = 8.4 self.base_tree = 0.0 - # self.base_inter_ring = 9.6 - # self.base_inter_tree = 28 # NVL in default self.intra_ring = 3.4 self.intra_tree = 28 @@ -681,6 +674,8 @@ class Cost: class OpCost: + OP_TYPE = "op" + def __init__(self, op=None, op_desc=None): self._op = op self._op_desc = op_desc @@ -883,6 +878,24 @@ class CompOpCost(OpCost): ) ) + def calc_flops(self): + if not self.op_desc: + return 0 + if "_grad" in self.__class__.OP_TYPE: + op_type = self.__class__.OP_TYPE[: len(self.__class__.OP_TYPE) - 5] + return 2 * flops( + op_type, self.op_desc["inputs"], self.op_desc["attrs"] + ) + return flops( + self.__class__.OP_TYPE, + self.op_desc["inputs"], + self.op_desc["attrs"], + ) + + def calc_time(self): + flops_count = self.calc_flops() + return flops_count * 2.9e-7 + def register_op_cost(cls): op_type = cls.OP_TYPE diff --git a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py index 41ccc2265bbb5db410721a972d57453215f1cf94..d23e9b08090c222ca95e6f9838c96ee9bb61ac3b 100644 --- a/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comm_op_cost.py @@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost): super().__init__(op=op, op_desc=op_desc, comm_context=comm_context) def calc_time(self): - return 0 + return self.comm_count * 1 / (144 * 1e3) @register_op_cost diff --git a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py index edf0a3397f88691ee2e6f360bf9b34cbc1b4ff88..ea6d2ef571ca99214be5a852ba422aaa5f59cb9b 100644 --- a/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/comp_op_cost.py @@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ArgsortOpCost(CompOpCost): @@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class AssignOpCost(CompOpCost): @@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class AssignValueOpCost(CompOpCost): @@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class BeamSearchOpCost(CompOpCost): @@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class BeamSearchDecodeOpCost(CompOpCost): @@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class CastOpCost(CompOpCost): @@ -124,15 +70,6 @@ class CastOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ConcatOpCost(CompOpCost): @@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class DropoutOpCost(CompOpCost): @@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class DropoutGradOpCost(CompOpCost): @@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseAddOpCost(CompOpCost): @@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseAddGradOpCost(CompOpCost): @@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseDivOpCost(CompOpCost): @@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseDivGradOpCost(CompOpCost): @@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseMulOpCost(CompOpCost): @@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseMulGradOpCost(CompOpCost): @@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseSubOpCost(CompOpCost): @@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ElementwiseSubGradOpCost(CompOpCost): @@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class EqualOpCost(CompOpCost): @@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class EmbeddingOpCost(CompOpCost): @@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class EmbeddingGradOpCost(CompOpCost): @@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class FillConstantOpCost(CompOpCost): @@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class FillConstantBatchSizeLikeOpCost(CompOpCost): @@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): @@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): @@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class GatherOpCost(CompOpCost): @@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class GeluOpCost(CompOpCost): @@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class GeluGradOpCost(CompOpCost): @@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class GreaterEqualOpCost(CompOpCost): @@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class IncrementOpCost(CompOpCost): @@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class IsEmptyOpCost(CompOpCost): @@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LayerNormOpCost(CompOpCost): @@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LayerNormGradOpCost(CompOpCost): @@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LessThanOpCost(CompOpCost): @@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LogicalNotOpCost(CompOpCost): @@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LogicalAndOpCost(CompOpCost): @@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LodResetOpCost(CompOpCost): @@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LogOpCost(CompOpCost): @@ -643,15 +318,6 @@ class LogOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LookupTableV2OpCost(CompOpCost): @@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class LookupTableV2GradOpCost(CompOpCost): @@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MatmulOpCost(CompOpCost): @@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MatmulGradOpCost(CompOpCost): @@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MatmulV2OpCost(CompOpCost): @@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MatmulV2GradOpCost(CompOpCost): @@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MemcpyOpCost(CompOpCost): @@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MulOpCost(CompOpCost): @@ -779,15 +382,6 @@ class MulOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class MulGradOpCost(CompOpCost): @@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class OneHotOpCost(CompOpCost): @@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ReadFromArrayOpCost(CompOpCost): @@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ReduceSumOpCost(CompOpCost): @@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ReduceSumGradOpCost(CompOpCost): @@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Reshape2OpCost(CompOpCost): @@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Reshape2GradOpCost(CompOpCost): @@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ReduceMeanOpCost(CompOpCost): @@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ReduceMeanGradOpCost(CompOpCost): @@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SamplingIdOpCost(CompOpCost): @@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class ScaleOpCost(CompOpCost): @@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SliceOpCost(CompOpCost): @@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SoftmaxOpCost(CompOpCost): @@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SoftmaxGradOpCost(CompOpCost): @@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SoftmaxWithCrossEntropyOpCost(CompOpCost): @@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): @@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SplitOpCost(CompOpCost): @@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Squeeze2OpCost(CompOpCost): @@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SquareOpCost(CompOpCost): @@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SquareGradOpCost(CompOpCost): @@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class SumOpCost(CompOpCost): @@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class TopKOpCost(CompOpCost): @@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Transpose2OpCost(CompOpCost): @@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Transpose2GradOpCost(CompOpCost): @@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class Unsqueeze2OpCost(CompOpCost): @@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 - @register_op_cost class WriteToArrayOpCost(CompOpCost): @@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost): def __init__(self, op=None, op_desc=None, cluster=None): super().__init__(op=op, op_desc=op_desc, cluster=cluster) - - # For a concrete COMP OP, the calc_time and calc_flops function need to be overrided - def calc_flops(self): - # NOTE: The actual formula will be filled in the future - return 0 - - def calc_time(self): - # NOTE: The actual formula will be filled in the future - return 0 diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index f5c1172cef50da99a45bd258a765f55824f178c9..b948241da369cd10ad1aca6106926c93272b9275 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -189,6 +189,9 @@ class CostEstimator: # Calc dist op cost dist_op = dist_context.get_dist_op_for_program(op) + if not dist_op: + continue + op_dist_attr = dist_op.dist_attr processes = op_dist_attr.process_mesh.process_ids @@ -225,6 +228,8 @@ class CostEstimator: for rank in group_ranks: self.local_cost(rank).time = ( max_time + comm_op_cost.time + if op.attr('op_role') != OpRole.Backward + else max_time + 0.9 * comm_op_cost.time ) if rank not in self._bubble_time_mapping: self._bubble_time_mapping[rank] = 0 @@ -290,6 +295,7 @@ class CostEstimator: self._ordered_ops.append([op.desc.id(), op]) self._ordered_ops.sort(key=lambda x: x[0]) + parameters = set() for op_id, op in self._ordered_ops: if op.type in [ "create_py_reader", @@ -298,11 +304,14 @@ class CostEstimator: ]: continue dist_op = dist_context.get_dist_op_for_program(op) + if not dist_op: + continue process_mesh = dist_op.dist_attr.process_mesh for var_name in op.input_arg_names: input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( var_name ) + if var_name not in var_info: var_info[var_name] = {} key = _convert_pm_and_dm_to_str( @@ -311,6 +320,10 @@ class CostEstimator: if key not in var_info[var_name]: var_info[var_name][key] = {} # It is even partition now + if "position" not in var_info[var_name][key]: + var_info[var_name][key]["position"] = [] + var_info[var_name][key]["position"].append(op_id) + if "memory" not in var_info[var_name][key]: var = dist_op.get_serial_input(var_name) global_sizes = var.shape @@ -324,9 +337,16 @@ class CostEstimator: var_info[var_name][key]["memory"] = self._calculate_bytes( sizes, dtype ) - if "position" not in var_info[var_name][key]: - var_info[var_name][key]["position"] = [] - var_info[var_name][key]["position"].append(op_id) + if var.persistable: + name = var_name + key + if name not in parameters: + parameters.add(name) + for process in process_mesh.process_ids: + if process not in memories: + memories[process] = 0 + memories[process] += var_info[var_name][key][ + "memory" + ] for var_name in op.output_arg_names: output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( @@ -339,6 +359,10 @@ class CostEstimator: ) if key not in var_info[var_name]: var_info[var_name][key] = {} + if "position" not in var_info[var_name][key]: + var_info[var_name][key]["position"] = [] + var_info[var_name][key]["position"].append(op_id) + if "memory" not in var_info[var_name][key]: var = dist_op.get_serial_output(var_name) global_sizes = var.shape @@ -352,11 +376,19 @@ class CostEstimator: var_info[var_name][key]["memory"] = self._calculate_bytes( sizes, dtype ) - if "position" not in var_info[var_name][key]: - var_info[var_name][key]["position"] = [] - var_info[var_name][key]["position"].append(op_id) + if var.persistable: + name = var_name + key + if name not in parameters: + parameters.add(name) + for process in process_mesh.process_ids: + if process not in memories: + memories[process] = 0 + memories[process] += var_info[var_name][key][ + "memory" + ] has_used_vars = set() + not_calc_vars = set() for op_id, op in self._ordered_ops: if op.type in [ "create_py_reader", @@ -367,6 +399,8 @@ class CostEstimator: can_free_memories = {} can_free_vars = set() dist_op = dist_context.get_dist_op_for_program(op) + if not dist_op: + continue process_mesh = dist_op.dist_attr.process_mesh for var_name in op.input_arg_names: input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( @@ -378,24 +412,30 @@ class CostEstimator: has_used_var = var_name + key var = dist_op.get_serial_input(var_name) # Not used - if var_name + key not in has_used_vars: + if ( + has_used_var not in has_used_vars + and has_used_var not in parameters + ): + if has_used_var in not_calc_vars: + continue has_used_vars.add(has_used_var) for process in process_mesh.process_ids: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] # Used - else: - if op_id == var_info[var_name][key]["position"][-1]: - if has_used_var not in can_free_vars: - can_free_vars.add(has_used_var) - if not var.persistable: - for process in process_mesh.process_ids: - if process not in can_free_memories: - can_free_memories[process] = 0 - can_free_memories[process] += var_info[ - var_name - ][key]["memory"] + if op_id == var_info[var_name][key]["position"][-1]: + if ( + has_used_var not in can_free_vars + and not var.persistable + ): + can_free_vars.add(has_used_var) + for process in process_mesh.process_ids: + if process not in can_free_memories: + can_free_memories[process] = 0 + can_free_memories[process] += var_info[var_name][ + key + ]["memory"] for var_name in op.output_arg_names: output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( @@ -406,25 +446,36 @@ class CostEstimator: ) has_used_var = var_name + key var = dist_op.get_serial_output(var_name) + if ( + op.type == "reshape2" + or op.type == "transpose2" + or op.type == "elementwise_add" + ): + not_calc_vars.add(has_used_var) + continue # Not used - if var_name + key not in has_used_vars: + if ( + has_used_var not in has_used_vars + and has_used_var not in parameters + ): has_used_vars.add(has_used_var) for process in process_mesh.process_ids: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] # Used - else: - if op_id == var_info[var_name][key]["position"][-1]: - if has_used_var not in can_free_vars: - can_free_vars.add(has_used_var) - if not var.persistable: - for process in process_mesh.process_ids: - if process not in can_free_memories: - can_free_memories[process] = 0 - can_free_memories[process] += var_info[ - var_name - ][key]["memory"] + if op_id == var_info[var_name][key]["position"][-1]: + if ( + has_used_var not in can_free_vars + and not var.persistable + ): + can_free_vars.add(has_used_var) + for process in process_mesh.process_ids: + if process not in can_free_memories: + can_free_memories[process] = 0 + can_free_memories[process] += var_info[var_name][ + key + ]["memory"] # Calc peak memory for process in memories: @@ -433,7 +484,6 @@ class CostEstimator: else: if memories[process] > self.max_memories[process]: self.max_memories[process] = memories[process] - # Free memory for process in can_free_memories: if process in memories: @@ -513,7 +563,7 @@ class CostEstimator: # Padding automatically max_len = 0 - header = ["Execution Time(ms)", "Max Memory(MiB)"] + header = ["Execution Time(us)", "Max Memory(MiB)"] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] for memory in vals + header: if len(str(memory)) > max_len: diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 6fc9b9d27eee2fac48b630751d628c4ce7fe46e4..0e3a1e03b91706af6429e9dcf358725f3fc77e01 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -2716,6 +2716,8 @@ class Resharder: ) # simplified processing: ignore union process mesh and output reshard dist_op = self.dist_context.get_dist_op_for_program(op) + if not dist_tensor or not dist_op: + return reshard_op_cost dims_mapping = dist_op.dist_attr.get_input_dims_mapping( tensor.name ) diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index e2713cea63aae66096296575d45ba98edde22c3d..cebc31df5a8cf098ed4e2a1e3eaa3ef439bcf7d1 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -16,17 +16,31 @@ import copy import logging import math import os +import pickle +import sys +import time from abc import abstractmethod from collections import OrderedDict +from functools import reduce + +import numpy as np import paddle +from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh from paddle.distributed.auto_parallel.completion import Completer +from paddle.distributed.auto_parallel.cost import CostEstimator from paddle.distributed.auto_parallel.dist_attribute import ( OperatorDistAttr, TensorDistAttr, ) from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.utils import ( + is_gradient_clip_op, + print_program_with_dist_attr, +) +from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.fluid import program_guard from paddle.fluid.backward import append_backward from paddle.fluid.framework import Parameter, unique_name @@ -1610,3 +1624,603 @@ class RuleBasedTuner: idx ][parallelism][key] self._complete_sub_bwd_program(sub_program_dist_context) + + def _complete_sub_update_program(self, sub_program_dist_context): + """ + Complete the opt OP according to the tensor. + Most of the logic is the same as the update completion in the completer. + """ + world_ranks = ProcessMesh( + [ + i + for i in range( + self._cluster.get_num_machines() + * self._cluster._num_devices_per_machine + ) + ] + ) + dist_tensors = sub_program_dist_context._dist_tensors_for_program + + vars = self.full_main_program.global_block().vars + ops = self.full_main_program.global_block().ops + learning_rate_completed = False + for idx in range(len(ops)): + op = ops[idx] + if int(op.attr('op_role')) == int(OpRole.Optimize): + if is_gradient_clip_op(op): + if op.type in [ + "sum", + "sqrt", + "fill_constant", + "elementwise_max", + "elementwise_div", + ]: + op_dist_attr = OperatorDistAttr() + op_dist_attr.process_mesh = world_ranks + for in_name in op.input_arg_names: + in_var = vars[in_name] + if in_var.desc.original_id() in dist_tensors: + in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + in_var + ) + op_dist_attr.set_input_dist_attr( + in_name, in_dist_attr + ) + else: + in_dist_attr = TensorDistAttr() + in_dist_attr.process_mesh = world_ranks + in_dist_attr.dims_mapping = [ + -1 for _ in range(len(in_var.shape)) + ] + op_dist_attr.set_input_dist_attr( + in_name, in_dist_attr + ) + sub_program_dist_context.set_tensor_dist_attr_for_program( + in_var, in_dist_attr + ) + for out_name in op.output_arg_names: + out_var = vars[out_name] + if out_var.desc.original_id() in dist_tensors: + out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + out_var + ) + op_dist_attr.set_output_dist_attr( + out_name, out_dist_attr + ) + else: + out_dist_attr = TensorDistAttr() + out_dist_attr.process_mesh = world_ranks + out_dist_attr.dims_mapping = [ + -1 for _ in range(len(out_var.shape)) + ] + sub_program_dist_context.set_tensor_dist_attr_for_program( + out_var, out_dist_attr + ) + op_dist_attr.set_output_dist_attr( + out_name, out_dist_attr + ) + sub_program_dist_context.set_op_dist_attr_for_program( + op, op_dist_attr + ) + else: + in_var = vars[op.input("X")[0]] + if in_var.desc.original_id() in dist_tensors: + in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + in_var + ) + assert in_dist_attr is not None + ref_process_mesh = in_dist_attr.process_mesh + ref_dims_mapping = in_dist_attr.dims_mapping + + if ( + op.type == "cast" + and ops[idx + 1].type == "elementwise_mul" + ): + ref_var = vars[ops[idx + 1].input("X")[0]] + ref_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + ref_var + ) + assert ref_dist_attr is not None + ref_process_mesh = ref_dist_attr.process_mesh + + out_var = vars[op.output("Out")[0]] + out_dist_attr = TensorDistAttr() + out_dist_attr.process_mesh = ref_process_mesh + if out_var.shape == in_var.shape: + out_dist_attr.dims_mapping = ref_dims_mapping + else: + assert ( + len(out_var.shape) == 1 + and out_var.shape[0] == 1 + ) + out_dist_attr.dims_mapping = [-1] + sub_program_dist_context.set_tensor_dist_attr_for_program( + out_var, out_dist_attr + ) + + op_dist_attr = OperatorDistAttr() + op_dist_attr.process_mesh = ref_process_mesh + for in_name in op.input_arg_names: + in_var = vars[in_name] + in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + in_var + ) + op_dist_attr.set_input_dims_mapping( + in_name, in_dist_attr.dims_mapping + ) + for out_name in op.output_arg_names: + out_var = vars[out_name] + out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + out_var + ) + op_dist_attr.set_output_dims_mapping( + out_name, out_dist_attr.dims_mapping + ) + op_dist_attr.set_input_dist_attr( + in_var.name, in_dist_attr + ) + op_dist_attr.set_output_dist_attr( + out_var.name, out_dist_attr + ) + + sub_program_dist_context.set_op_dist_attr_for_program( + op, op_dist_attr + ) + else: + continue + + if "Grad" in op.input_names and "Param" in ops[idx].input_names: + assert ( + len(op.input("Param")) == 1 + ), "Only support one-to-one now." + assert ( + len(op.input("Grad")) == 1 + ), "Only support one-to-one now." + param = vars[op.input("Param")[0]] + grad_var = vars[op.input("Grad")[0]] + if param.desc.original_id() in dist_tensors: + param_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program( + param + ) + assert param_dist_attr is not None + ref_process_mesh = sub_program_dist_context.get_tensor_dist_attr_for_program( + param + ).process_mesh + assert ref_process_mesh is not None + ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program( + param + ).dims_mapping + assert ref_dims_mapping is not None + op_dist_attr = OperatorDistAttr() + op_dist_attr.process_mesh = ref_process_mesh + op_dist_attr.set_input_dims_mapping( + grad_var.name, ref_dims_mapping + ) + op_dist_attr.set_input_dims_mapping( + param.name, ref_dims_mapping + ) + op_dist_attr.set_output_dims_mapping( + param.name, ref_dims_mapping + ) + learning_var = vars[op.input("LearningRate")[0]] + op_dist_attr.set_input_dims_mapping( + learning_var.name, [-1] + ) + op_dist_attr.set_output_dims_mapping( + learning_var.name, [-1] + ) + + if not learning_rate_completed: + learning_rate_completed = True + var_dist_attr = TensorDistAttr() + var_dist_attr.process_mesh = world_ranks + var_dist_attr.dims_mapping = [-1] + sub_program_dist_context.set_tensor_dist_attr_for_program( + learning_var, var_dist_attr + ) + + for input_name in op.desc.input_names(): + + if input_name in [ + 'Param', + 'Grad', + 'LearningRate', + "SkipUpdate", + "Beta1Tensor", + "Beta2Tensor", + "EpsilonTensor", + ]: + continue + if len(op.desc.input(input_name)) == 0: + continue + + assert len(op.desc.input(input_name)) == 1 + input_var = vars[op.desc.input(input_name)[0]] + input_var_attr = TensorDistAttr() + + if ( + "Beta1Pow" in input_name + or "Beta2Pow" in input_name + ): + input_var_attr.dims_mapping = [-1] + op_dist_attr.set_input_dims_mapping( + input_var.name, [-1] + ) + op_dist_attr.set_output_dims_mapping( + input_var.name, [-1] + ) + else: + input_var_attr.dims_mapping = ref_dims_mapping + op_dist_attr.set_input_dims_mapping( + input_var.name, ref_dims_mapping + ) + op_dist_attr.set_output_dims_mapping( + input_var.name, ref_dims_mapping + ) + + input_var_attr.process_mesh = ref_process_mesh + sub_program_dist_context.set_tensor_dist_attr_for_program( + input_var, input_var_attr + ) + + sub_program_dist_context.set_op_dist_attr_for_program( + op, op_dist_attr + ) + continue + else: + continue + + def complete_sub_update_programs(self): + for idx in self.sub_programs_dist_context: + for parallelism in self.sub_programs_dist_context[idx]: + for key in self.sub_programs_dist_context[idx][parallelism]: + sub_program_dist_context = self.sub_programs_dist_context[ + idx + ][parallelism][key] + self._complete_sub_update_program(sub_program_dist_context) + + def convert_device_mesh_to_key(self, device_mesh): + """Convert device mesh object to str.""" + processes = ",".join([str(x) for x in device_mesh.device_ids]) + topology = ",".join([str(x) for x in device_mesh.shape]) + key = processes + ";" + topology + return key + + def _get_sub_program_cost(self, dist_context): + """Estimate the cost of dist context.""" + cost_estimator = CostEstimator(self.full_main_program, self._cluster) + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op( + dist_context + ) + return global_cost.time, max_memory + + def combine_dist_contexts(self, dist_contexts): + """Combine the dist attr in dist contexts to one dist context.""" + combined_dist_context = DistributedContext() + # set dist tensor, pay attention to shared param or var as input for multi op + for dist_context in dist_contexts: + for tensor_id in dist_context._dist_tensors_for_program: + dist_tensor = dist_context._dist_tensors_for_program[tensor_id] + if ( + tensor_id + not in combined_dist_context._dist_tensors_for_program + ): + combined_dist_context.add_dist_tensor_for_program( + dist_tensor + ) + + # set dist op + for op_id in dist_context._dist_ops_for_program: + dist_op = dist_context._dist_ops_for_program[op_id] + combined_dist_context.add_dist_op_for_program(dist_op) + + for process_mesh in dist_context.process_meshes: + combined_dist_context.add_process_mesh(process_mesh) + + return combined_dist_context + + def prepare(self): + """Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need.""" + + # step1: cluster operators to layers + begin = time.time() + self.layers = self.cluster_operators() + end = time.time() + self._logger.info( + "Cluster operators to {} layers in {}s.".format( + len(self.layers), end - begin + ) + ) + + # step2: generate sub program of each layer + begin = time.time() + self.gen_fwd_sub_programs_by_clone() + end = time.time() + self._logger.info( + "Generate programs of every layer in {}s.".format(end - begin) + ) + + # step3: partition devices to device meshes + begin = time.time() + n, m = ( + self._cluster.get_num_machines(), + self._cluster._num_devices_per_machine, + ) + device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m) + end = time.time() + self._logger.info("Partition cluster in {}s.".format(end - begin)) + + # step4: transform device mesh to process meshes + dm_idx = 0 + for device_meshes in device_meshes_list: + has_used_devices = 0 + self.device_meshes_list.append([]) + for device_mesh in device_meshes: + devices = reduce(lambda x, y: x * y, device_mesh) + processes = [ + i + for i in range(has_used_devices, has_used_devices + devices) + ] + device_mesh_shape = ( + device_mesh + if device_mesh[0] != 1 + else [device_mesh[i] for i in range(1, len(device_mesh))] + ) + self.device_meshes_list[-1].append( + DeviceMesh( + mesh=np.array(processes) + .reshape(device_mesh_shape) + .tolist(), + name="device_mesh_" + str(dm_idx), + ) + ) + dm_idx += 1 + has_used_devices += devices + process_mesh_shapes = convert_to_process_meshes(device_mesh) + for process_mesh_shape in process_mesh_shapes: + process_mesh = ProcessMesh( + np.array(processes).reshape(process_mesh_shape).tolist() + ) + if process_mesh not in self.process_meshes: + self.process_meshes.append(process_mesh) + + # step5: generate full program + begin = time.time() + self.gen_full_program() + end = time.time() + self._logger.info("Generate full program in {}s.".format(end - begin)) + + # step6: complete forward sub programs + begin = time.time() + for process_mesh in self.process_meshes: + self.complete_sub_fwd_programs(process_mesh) + end = time.time() + self._logger.info( + "Complete all sub forward programs in {}s.".format(end - begin) + ) + + if self.mode == "train": + # step7: complete backward sub programs + begin = time.time() + self.complete_sub_bwd_programs() + end = time.time() + self._logger.info( + "Complete all sub backward programs in {}s.".format(end - begin) + ) + + # step8: complete update sub programs + begin = time.time() + self.complete_sub_update_programs() + end = time.time() + self._logger.info( + "Complete all sub update programs in {}s.".format(end - begin) + ) + + def tune_o1(self): + """The o1 level tuning.""" + best_cost = sys.maxsize + best_dist_context = None + + for device_meshes in self.device_meshes_list: + pp_stages = len(device_meshes) + average_layers = len(self.layers) // pp_stages + device_mesh_shape = device_meshes[0].shape + if len(device_mesh_shape) == 1: + device_mesh_shape.insert(0, 1) + process_mesh_shapes = convert_to_process_meshes(device_mesh_shape) + + # For example, device_mesh is [1, 8] and process_mesh is [8]. + # The selective parallelism is dp or mp + # Get dp8 or mp8 cost and compare them to get best sreategy. + for parallelism in ["dp", "mp", "dp_mp", "mp_dp"]: + for process_mesh_shape in process_mesh_shapes: + dist_context_of_device_meshes = None + for idx, device_mesh in enumerate(device_meshes): + device_mesh_shape = device_mesh.shape + process_mesh = ProcessMesh( + np.array(device_mesh.device_ids) + .reshape(process_mesh_shape) + .tolist() + ) + + selective_parallelisms = ( + ["dp", "mp"] + if len(process_mesh.shape) == 1 + else ["dp_mp", "mp_dp"] + ) + if parallelism not in selective_parallelisms: + total_cost_of_device_meshes = sys.maxsize + continue + + key = self.convert_process_mesh_to_key(process_mesh) + + if idx == len(device_meshes) - 1: + start = idx * average_layers + end = len(self.layers) + else: + start = idx * average_layers + end = (idx + 1) * average_layers + + dist_context = self.combine_dist_contexts( + [ + self.sub_programs_dist_context[j][parallelism][ + key + ] + for j in range(start, end) + ] + ) + + dist_context_of_device_meshes = ( + dist_context + if dist_context_of_device_meshes is None + else self.combine_dist_contexts( + [dist_context_of_device_meshes, dist_context] + ) + ) + if dist_context_of_device_meshes is not None: + cost, memory = self._get_sub_program_cost( + dist_context_of_device_meshes + ) + + self._logger.info( + "Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format( + memory / (1024**3), + cost, + parallelism, + process_mesh_shape, + len(device_meshes), + ) + ) + # 15% buffer is reserved for memory cost + if memory > 0.85 * self.cluster.machines[0].devices[ + 0 + ].memory * (1024**3): + cost = sys.maxsize + + if cost < best_cost: + best_cost = cost + best_dist_context = dist_context_of_device_meshes + self._logger.info( + "O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format( + parallelism, + process_mesh_shape, + len(device_meshes), + memory / (1024**3), + ) + ) + + return best_dist_context + + def tune_o2(self): + return None + + def save_strategy(self, best_dist_context, path): + dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []} + for key in best_dist_context._dist_tensors_for_program: + if key in self._dist_context._dist_tensors_for_program: + dist_tensor = best_dist_context._dist_tensors_for_program[key] + dist_attrs["tensor"][ + key + ] = dist_tensor.dist_attr.serialize_to_string() + assert dist_attrs["tensor"], "Tensor dist attrs must not be None." + + for key in best_dist_context._dist_ops_for_program: + if key in self._dist_context._dist_ops_for_program: + dist_op = best_dist_context._dist_ops_for_program[key] + dist_attrs["op"][key] = dist_op.dist_attr.serialize_to_string() + assert dist_attrs["op"], "Op dist attrs must not be None." + + for process_mesh in best_dist_context._process_meshes: + process_ids = process_mesh.process_ids + process_shape = process_mesh.shape + dist_attrs["process_meshes"].append([process_ids, process_shape]) + + dist_attrs["cluster"] = self._cluster + with open(path, 'wb') as f: + pickle.dump(dist_attrs, f) + self._logger.info("The strategy has been saved at {}".format(path)) + + def run_or_quit(self): + # Quit if just tune + if not self._is_run: + self._logger.info( + "The process will be quitted when just tune not run." + ) + quit() + + def tune(self): + begin = time.time() + self.match_program(self._dist_context.serial_main_program) + end = time.time() + self._logger.info("Pattern match in {}s.".format(end - begin)) + + if self._use_dp: + completer = Completer(self._dist_context) + completer.complete_forward_annotation() + print_program_with_dist_attr( + self._dist_context.serial_main_program, self._dist_context + ) + # Save strategy if need + path = self._strategy_path + if path: + self.save_strategy(self._dist_context, path) + self.run_or_quit() + return + + # prepare + self.prepare() + + best_dist_context = None + if self.level == "o2": + best_dist_context = self.tune_o2() + + elif self.level == "o1": + # If level is o1, it means all layers within same parallelism. + # When in pipeline parallism, it means that place layers evenly. + use_o2_level = False + for device_meshes in self.device_meshes_list: + if len(device_meshes) > 1: + shape = None + for device_mesh in device_meshes: + if shape is None: + shape = device_mesh.shape + continue + else: + if shape != device_mesh.shape: + self._logger.info( + "Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune." + ) + use_o2_level = True + break + if use_o2_level: + best_dist_context = self.tune_o2() + else: + best_dist_context = self.tune_o1() + + assert ( + best_dist_context is not None + ), "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding." + + for key in best_dist_context._dist_tensors_for_program: + if key in self._dist_context._dist_tensors_for_program: + self._dist_context._dist_tensors_for_program[ + key + ] = best_dist_context._dist_tensors_for_program[key] + for key in best_dist_context._dist_ops_for_program: + if key in self._dist_context._dist_ops_for_program: + self._dist_context._dist_ops_for_program[ + key + ] = best_dist_context._dist_ops_for_program[key] + self._dist_context._process_meshes = best_dist_context._process_meshes + + end = time.time() + self._logger.info("Rule-based tuner end in {}s.".format(end - begin)) + self._logger.info("The best strategy found is as follows: ") + print_program_with_dist_attr(self.full_main_program, best_dist_context) + + # Save strategy if need + path = self._strategy_path + if path: + self.save_strategy(best_dist_context, path) + self.run_or_quit() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py index 33fdc63cdc3115ec7c4420e65aba841f582f8d6c..15e9390cc400c7b7e13715638dece560168e9544 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_rule_based_tuner.py @@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase): modeling.init_global() train_program = static.Program() start_program = static.Program() - place = paddle.set_device("gpu") batch_size = 8 sequence_len = 512 vocab_size = 1000 + place = None train_program, start_program, loss, gen_data = get_gpt_model( train_program, start_program, @@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase): sequence_len, vocab_size, ) + from paddle.distributed.auto_parallel.cluster import Cluster from paddle.distributed.auto_parallel.dist_context import ( DistributedContext, ) - from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( RuleBasedTuner, ) clip = paddle.nn.ClipGradByGlobalNorm(0.2) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + + cluster = Cluster() + cluster.gen_default_config_cluster(node_count=1, device_count=8) dist_context = DistributedContext( serial_main_prog=train_program, serial_startup_prog=start_program, serial_optimizer=opt, serial_loss=loss, + cluster=cluster, ) dist_context.initialize() tuner = RuleBasedTuner(dist_context) - tuner.cluster_operators() - tuner.gen_full_program() - tuner.match_program(tuner._dist_context.serial_main_program) - process_mesh = ProcessMesh([0, 1]) - tuner.gen_fwd_sub_programs_by_clone() - tuner.complete_sub_fwd_programs(process_mesh) - tuner.complete_sub_bwd_programs() + tuner.tune() if __name__ == "__main__":