未验证 提交 d6011cb6 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add o1 level tune (#52041)

* add tune o1 level

* add unittest
上级 418b983c
......@@ -14,6 +14,7 @@
import json
import os
import re
from enum import IntEnum, unique
import paddle
......@@ -449,7 +450,6 @@ class Cluster:
npu_models = ["NPU"]
dcu_models = ["DCU"]
all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models
assert gpu_model in all_gpu_models
self._num_devices_per_machine = device_count
def _convert_to_type(gpu_model):
......@@ -462,6 +462,8 @@ class Cluster:
type = "NPU"
elif gpu_model in dcu_models:
type = "DCU"
else:
type = "GPU"
assert type is not None
return type
......@@ -470,6 +472,12 @@ class Cluster:
model = None
if gpu_model == "V100":
model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB"
elif gpu_model == "A100":
model = "Tesla A100-SXM-" + str(gpu_memory) + "GB"
elif gpu_model == "A30":
model = "Tesla A30-SXM-" + str(gpu_memory) + "GB"
else:
model = gpu_model + str(gpu_memory) + "GB"
assert model is not None
return model
......@@ -527,6 +535,8 @@ class Cluster:
device["memory"] = memory
device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type
global_id_to_node[global_id] = i
devices.append(device)
......@@ -820,13 +830,45 @@ class Cluster:
return self.__str__()
def get_default_cluster():
def get_default_cluster(json_config=None):
def is_by_json_config(json_config):
if not json_config:
return False
if "cluster" not in json_config:
return False
else:
if "path" not in json_config["cluster"]:
if "num_nodes" not in json_config["cluster"]:
return False
if "num_gpus" not in json_config["cluster"]:
return False
if "gpu_model" not in json_config["cluster"]:
return False
if "gpu_memory" not in json_config["cluster"]:
return False
return True
else:
return True
cluster = Cluster()
if json_config and is_by_json_config(json_config):
# Get GPU info by json config
if "path" in json_config["cluster"]:
cluster.build_from_file(json_config["cluster"]["path"])
return cluster
else:
node_count = json_config["cluster"]["num_nodes"]
local_device_count = json_config["cluster"]["num_gpus"]
gpu_model = json_config["cluster"]["gpu_model"]
memory = json_config["cluster"]["gpu_memory"]
else:
# Get GPU info by get_device_properties
local_device_count = os.getenv("PADDLE_LOCAL_SIZE")
if local_device_count is None:
local_device_count = 1
else:
local_device_count = int(local_device_count)
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE")
if global_device_count is None:
node_count = 1
......@@ -834,16 +876,36 @@ def get_default_cluster():
global_device_count = int(global_device_count)
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now."
gpu_name = gpu_info.name
try:
re_result = re.split(r'[ , -]', gpu_name)
gpu_model = re_result[1]
memory = int(re_result[-1][:-2])
except:
memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name
print(
"Node Count: ",
node_count,
"Local Device Size: ",
local_device_count,
"GPU Model: ",
gpu_model,
"GPU Memory: ",
memory,
"World size: ",
paddle.distributed.get_world_size(),
flush=True,
)
cluster.gen_default_config_cluster(
node_count=node_count, device_count=local_device_count
node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
)
return cluster
......@@ -16,6 +16,7 @@ from collections import OrderedDict
from functools import reduce
import paddle
from paddle.utils.flops import flops
from ..cluster import LinkType
from ..dist_tensor import DistributedTensor
......@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc = OrderedDict()
# Get partitioned shape of input
input_var_desc = {}
for input_name in op.input_names:
var_name_list = op.input(input_name)
var_desc = []
input_var_desc[input_name] = []
for var_name in var_name_list:
var = get_var_with_recursion(
var_name, op.block, op.block.program
......@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process,
shard_sizes,
)
var_desc.append((var.dtype, shape))
input_var_desc[input_name].append(shape)
# For special op such as embedding and its grad op
if (
......@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx = relative_idx * per_part_size
desc["attrs"]["start_index"] = relative_idx
input_desc[input_name] = var_desc
desc["inputs"] = input_desc
desc["inputs"] = input_var_desc
for out_name in op.output_names:
var_name_list = op.output(out_name)
......@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return desc
def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
def build_comm_costs_from_descs(
op_cost_class, ctx, processes, descs, cluster, is_dp=False
):
"""Build comm costs by descriptions"""
comm_context = CommContext(cluster)
group_ranks_list = []
......@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost = op_cost_class(
op_desc=desc, comm_context=comm_context
)
if is_dp:
comm_op_cost.cost.time *= 0.9
comm_op_cost_list.append(comm_op_cost)
return comm_op_cost_list
......@@ -389,6 +394,7 @@ def build_dp_costs(
vars = dist_op.serial_op.block.vars
var_name = var_names[0]
has_found = False
is_input = True
for name in dist_op.serial_op.input_arg_names:
if var_name in name:
var_name = name
......@@ -400,6 +406,7 @@ def build_dp_costs(
if var_name in name:
var_name = name
has_found = True
is_input = False
break
if not has_found:
return
......@@ -418,6 +425,7 @@ def build_dp_costs(
processes,
c_allreduce_sum_descs,
cluster,
is_dp=True,
)
result.append(comm_cost_list)
......@@ -431,22 +439,11 @@ def build_dp_costs(
desc = {}
desc["op"] = op_type
desc["inputs"] = {}
if var_name in dist_attr.inputs_dist_attrs:
dims_mapping = dist_attr.get_input_dims_mapping(var_name)
elif var_name in dist_attr.outputs_dist_attrs:
dims_mapping = dist_attr.get_output_dims_mapping(var_name)
else:
raise AssertionError(
"cannot find dims_mapping for {} in {}".format(
var_name, dist_attr
)
dims_mapping = (
dist_attr.get_input_dims_mapping(var_name)
if is_input
else dist_attr.get_output_dims_mapping(var_name)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var = get_var_with_recursion(
var_name,
dist_op.serial_op.block,
......@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
self.base_ring = None
self.base_tree = None
# self.base_inter_ring = None
# self.base_inter_tree = None
self.intra_ring = None
self.intra_tree = None
self.inter_ring = None
......@@ -508,8 +503,6 @@ class CommContext:
# set default
self.base_ring = 8.4
self.base_tree = 0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self.intra_ring = 3.4
self.intra_tree = 28
......@@ -681,6 +674,8 @@ class Cost:
class OpCost:
OP_TYPE = "op"
def __init__(self, op=None, op_desc=None):
self._op = op
self._op_desc = op_desc
......@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
def calc_flops(self):
if not self.op_desc:
return 0
if "_grad" in self.__class__.OP_TYPE:
op_type = self.__class__.OP_TYPE[: len(self.__class__.OP_TYPE) - 5]
return 2 * flops(
op_type, self.op_desc["inputs"], self.op_desc["attrs"]
)
return flops(
self.__class__.OP_TYPE,
self.op_desc["inputs"],
self.op_desc["attrs"],
)
def calc_time(self):
flops_count = self.calc_flops()
return flops_count * 2.9e-7
def register_op_cost(cls):
op_type = cls.OP_TYPE
......
......@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
return 0
return self.comm_count * 1 / (144 * 1e3)
@register_op_cost
......
......@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ArgsortOpCost(CompOpCost):
......@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class AssignOpCost(CompOpCost):
......@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class AssignValueOpCost(CompOpCost):
......@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class BeamSearchOpCost(CompOpCost):
......@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class BeamSearchDecodeOpCost(CompOpCost):
......@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class CastOpCost(CompOpCost):
......@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ConcatOpCost(CompOpCost):
......@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class DropoutOpCost(CompOpCost):
......@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class DropoutGradOpCost(CompOpCost):
......@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
......@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseAddGradOpCost(CompOpCost):
......@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseDivOpCost(CompOpCost):
......@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseDivGradOpCost(CompOpCost):
......@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseMulOpCost(CompOpCost):
......@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseMulGradOpCost(CompOpCost):
......@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseSubOpCost(CompOpCost):
......@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseSubGradOpCost(CompOpCost):
......@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EqualOpCost(CompOpCost):
......@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EmbeddingOpCost(CompOpCost):
......@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class EmbeddingGradOpCost(CompOpCost):
......@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FillConstantOpCost(CompOpCost):
......@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FillConstantBatchSizeLikeOpCost(CompOpCost):
......@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
......@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
......@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class GatherOpCost(CompOpCost):
......@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class GeluOpCost(CompOpCost):
......@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class GeluGradOpCost(CompOpCost):
......@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class GreaterEqualOpCost(CompOpCost):
......@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class IncrementOpCost(CompOpCost):
......@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class IsEmptyOpCost(CompOpCost):
......@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LayerNormOpCost(CompOpCost):
......@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LayerNormGradOpCost(CompOpCost):
......@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LessThanOpCost(CompOpCost):
......@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LogicalNotOpCost(CompOpCost):
......@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LogicalAndOpCost(CompOpCost):
......@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LodResetOpCost(CompOpCost):
......@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LogOpCost(CompOpCost):
......@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LookupTableV2OpCost(CompOpCost):
......@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class LookupTableV2GradOpCost(CompOpCost):
......@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MatmulOpCost(CompOpCost):
......@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MatmulGradOpCost(CompOpCost):
......@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MatmulV2OpCost(CompOpCost):
......@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MatmulV2GradOpCost(CompOpCost):
......@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MemcpyOpCost(CompOpCost):
......@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MulOpCost(CompOpCost):
......@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class MulGradOpCost(CompOpCost):
......@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class OneHotOpCost(CompOpCost):
......@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ReadFromArrayOpCost(CompOpCost):
......@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ReduceSumOpCost(CompOpCost):
......@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ReduceSumGradOpCost(CompOpCost):
......@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Reshape2OpCost(CompOpCost):
......@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Reshape2GradOpCost(CompOpCost):
......@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ReduceMeanOpCost(CompOpCost):
......@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ReduceMeanGradOpCost(CompOpCost):
......@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SamplingIdOpCost(CompOpCost):
......@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ScaleOpCost(CompOpCost):
......@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SliceOpCost(CompOpCost):
......@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SoftmaxOpCost(CompOpCost):
......@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SoftmaxGradOpCost(CompOpCost):
......@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SoftmaxWithCrossEntropyOpCost(CompOpCost):
......@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
......@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SplitOpCost(CompOpCost):
......@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Squeeze2OpCost(CompOpCost):
......@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SquareOpCost(CompOpCost):
......@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SquareGradOpCost(CompOpCost):
......@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class SumOpCost(CompOpCost):
......@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class TopKOpCost(CompOpCost):
......@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Transpose2OpCost(CompOpCost):
......@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Transpose2GradOpCost(CompOpCost):
......@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class Unsqueeze2OpCost(CompOpCost):
......@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class WriteToArrayOpCost(CompOpCost):
......@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
......@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.process_ids
......@@ -225,6 +228,8 @@ class CostEstimator:
for rank in group_ranks:
self.local_cost(rank).time = (
max_time + comm_op_cost.time
if op.attr('op_role') != OpRole.Backward
else max_time + 0.9 * comm_op_cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
......@@ -290,6 +295,7 @@ class CostEstimator:
self._ordered_ops.append([op.desc.id(), op])
self._ordered_ops.sort(key=lambda x: x[0])
parameters = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader",
......@@ -298,11 +304,14 @@ class CostEstimator:
]:
continue
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(
......@@ -311,6 +320,10 @@ class CostEstimator:
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# It is even partition now
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name)
global_sizes = var.shape
......@@ -324,9 +337,16 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if var.persistable:
name = var_name + key
if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
......@@ -339,6 +359,10 @@ class CostEstimator:
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_output(var_name)
global_sizes = var.shape
......@@ -352,11 +376,19 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if var.persistable:
name = var_name + key
if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
has_used_vars = set()
not_calc_vars = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader",
......@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories = {}
can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
......@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# Not used
if var_name + key not in has_used_vars:
if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
if has_used_var in not_calc_vars:
continue
has_used_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
if (
has_used_var not in can_free_vars
and not var.persistable
):
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name
][key]["memory"]
can_free_memories[process] += var_info[var_name][
key
]["memory"]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
......@@ -406,25 +446,36 @@ class CostEstimator:
)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
if (
op.type == "reshape2"
or op.type == "transpose2"
or op.type == "elementwise_add"
):
not_calc_vars.add(has_used_var)
continue
# Not used
if var_name + key not in has_used_vars:
if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
has_used_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
if (
has_used_var not in can_free_vars
and not var.persistable
):
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name
][key]["memory"]
can_free_memories[process] += var_info[var_name][
key
]["memory"]
# Calc peak memory
for process in memories:
......@@ -433,7 +484,6 @@ class CostEstimator:
else:
if memories[process] > self.max_memories[process]:
self.max_memories[process] = memories[process]
# Free memory
for process in can_free_memories:
if process in memories:
......@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
header = ["Execution Time(us)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in vals + header:
if len(str(memory)) > max_len:
......
......@@ -2716,6 +2716,8 @@ class Resharder:
)
# simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_tensor or not dist_op:
return reshard_op_cost
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name
)
......
......@@ -16,17 +16,31 @@ import copy
import logging
import math
import os
import pickle
import sys
import time
from abc import abstractmethod
from collections import OrderedDict
from functools import reduce
import numpy as np
import paddle
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op,
print_program_with_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name
......@@ -1610,3 +1624,603 @@ class RuleBasedTuner:
idx
][parallelism][key]
self._complete_sub_bwd_program(sub_program_dist_context)
def _complete_sub_update_program(self, sub_program_dist_context):
"""
Complete the opt OP according to the tensor.
Most of the logic is the same as the update completion in the completer.
"""
world_ranks = ProcessMesh(
[
i
for i in range(
self._cluster.get_num_machines()
* self._cluster._num_devices_per_machine
)
]
)
dist_tensors = sub_program_dist_context._dist_tensors_for_program
vars = self.full_main_program.global_block().vars
ops = self.full_main_program.global_block().ops
learning_rate_completed = False
for idx in range(len(ops)):
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if is_gradient_clip_op(op):
if op.type in [
"sum",
"sqrt",
"fill_constant",
"elementwise_max",
"elementwise_div",
]:
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = world_ranks
for in_name in op.input_arg_names:
in_var = vars[in_name]
if in_var.desc.original_id() in dist_tensors:
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr
)
else:
in_dist_attr = TensorDistAttr()
in_dist_attr.process_mesh = world_ranks
in_dist_attr.dims_mapping = [
-1 for _ in range(len(in_var.shape))
]
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr
)
sub_program_dist_context.set_tensor_dist_attr_for_program(
in_var, in_dist_attr
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
if out_var.desc.original_id() in dist_tensors:
out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr
)
else:
out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = world_ranks
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
else:
in_var = vars[op.input("X")[0]]
if in_var.desc.original_id() in dist_tensors:
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
assert in_dist_attr is not None
ref_process_mesh = in_dist_attr.process_mesh
ref_dims_mapping = in_dist_attr.dims_mapping
if (
op.type == "cast"
and ops[idx + 1].type == "elementwise_mul"
):
ref_var = vars[ops[idx + 1].input("X")[0]]
ref_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
ref_var
)
assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh
out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = ref_process_mesh
if out_var.shape == in_var.shape:
out_dist_attr.dims_mapping = ref_dims_mapping
else:
assert (
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh
for in_name in op.input_arg_names:
in_var = vars[in_name]
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr.dims_mapping
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr.dims_mapping
)
op_dist_attr.set_input_dist_attr(
in_var.name, in_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_var.name, out_dist_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
else:
continue
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert (
len(op.input("Param")) == 1
), "Only support one-to-one now."
assert (
len(op.input("Grad")) == 1
), "Only support one-to-one now."
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
if param.desc.original_id() in dist_tensors:
param_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
)
assert param_dist_attr is not None
ref_process_mesh = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
).process_mesh
assert ref_process_mesh is not None
ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
).dims_mapping
assert ref_dims_mapping is not None
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping(
grad_var.name, ref_dims_mapping
)
op_dist_attr.set_input_dims_mapping(
param.name, ref_dims_mapping
)
op_dist_attr.set_output_dims_mapping(
param.name, ref_dims_mapping
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
)
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1]
sub_program_dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
for input_name in op.desc.input_names():
if input_name in [
'Param',
'Grad',
'LearningRate',
"SkipUpdate",
"Beta1Tensor",
"Beta2Tensor",
"EpsilonTensor",
]:
continue
if len(op.desc.input(input_name)) == 0:
continue
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistAttr()
if (
"Beta1Pow" in input_name
or "Beta2Pow" in input_name
):
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(
input_var.name, [-1]
)
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]
)
else:
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping
)
op_dist_attr.set_output_dims_mapping(
input_var.name, ref_dims_mapping
)
input_var_attr.process_mesh = ref_process_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
continue
else:
continue
def complete_sub_update_programs(self):
for idx in self.sub_programs_dist_context:
for parallelism in self.sub_programs_dist_context[idx]:
for key in self.sub_programs_dist_context[idx][parallelism]:
sub_program_dist_context = self.sub_programs_dist_context[
idx
][parallelism][key]
self._complete_sub_update_program(sub_program_dist_context)
def convert_device_mesh_to_key(self, device_mesh):
"""Convert device mesh object to str."""
processes = ",".join([str(x) for x in device_mesh.device_ids])
topology = ",".join([str(x) for x in device_mesh.shape])
key = processes + ";" + topology
return key
def _get_sub_program_cost(self, dist_context):
"""Estimate the cost of dist context."""
cost_estimator = CostEstimator(self.full_main_program, self._cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
return global_cost.time, max_memory
def combine_dist_contexts(self, dist_contexts):
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context = DistributedContext()
# set dist tensor, pay attention to shared param or var as input for multi op
for dist_context in dist_contexts:
for tensor_id in dist_context._dist_tensors_for_program:
dist_tensor = dist_context._dist_tensors_for_program[tensor_id]
if (
tensor_id
not in combined_dist_context._dist_tensors_for_program
):
combined_dist_context.add_dist_tensor_for_program(
dist_tensor
)
# set dist op
for op_id in dist_context._dist_ops_for_program:
dist_op = dist_context._dist_ops_for_program[op_id]
combined_dist_context.add_dist_op_for_program(dist_op)
for process_mesh in dist_context.process_meshes:
combined_dist_context.add_process_mesh(process_mesh)
return combined_dist_context
def prepare(self):
"""Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need."""
# step1: cluster operators to layers
begin = time.time()
self.layers = self.cluster_operators()
end = time.time()
self._logger.info(
"Cluster operators to {} layers in {}s.".format(
len(self.layers), end - begin
)
)
# step2: generate sub program of each layer
begin = time.time()
self.gen_fwd_sub_programs_by_clone()
end = time.time()
self._logger.info(
"Generate programs of every layer in {}s.".format(end - begin)
)
# step3: partition devices to device meshes
begin = time.time()
n, m = (
self._cluster.get_num_machines(),
self._cluster._num_devices_per_machine,
)
device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m)
end = time.time()
self._logger.info("Partition cluster in {}s.".format(end - begin))
# step4: transform device mesh to process meshes
dm_idx = 0
for device_meshes in device_meshes_list:
has_used_devices = 0
self.device_meshes_list.append([])
for device_mesh in device_meshes:
devices = reduce(lambda x, y: x * y, device_mesh)
processes = [
i
for i in range(has_used_devices, has_used_devices + devices)
]
device_mesh_shape = (
device_mesh
if device_mesh[0] != 1
else [device_mesh[i] for i in range(1, len(device_mesh))]
)
self.device_meshes_list[-1].append(
DeviceMesh(
mesh=np.array(processes)
.reshape(device_mesh_shape)
.tolist(),
name="device_mesh_" + str(dm_idx),
)
)
dm_idx += 1
has_used_devices += devices
process_mesh_shapes = convert_to_process_meshes(device_mesh)
for process_mesh_shape in process_mesh_shapes:
process_mesh = ProcessMesh(
np.array(processes).reshape(process_mesh_shape).tolist()
)
if process_mesh not in self.process_meshes:
self.process_meshes.append(process_mesh)
# step5: generate full program
begin = time.time()
self.gen_full_program()
end = time.time()
self._logger.info("Generate full program in {}s.".format(end - begin))
# step6: complete forward sub programs
begin = time.time()
for process_mesh in self.process_meshes:
self.complete_sub_fwd_programs(process_mesh)
end = time.time()
self._logger.info(
"Complete all sub forward programs in {}s.".format(end - begin)
)
if self.mode == "train":
# step7: complete backward sub programs
begin = time.time()
self.complete_sub_bwd_programs()
end = time.time()
self._logger.info(
"Complete all sub backward programs in {}s.".format(end - begin)
)
# step8: complete update sub programs
begin = time.time()
self.complete_sub_update_programs()
end = time.time()
self._logger.info(
"Complete all sub update programs in {}s.".format(end - begin)
)
def tune_o1(self):
"""The o1 level tuning."""
best_cost = sys.maxsize
best_dist_context = None
for device_meshes in self.device_meshes_list:
pp_stages = len(device_meshes)
average_layers = len(self.layers) // pp_stages
device_mesh_shape = device_meshes[0].shape
if len(device_mesh_shape) == 1:
device_mesh_shape.insert(0, 1)
process_mesh_shapes = convert_to_process_meshes(device_mesh_shape)
# For example, device_mesh is [1, 8] and process_mesh is [8].
# The selective parallelism is dp or mp
# Get dp8 or mp8 cost and compare them to get best sreategy.
for parallelism in ["dp", "mp", "dp_mp", "mp_dp"]:
for process_mesh_shape in process_mesh_shapes:
dist_context_of_device_meshes = None
for idx, device_mesh in enumerate(device_meshes):
device_mesh_shape = device_mesh.shape
process_mesh = ProcessMesh(
np.array(device_mesh.device_ids)
.reshape(process_mesh_shape)
.tolist()
)
selective_parallelisms = (
["dp", "mp"]
if len(process_mesh.shape) == 1
else ["dp_mp", "mp_dp"]
)
if parallelism not in selective_parallelisms:
total_cost_of_device_meshes = sys.maxsize
continue
key = self.convert_process_mesh_to_key(process_mesh)
if idx == len(device_meshes) - 1:
start = idx * average_layers
end = len(self.layers)
else:
start = idx * average_layers
end = (idx + 1) * average_layers
dist_context = self.combine_dist_contexts(
[
self.sub_programs_dist_context[j][parallelism][
key
]
for j in range(start, end)
]
)
dist_context_of_device_meshes = (
dist_context
if dist_context_of_device_meshes is None
else self.combine_dist_contexts(
[dist_context_of_device_meshes, dist_context]
)
)
if dist_context_of_device_meshes is not None:
cost, memory = self._get_sub_program_cost(
dist_context_of_device_meshes
)
self._logger.info(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format(
memory / (1024**3),
cost,
parallelism,
process_mesh_shape,
len(device_meshes),
)
)
# 15% buffer is reserved for memory cost
if memory > 0.85 * self.cluster.machines[0].devices[
0
].memory * (1024**3):
cost = sys.maxsize
if cost < best_cost:
best_cost = cost
best_dist_context = dist_context_of_device_meshes
self._logger.info(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format(
parallelism,
process_mesh_shape,
len(device_meshes),
memory / (1024**3),
)
)
return best_dist_context
def tune_o2(self):
return None
def save_strategy(self, best_dist_context, path):
dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []}
for key in best_dist_context._dist_tensors_for_program:
if key in self._dist_context._dist_tensors_for_program:
dist_tensor = best_dist_context._dist_tensors_for_program[key]
dist_attrs["tensor"][
key
] = dist_tensor.dist_attr.serialize_to_string()
assert dist_attrs["tensor"], "Tensor dist attrs must not be None."
for key in best_dist_context._dist_ops_for_program:
if key in self._dist_context._dist_ops_for_program:
dist_op = best_dist_context._dist_ops_for_program[key]
dist_attrs["op"][key] = dist_op.dist_attr.serialize_to_string()
assert dist_attrs["op"], "Op dist attrs must not be None."
for process_mesh in best_dist_context._process_meshes:
process_ids = process_mesh.process_ids
process_shape = process_mesh.shape
dist_attrs["process_meshes"].append([process_ids, process_shape])
dist_attrs["cluster"] = self._cluster
with open(path, 'wb') as f:
pickle.dump(dist_attrs, f)
self._logger.info("The strategy has been saved at {}".format(path))
def run_or_quit(self):
# Quit if just tune
if not self._is_run:
self._logger.info(
"The process will be quitted when just tune not run."
)
quit()
def tune(self):
begin = time.time()
self.match_program(self._dist_context.serial_main_program)
end = time.time()
self._logger.info("Pattern match in {}s.".format(end - begin))
if self._use_dp:
completer = Completer(self._dist_context)
completer.complete_forward_annotation()
print_program_with_dist_attr(
self._dist_context.serial_main_program, self._dist_context
)
# Save strategy if need
path = self._strategy_path
if path:
self.save_strategy(self._dist_context, path)
self.run_or_quit()
return
# prepare
self.prepare()
best_dist_context = None
if self.level == "o2":
best_dist_context = self.tune_o2()
elif self.level == "o1":
# If level is o1, it means all layers within same parallelism.
# When in pipeline parallism, it means that place layers evenly.
use_o2_level = False
for device_meshes in self.device_meshes_list:
if len(device_meshes) > 1:
shape = None
for device_mesh in device_meshes:
if shape is None:
shape = device_mesh.shape
continue
else:
if shape != device_mesh.shape:
self._logger.info(
"Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune."
)
use_o2_level = True
break
if use_o2_level:
best_dist_context = self.tune_o2()
else:
best_dist_context = self.tune_o1()
assert (
best_dist_context is not None
), "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding."
for key in best_dist_context._dist_tensors_for_program:
if key in self._dist_context._dist_tensors_for_program:
self._dist_context._dist_tensors_for_program[
key
] = best_dist_context._dist_tensors_for_program[key]
for key in best_dist_context._dist_ops_for_program:
if key in self._dist_context._dist_ops_for_program:
self._dist_context._dist_ops_for_program[
key
] = best_dist_context._dist_ops_for_program[key]
self._dist_context._process_meshes = best_dist_context._process_meshes
end = time.time()
self._logger.info("Rule-based tuner end in {}s.".format(end - begin))
self._logger.info("The best strategy found is as follows: ")
print_program_with_dist_attr(self.full_main_program, best_dist_context)
# Save strategy if need
path = self._strategy_path
if path:
self.save_strategy(best_dist_context, path)
self.run_or_quit()
......@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8
sequence_len = 512
vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
......@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
cluster=cluster,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
tuner.tune()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册