未验证 提交 d6011cb6 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add o1 level tune (#52041)

* add tune o1 level

* add unittest
上级 418b983c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import json import json
import os import os
import re
from enum import IntEnum, unique from enum import IntEnum, unique
import paddle import paddle
...@@ -449,7 +450,6 @@ class Cluster: ...@@ -449,7 +450,6 @@ class Cluster:
npu_models = ["NPU"] npu_models = ["NPU"]
dcu_models = ["DCU"] dcu_models = ["DCU"]
all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models
assert gpu_model in all_gpu_models
self._num_devices_per_machine = device_count self._num_devices_per_machine = device_count
def _convert_to_type(gpu_model): def _convert_to_type(gpu_model):
...@@ -462,6 +462,8 @@ class Cluster: ...@@ -462,6 +462,8 @@ class Cluster:
type = "NPU" type = "NPU"
elif gpu_model in dcu_models: elif gpu_model in dcu_models:
type = "DCU" type = "DCU"
else:
type = "GPU"
assert type is not None assert type is not None
return type return type
...@@ -470,6 +472,12 @@ class Cluster: ...@@ -470,6 +472,12 @@ class Cluster:
model = None model = None
if gpu_model == "V100": if gpu_model == "V100":
model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB" model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB"
elif gpu_model == "A100":
model = "Tesla A100-SXM-" + str(gpu_memory) + "GB"
elif gpu_model == "A30":
model = "Tesla A30-SXM-" + str(gpu_memory) + "GB"
else:
model = gpu_model + str(gpu_memory) + "GB"
assert model is not None assert model is not None
return model return model
...@@ -527,6 +535,8 @@ class Cluster: ...@@ -527,6 +535,8 @@ class Cluster:
device["memory"] = memory device["memory"] = memory
device["sp_gflops"] = sp_gflops device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops device["dp_gflops"] = dp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type global_id_to_device_type[global_id] = type
global_id_to_node[global_id] = i global_id_to_node[global_id] = i
devices.append(device) devices.append(device)
...@@ -820,30 +830,82 @@ class Cluster: ...@@ -820,30 +830,82 @@ class Cluster:
return self.__str__() return self.__str__()
def get_default_cluster(): def get_default_cluster(json_config=None):
def is_by_json_config(json_config):
if not json_config:
return False
if "cluster" not in json_config:
return False
else:
if "path" not in json_config["cluster"]:
if "num_nodes" not in json_config["cluster"]:
return False
if "num_gpus" not in json_config["cluster"]:
return False
if "gpu_model" not in json_config["cluster"]:
return False
if "gpu_memory" not in json_config["cluster"]:
return False
return True
else:
return True
cluster = Cluster() cluster = Cluster()
local_device_count = os.getenv("PADDLE_LOCAL_SIZE") if json_config and is_by_json_config(json_config):
if local_device_count is None: # Get GPU info by json config
local_device_count = 1 if "path" in json_config["cluster"]:
else: cluster.build_from_file(json_config["cluster"]["path"])
local_device_count = int(local_device_count) return cluster
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE") else:
if global_device_count is None: node_count = json_config["cluster"]["num_nodes"]
node_count = 1 local_device_count = json_config["cluster"]["num_gpus"]
gpu_model = json_config["cluster"]["gpu_model"]
memory = json_config["cluster"]["gpu_memory"]
else: else:
global_device_count = int(global_device_count) # Get GPU info by get_device_properties
assert global_device_count % local_device_count == 0 local_device_count = os.getenv("PADDLE_LOCAL_SIZE")
node_count = int(global_device_count) // local_device_count if local_device_count is None:
local_device_count = 1
else:
local_device_count = int(local_device_count)
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE")
if global_device_count is None:
node_count = 1
else:
global_device_count = int(global_device_count)
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now."
gpu_name = gpu_info.name
try:
re_result = re.split(r'[ , -]', gpu_name)
gpu_model = re_result[1]
memory = int(re_result[-1][:-2])
except:
memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name
print( print(
"Node Count: ", "Node Count: ",
node_count, node_count,
"Local Device Size: ", "Local Device Size: ",
local_device_count, local_device_count,
"GPU Model: ",
gpu_model,
"GPU Memory: ",
memory,
"World size: ", "World size: ",
paddle.distributed.get_world_size(), paddle.distributed.get_world_size(),
flush=True, flush=True,
) )
cluster.gen_default_config_cluster( cluster.gen_default_config_cluster(
node_count=node_count, device_count=local_device_count node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
) )
return cluster return cluster
...@@ -16,6 +16,7 @@ from collections import OrderedDict ...@@ -16,6 +16,7 @@ from collections import OrderedDict
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.utils.flops import flops
from ..cluster import LinkType from ..cluster import LinkType
from ..dist_tensor import DistributedTensor from ..dist_tensor import DistributedTensor
...@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc = OrderedDict() output_desc = OrderedDict()
# Get partitioned shape of input # Get partitioned shape of input
input_var_desc = {}
for input_name in op.input_names: for input_name in op.input_names:
var_name_list = op.input(input_name) var_name_list = op.input(input_name)
var_desc = [] input_var_desc[input_name] = []
for var_name in var_name_list: for var_name in var_name_list:
var = get_var_with_recursion( var = get_var_with_recursion(
var_name, op.block, op.block.program var_name, op.block, op.block.program
...@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process, process,
shard_sizes, shard_sizes,
) )
var_desc.append((var.dtype, shape)) input_var_desc[input_name].append(shape)
# For special op such as embedding and its grad op # For special op such as embedding and its grad op
if ( if (
...@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx = relative_idx * per_part_size relative_idx = relative_idx * per_part_size
desc["attrs"]["start_index"] = relative_idx desc["attrs"]["start_index"] = relative_idx
input_desc[input_name] = var_desc desc["inputs"] = input_var_desc
desc["inputs"] = input_desc
for out_name in op.output_names: for out_name in op.output_names:
var_name_list = op.output(out_name) var_name_list = op.output(out_name)
...@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): ...@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return desc return desc
def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): def build_comm_costs_from_descs(
op_cost_class, ctx, processes, descs, cluster, is_dp=False
):
"""Build comm costs by descriptions""" """Build comm costs by descriptions"""
comm_context = CommContext(cluster) comm_context = CommContext(cluster)
group_ranks_list = [] group_ranks_list = []
...@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): ...@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost = op_cost_class( comm_op_cost = op_cost_class(
op_desc=desc, comm_context=comm_context op_desc=desc, comm_context=comm_context
) )
if is_dp:
comm_op_cost.cost.time *= 0.9
comm_op_cost_list.append(comm_op_cost) comm_op_cost_list.append(comm_op_cost)
return comm_op_cost_list return comm_op_cost_list
...@@ -389,6 +394,7 @@ def build_dp_costs( ...@@ -389,6 +394,7 @@ def build_dp_costs(
vars = dist_op.serial_op.block.vars vars = dist_op.serial_op.block.vars
var_name = var_names[0] var_name = var_names[0]
has_found = False has_found = False
is_input = True
for name in dist_op.serial_op.input_arg_names: for name in dist_op.serial_op.input_arg_names:
if var_name in name: if var_name in name:
var_name = name var_name = name
...@@ -400,6 +406,7 @@ def build_dp_costs( ...@@ -400,6 +406,7 @@ def build_dp_costs(
if var_name in name: if var_name in name:
var_name = name var_name = name
has_found = True has_found = True
is_input = False
break break
if not has_found: if not has_found:
return return
...@@ -418,6 +425,7 @@ def build_dp_costs( ...@@ -418,6 +425,7 @@ def build_dp_costs(
processes, processes,
c_allreduce_sum_descs, c_allreduce_sum_descs,
cluster, cluster,
is_dp=True,
) )
result.append(comm_cost_list) result.append(comm_cost_list)
...@@ -431,22 +439,11 @@ def build_dp_costs( ...@@ -431,22 +439,11 @@ def build_dp_costs(
desc = {} desc = {}
desc["op"] = op_type desc["op"] = op_type
desc["inputs"] = {} desc["inputs"] = {}
if var_name in dist_attr.inputs_dist_attrs: dims_mapping = (
dims_mapping = dist_attr.get_input_dims_mapping(var_name) dist_attr.get_input_dims_mapping(var_name)
elif var_name in dist_attr.outputs_dist_attrs: if is_input
dims_mapping = dist_attr.get_output_dims_mapping(var_name) else dist_attr.get_output_dims_mapping(var_name)
else: )
raise AssertionError(
"cannot find dims_mapping for {} in {}".format(
var_name, dist_attr
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var = get_var_with_recursion( var = get_var_with_recursion(
var_name, var_name,
dist_op.serial_op.block, dist_op.serial_op.block,
...@@ -493,8 +490,6 @@ class CommContext: ...@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default # if cluster has no info about those vars, it will be set by default
self.base_ring = None self.base_ring = None
self.base_tree = None self.base_tree = None
# self.base_inter_ring = None
# self.base_inter_tree = None
self.intra_ring = None self.intra_ring = None
self.intra_tree = None self.intra_tree = None
self.inter_ring = None self.inter_ring = None
...@@ -508,8 +503,6 @@ class CommContext: ...@@ -508,8 +503,6 @@ class CommContext:
# set default # set default
self.base_ring = 8.4 self.base_ring = 8.4
self.base_tree = 0.0 self.base_tree = 0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default # NVL in default
self.intra_ring = 3.4 self.intra_ring = 3.4
self.intra_tree = 28 self.intra_tree = 28
...@@ -681,6 +674,8 @@ class Cost: ...@@ -681,6 +674,8 @@ class Cost:
class OpCost: class OpCost:
OP_TYPE = "op"
def __init__(self, op=None, op_desc=None): def __init__(self, op=None, op_desc=None):
self._op = op self._op = op
self._op_desc = op_desc self._op_desc = op_desc
...@@ -883,6 +878,24 @@ class CompOpCost(OpCost): ...@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
) )
) )
def calc_flops(self):
if not self.op_desc:
return 0
if "_grad" in self.__class__.OP_TYPE:
op_type = self.__class__.OP_TYPE[: len(self.__class__.OP_TYPE) - 5]
return 2 * flops(
op_type, self.op_desc["inputs"], self.op_desc["attrs"]
)
return flops(
self.__class__.OP_TYPE,
self.op_desc["inputs"],
self.op_desc["attrs"],
)
def calc_time(self):
flops_count = self.calc_flops()
return flops_count * 2.9e-7
def register_op_cost(cls): def register_op_cost(cls):
op_type = cls.OP_TYPE op_type = cls.OP_TYPE
......
...@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost): ...@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context) super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self): def calc_time(self):
return 0 return self.comm_count * 1 / (144 * 1e3)
@register_op_cost @register_op_cost
......
...@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost): ...@@ -22,15 +22,6 @@ class AdamOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ArgsortOpCost(CompOpCost): class ArgsortOpCost(CompOpCost):
...@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost): ...@@ -39,15 +30,6 @@ class ArgsortOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class AssignOpCost(CompOpCost): class AssignOpCost(CompOpCost):
...@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost): ...@@ -56,15 +38,6 @@ class AssignOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class AssignValueOpCost(CompOpCost): class AssignValueOpCost(CompOpCost):
...@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost): ...@@ -73,15 +46,6 @@ class AssignValueOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class BeamSearchOpCost(CompOpCost): class BeamSearchOpCost(CompOpCost):
...@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost): ...@@ -90,15 +54,6 @@ class BeamSearchOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class BeamSearchDecodeOpCost(CompOpCost): class BeamSearchDecodeOpCost(CompOpCost):
...@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost): ...@@ -107,15 +62,6 @@ class BeamSearchDecodeOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class CastOpCost(CompOpCost): class CastOpCost(CompOpCost):
...@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost): ...@@ -124,15 +70,6 @@ class CastOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ConcatOpCost(CompOpCost): class ConcatOpCost(CompOpCost):
...@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost): ...@@ -141,15 +78,6 @@ class ConcatOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class DropoutOpCost(CompOpCost): class DropoutOpCost(CompOpCost):
...@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost): ...@@ -158,15 +86,6 @@ class DropoutOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class DropoutGradOpCost(CompOpCost): class DropoutGradOpCost(CompOpCost):
...@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost): ...@@ -175,15 +94,6 @@ class DropoutGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseAddOpCost(CompOpCost): class ElementwiseAddOpCost(CompOpCost):
...@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost): ...@@ -192,15 +102,6 @@ class ElementwiseAddOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseAddGradOpCost(CompOpCost): class ElementwiseAddGradOpCost(CompOpCost):
...@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost): ...@@ -209,15 +110,6 @@ class ElementwiseAddGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseDivOpCost(CompOpCost): class ElementwiseDivOpCost(CompOpCost):
...@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost): ...@@ -226,15 +118,6 @@ class ElementwiseDivOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseDivGradOpCost(CompOpCost): class ElementwiseDivGradOpCost(CompOpCost):
...@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost): ...@@ -243,15 +126,6 @@ class ElementwiseDivGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseMulOpCost(CompOpCost): class ElementwiseMulOpCost(CompOpCost):
...@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost): ...@@ -260,15 +134,6 @@ class ElementwiseMulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseMulGradOpCost(CompOpCost): class ElementwiseMulGradOpCost(CompOpCost):
...@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost): ...@@ -277,15 +142,6 @@ class ElementwiseMulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseSubOpCost(CompOpCost): class ElementwiseSubOpCost(CompOpCost):
...@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost): ...@@ -294,15 +150,6 @@ class ElementwiseSubOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseSubGradOpCost(CompOpCost): class ElementwiseSubGradOpCost(CompOpCost):
...@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost): ...@@ -311,15 +158,6 @@ class ElementwiseSubGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class EqualOpCost(CompOpCost): class EqualOpCost(CompOpCost):
...@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost): ...@@ -328,15 +166,6 @@ class EqualOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class EmbeddingOpCost(CompOpCost): class EmbeddingOpCost(CompOpCost):
...@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost): ...@@ -345,15 +174,6 @@ class EmbeddingOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class EmbeddingGradOpCost(CompOpCost): class EmbeddingGradOpCost(CompOpCost):
...@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost): ...@@ -362,15 +182,6 @@ class EmbeddingGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class FillConstantOpCost(CompOpCost): class FillConstantOpCost(CompOpCost):
...@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost): ...@@ -379,15 +190,6 @@ class FillConstantOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class FillConstantBatchSizeLikeOpCost(CompOpCost): class FillConstantBatchSizeLikeOpCost(CompOpCost):
...@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): ...@@ -396,15 +198,6 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
...@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost): ...@@ -413,15 +206,6 @@ class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
...@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost): ...@@ -430,15 +214,6 @@ class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GatherOpCost(CompOpCost): class GatherOpCost(CompOpCost):
...@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost): ...@@ -447,15 +222,6 @@ class GatherOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GeluOpCost(CompOpCost): class GeluOpCost(CompOpCost):
...@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost): ...@@ -464,15 +230,6 @@ class GeluOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GeluGradOpCost(CompOpCost): class GeluGradOpCost(CompOpCost):
...@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost): ...@@ -481,15 +238,6 @@ class GeluGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GreaterEqualOpCost(CompOpCost): class GreaterEqualOpCost(CompOpCost):
...@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost): ...@@ -498,15 +246,6 @@ class GreaterEqualOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class IncrementOpCost(CompOpCost): class IncrementOpCost(CompOpCost):
...@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost): ...@@ -515,11 +254,6 @@ class IncrementOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class IsEmptyOpCost(CompOpCost): class IsEmptyOpCost(CompOpCost):
...@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost): ...@@ -528,11 +262,6 @@ class IsEmptyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LayerNormOpCost(CompOpCost): class LayerNormOpCost(CompOpCost):
...@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost): ...@@ -541,15 +270,6 @@ class LayerNormOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LayerNormGradOpCost(CompOpCost): class LayerNormGradOpCost(CompOpCost):
...@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost): ...@@ -558,15 +278,6 @@ class LayerNormGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LessThanOpCost(CompOpCost): class LessThanOpCost(CompOpCost):
...@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost): ...@@ -575,15 +286,6 @@ class LessThanOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LogicalNotOpCost(CompOpCost): class LogicalNotOpCost(CompOpCost):
...@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost): ...@@ -592,15 +294,6 @@ class LogicalNotOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LogicalAndOpCost(CompOpCost): class LogicalAndOpCost(CompOpCost):
...@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost): ...@@ -609,15 +302,6 @@ class LogicalAndOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LodResetOpCost(CompOpCost): class LodResetOpCost(CompOpCost):
...@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost): ...@@ -626,15 +310,6 @@ class LodResetOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LogOpCost(CompOpCost): class LogOpCost(CompOpCost):
...@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost): ...@@ -643,15 +318,6 @@ class LogOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LookupTableV2OpCost(CompOpCost): class LookupTableV2OpCost(CompOpCost):
...@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost): ...@@ -660,15 +326,6 @@ class LookupTableV2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class LookupTableV2GradOpCost(CompOpCost): class LookupTableV2GradOpCost(CompOpCost):
...@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost): ...@@ -677,15 +334,6 @@ class LookupTableV2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MatmulOpCost(CompOpCost): class MatmulOpCost(CompOpCost):
...@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost): ...@@ -694,15 +342,6 @@ class MatmulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MatmulGradOpCost(CompOpCost): class MatmulGradOpCost(CompOpCost):
...@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost): ...@@ -711,15 +350,6 @@ class MatmulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MatmulV2OpCost(CompOpCost): class MatmulV2OpCost(CompOpCost):
...@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost): ...@@ -728,15 +358,6 @@ class MatmulV2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MatmulV2GradOpCost(CompOpCost): class MatmulV2GradOpCost(CompOpCost):
...@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost): ...@@ -745,15 +366,6 @@ class MatmulV2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MemcpyOpCost(CompOpCost): class MemcpyOpCost(CompOpCost):
...@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost): ...@@ -762,15 +374,6 @@ class MemcpyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MulOpCost(CompOpCost): class MulOpCost(CompOpCost):
...@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost): ...@@ -779,15 +382,6 @@ class MulOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class MulGradOpCost(CompOpCost): class MulGradOpCost(CompOpCost):
...@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost): ...@@ -796,15 +390,6 @@ class MulGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class OneHotOpCost(CompOpCost): class OneHotOpCost(CompOpCost):
...@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost): ...@@ -813,15 +398,6 @@ class OneHotOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ReadFromArrayOpCost(CompOpCost): class ReadFromArrayOpCost(CompOpCost):
...@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost): ...@@ -830,15 +406,6 @@ class ReadFromArrayOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ReduceSumOpCost(CompOpCost): class ReduceSumOpCost(CompOpCost):
...@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost): ...@@ -847,15 +414,6 @@ class ReduceSumOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ReduceSumGradOpCost(CompOpCost): class ReduceSumGradOpCost(CompOpCost):
...@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost): ...@@ -864,15 +422,6 @@ class ReduceSumGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Reshape2OpCost(CompOpCost): class Reshape2OpCost(CompOpCost):
...@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost): ...@@ -881,15 +430,6 @@ class Reshape2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Reshape2GradOpCost(CompOpCost): class Reshape2GradOpCost(CompOpCost):
...@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost): ...@@ -898,15 +438,6 @@ class Reshape2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ReduceMeanOpCost(CompOpCost): class ReduceMeanOpCost(CompOpCost):
...@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost): ...@@ -915,15 +446,6 @@ class ReduceMeanOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ReduceMeanGradOpCost(CompOpCost): class ReduceMeanGradOpCost(CompOpCost):
...@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost): ...@@ -932,15 +454,6 @@ class ReduceMeanGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SamplingIdOpCost(CompOpCost): class SamplingIdOpCost(CompOpCost):
...@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost): ...@@ -949,15 +462,6 @@ class SamplingIdOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ScaleOpCost(CompOpCost): class ScaleOpCost(CompOpCost):
...@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost): ...@@ -966,15 +470,6 @@ class ScaleOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SliceOpCost(CompOpCost): class SliceOpCost(CompOpCost):
...@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost): ...@@ -983,15 +478,6 @@ class SliceOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SoftmaxOpCost(CompOpCost): class SoftmaxOpCost(CompOpCost):
...@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost): ...@@ -1000,15 +486,6 @@ class SoftmaxOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SoftmaxGradOpCost(CompOpCost): class SoftmaxGradOpCost(CompOpCost):
...@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost): ...@@ -1017,15 +494,6 @@ class SoftmaxGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SoftmaxWithCrossEntropyOpCost(CompOpCost): class SoftmaxWithCrossEntropyOpCost(CompOpCost):
...@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost): ...@@ -1034,15 +502,6 @@ class SoftmaxWithCrossEntropyOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
...@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost): ...@@ -1051,15 +510,6 @@ class SoftmaxWithCrossEntropyGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SplitOpCost(CompOpCost): class SplitOpCost(CompOpCost):
...@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost): ...@@ -1068,15 +518,6 @@ class SplitOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Squeeze2OpCost(CompOpCost): class Squeeze2OpCost(CompOpCost):
...@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost): ...@@ -1085,15 +526,6 @@ class Squeeze2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SquareOpCost(CompOpCost): class SquareOpCost(CompOpCost):
...@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost): ...@@ -1102,15 +534,6 @@ class SquareOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SquareGradOpCost(CompOpCost): class SquareGradOpCost(CompOpCost):
...@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost): ...@@ -1119,15 +542,6 @@ class SquareGradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class SumOpCost(CompOpCost): class SumOpCost(CompOpCost):
...@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost): ...@@ -1136,15 +550,6 @@ class SumOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class TopKOpCost(CompOpCost): class TopKOpCost(CompOpCost):
...@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost): ...@@ -1153,15 +558,6 @@ class TopKOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Transpose2OpCost(CompOpCost): class Transpose2OpCost(CompOpCost):
...@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost): ...@@ -1170,15 +566,6 @@ class Transpose2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Transpose2GradOpCost(CompOpCost): class Transpose2GradOpCost(CompOpCost):
...@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost): ...@@ -1187,15 +574,6 @@ class Transpose2GradOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class Unsqueeze2OpCost(CompOpCost): class Unsqueeze2OpCost(CompOpCost):
...@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost): ...@@ -1204,15 +582,6 @@ class Unsqueeze2OpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class WriteToArrayOpCost(CompOpCost): class WriteToArrayOpCost(CompOpCost):
...@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost): ...@@ -1220,12 +589,3 @@ class WriteToArrayOpCost(CompOpCost):
def __init__(self, op=None, op_desc=None, cluster=None): def __init__(self, op=None, op_desc=None, cluster=None):
super().__init__(op=op, op_desc=op_desc, cluster=cluster) super().__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
...@@ -189,6 +189,9 @@ class CostEstimator: ...@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost # Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.process_ids processes = op_dist_attr.process_mesh.process_ids
...@@ -225,6 +228,8 @@ class CostEstimator: ...@@ -225,6 +228,8 @@ class CostEstimator:
for rank in group_ranks: for rank in group_ranks:
self.local_cost(rank).time = ( self.local_cost(rank).time = (
max_time + comm_op_cost.time max_time + comm_op_cost.time
if op.attr('op_role') != OpRole.Backward
else max_time + 0.9 * comm_op_cost.time
) )
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
...@@ -290,6 +295,7 @@ class CostEstimator: ...@@ -290,6 +295,7 @@ class CostEstimator:
self._ordered_ops.append([op.desc.id(), op]) self._ordered_ops.append([op.desc.id(), op])
self._ordered_ops.sort(key=lambda x: x[0]) self._ordered_ops.sort(key=lambda x: x[0])
parameters = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_py_reader",
...@@ -298,11 +304,14 @@ class CostEstimator: ...@@ -298,11 +304,14 @@ class CostEstimator:
]: ]:
continue continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name var_name
) )
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str( key = _convert_pm_and_dm_to_str(
...@@ -311,6 +320,10 @@ class CostEstimator: ...@@ -311,6 +320,10 @@ class CostEstimator:
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# It is even partition now # It is even partition now
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -324,9 +337,16 @@ class CostEstimator: ...@@ -324,9 +337,16 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
) )
if "position" not in var_info[var_name][key]: if var.persistable:
var_info[var_name][key]["position"] = [] name = var_name + key
var_info[var_name][key]["position"].append(op_id) if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
...@@ -339,6 +359,10 @@ class CostEstimator: ...@@ -339,6 +359,10 @@ class CostEstimator:
) )
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -352,11 +376,19 @@ class CostEstimator: ...@@ -352,11 +376,19 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
) )
if "position" not in var_info[var_name][key]: if var.persistable:
var_info[var_name][key]["position"] = [] name = var_name + key
var_info[var_name][key]["position"].append(op_id) if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
has_used_vars = set() has_used_vars = set()
not_calc_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_py_reader",
...@@ -367,6 +399,8 @@ class CostEstimator: ...@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories = {} can_free_memories = {}
can_free_vars = set() can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
...@@ -378,24 +412,30 @@ class CostEstimator: ...@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# Not used # Not used
if var_name + key not in has_used_vars: if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
if has_used_var in not_calc_vars:
continue
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.process_ids: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# Used # Used
else: if op_id == var_info[var_name][key]["position"][-1]:
if op_id == var_info[var_name][key]["position"][-1]: if (
if has_used_var not in can_free_vars: has_used_var not in can_free_vars
can_free_vars.add(has_used_var) and not var.persistable
if not var.persistable: ):
for process in process_mesh.process_ids: can_free_vars.add(has_used_var)
if process not in can_free_memories: for process in process_mesh.process_ids:
can_free_memories[process] = 0 if process not in can_free_memories:
can_free_memories[process] += var_info[ can_free_memories[process] = 0
var_name can_free_memories[process] += var_info[var_name][
][key]["memory"] key
]["memory"]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
...@@ -406,25 +446,36 @@ class CostEstimator: ...@@ -406,25 +446,36 @@ class CostEstimator:
) )
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
if (
op.type == "reshape2"
or op.type == "transpose2"
or op.type == "elementwise_add"
):
not_calc_vars.add(has_used_var)
continue
# Not used # Not used
if var_name + key not in has_used_vars: if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.process_ids: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# Used # Used
else: if op_id == var_info[var_name][key]["position"][-1]:
if op_id == var_info[var_name][key]["position"][-1]: if (
if has_used_var not in can_free_vars: has_used_var not in can_free_vars
can_free_vars.add(has_used_var) and not var.persistable
if not var.persistable: ):
for process in process_mesh.process_ids: can_free_vars.add(has_used_var)
if process not in can_free_memories: for process in process_mesh.process_ids:
can_free_memories[process] = 0 if process not in can_free_memories:
can_free_memories[process] += var_info[ can_free_memories[process] = 0
var_name can_free_memories[process] += var_info[var_name][
][key]["memory"] key
]["memory"]
# Calc peak memory # Calc peak memory
for process in memories: for process in memories:
...@@ -433,7 +484,6 @@ class CostEstimator: ...@@ -433,7 +484,6 @@ class CostEstimator:
else: else:
if memories[process] > self.max_memories[process]: if memories[process] > self.max_memories[process]:
self.max_memories[process] = memories[process] self.max_memories[process] = memories[process]
# Free memory # Free memory
for process in can_free_memories: for process in can_free_memories:
if process in memories: if process in memories:
...@@ -513,7 +563,7 @@ class CostEstimator: ...@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically # Padding automatically
max_len = 0 max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"] header = ["Execution Time(us)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in vals + header: for memory in vals + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
......
...@@ -2716,6 +2716,8 @@ class Resharder: ...@@ -2716,6 +2716,8 @@ class Resharder:
) )
# simplified processing: ignore union process mesh and output reshard # simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_tensor or not dist_op:
return reshard_op_cost
dims_mapping = dist_op.dist_attr.get_input_dims_mapping( dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name tensor.name
) )
......
...@@ -16,17 +16,31 @@ import copy ...@@ -16,17 +16,31 @@ import copy
import logging import logging
import math import math
import os import os
import pickle
import sys
import time
from abc import abstractmethod from abc import abstractmethod
from collections import OrderedDict from collections import OrderedDict
from functools import reduce
import numpy as np
import paddle import paddle
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh
from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost import CostEstimator
from paddle.distributed.auto_parallel.dist_attribute import ( from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr, OperatorDistAttr,
TensorDistAttr, TensorDistAttr,
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op,
print_program_with_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name from paddle.fluid.framework import Parameter, unique_name
...@@ -1610,3 +1624,603 @@ class RuleBasedTuner: ...@@ -1610,3 +1624,603 @@ class RuleBasedTuner:
idx idx
][parallelism][key] ][parallelism][key]
self._complete_sub_bwd_program(sub_program_dist_context) self._complete_sub_bwd_program(sub_program_dist_context)
def _complete_sub_update_program(self, sub_program_dist_context):
"""
Complete the opt OP according to the tensor.
Most of the logic is the same as the update completion in the completer.
"""
world_ranks = ProcessMesh(
[
i
for i in range(
self._cluster.get_num_machines()
* self._cluster._num_devices_per_machine
)
]
)
dist_tensors = sub_program_dist_context._dist_tensors_for_program
vars = self.full_main_program.global_block().vars
ops = self.full_main_program.global_block().ops
learning_rate_completed = False
for idx in range(len(ops)):
op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize):
if is_gradient_clip_op(op):
if op.type in [
"sum",
"sqrt",
"fill_constant",
"elementwise_max",
"elementwise_div",
]:
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = world_ranks
for in_name in op.input_arg_names:
in_var = vars[in_name]
if in_var.desc.original_id() in dist_tensors:
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr
)
else:
in_dist_attr = TensorDistAttr()
in_dist_attr.process_mesh = world_ranks
in_dist_attr.dims_mapping = [
-1 for _ in range(len(in_var.shape))
]
op_dist_attr.set_input_dist_attr(
in_name, in_dist_attr
)
sub_program_dist_context.set_tensor_dist_attr_for_program(
in_var, in_dist_attr
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
if out_var.desc.original_id() in dist_tensors:
out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr
)
else:
out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = world_ranks
out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape))
]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_name, out_dist_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
else:
in_var = vars[op.input("X")[0]]
if in_var.desc.original_id() in dist_tensors:
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
assert in_dist_attr is not None
ref_process_mesh = in_dist_attr.process_mesh
ref_dims_mapping = in_dist_attr.dims_mapping
if (
op.type == "cast"
and ops[idx + 1].type == "elementwise_mul"
):
ref_var = vars[ops[idx + 1].input("X")[0]]
ref_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
ref_var
)
assert ref_dist_attr is not None
ref_process_mesh = ref_dist_attr.process_mesh
out_var = vars[op.output("Out")[0]]
out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = ref_process_mesh
if out_var.shape == in_var.shape:
out_dist_attr.dims_mapping = ref_dims_mapping
else:
assert (
len(out_var.shape) == 1
and out_var.shape[0] == 1
)
out_dist_attr.dims_mapping = [-1]
sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr
)
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh
for in_name in op.input_arg_names:
in_var = vars[in_name]
in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
in_var
)
op_dist_attr.set_input_dims_mapping(
in_name, in_dist_attr.dims_mapping
)
for out_name in op.output_arg_names:
out_var = vars[out_name]
out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
out_var
)
op_dist_attr.set_output_dims_mapping(
out_name, out_dist_attr.dims_mapping
)
op_dist_attr.set_input_dist_attr(
in_var.name, in_dist_attr
)
op_dist_attr.set_output_dist_attr(
out_var.name, out_dist_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
else:
continue
if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert (
len(op.input("Param")) == 1
), "Only support one-to-one now."
assert (
len(op.input("Grad")) == 1
), "Only support one-to-one now."
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
if param.desc.original_id() in dist_tensors:
param_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
)
assert param_dist_attr is not None
ref_process_mesh = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
).process_mesh
assert ref_process_mesh is not None
ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
param
).dims_mapping
assert ref_dims_mapping is not None
op_dist_attr = OperatorDistAttr()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping(
grad_var.name, ref_dims_mapping
)
op_dist_attr.set_input_dims_mapping(
param.name, ref_dims_mapping
)
op_dist_attr.set_output_dims_mapping(
param.name, ref_dims_mapping
)
learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1]
)
op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1]
)
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1]
sub_program_dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr
)
for input_name in op.desc.input_names():
if input_name in [
'Param',
'Grad',
'LearningRate',
"SkipUpdate",
"Beta1Tensor",
"Beta2Tensor",
"EpsilonTensor",
]:
continue
if len(op.desc.input(input_name)) == 0:
continue
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistAttr()
if (
"Beta1Pow" in input_name
or "Beta2Pow" in input_name
):
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(
input_var.name, [-1]
)
op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]
)
else:
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping
)
op_dist_attr.set_output_dims_mapping(
input_var.name, ref_dims_mapping
)
input_var_attr.process_mesh = ref_process_mesh
sub_program_dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr
)
sub_program_dist_context.set_op_dist_attr_for_program(
op, op_dist_attr
)
continue
else:
continue
def complete_sub_update_programs(self):
for idx in self.sub_programs_dist_context:
for parallelism in self.sub_programs_dist_context[idx]:
for key in self.sub_programs_dist_context[idx][parallelism]:
sub_program_dist_context = self.sub_programs_dist_context[
idx
][parallelism][key]
self._complete_sub_update_program(sub_program_dist_context)
def convert_device_mesh_to_key(self, device_mesh):
"""Convert device mesh object to str."""
processes = ",".join([str(x) for x in device_mesh.device_ids])
topology = ",".join([str(x) for x in device_mesh.shape])
key = processes + ";" + topology
return key
def _get_sub_program_cost(self, dist_context):
"""Estimate the cost of dist context."""
cost_estimator = CostEstimator(self.full_main_program, self._cluster)
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(
dist_context
)
return global_cost.time, max_memory
def combine_dist_contexts(self, dist_contexts):
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context = DistributedContext()
# set dist tensor, pay attention to shared param or var as input for multi op
for dist_context in dist_contexts:
for tensor_id in dist_context._dist_tensors_for_program:
dist_tensor = dist_context._dist_tensors_for_program[tensor_id]
if (
tensor_id
not in combined_dist_context._dist_tensors_for_program
):
combined_dist_context.add_dist_tensor_for_program(
dist_tensor
)
# set dist op
for op_id in dist_context._dist_ops_for_program:
dist_op = dist_context._dist_ops_for_program[op_id]
combined_dist_context.add_dist_op_for_program(dist_op)
for process_mesh in dist_context.process_meshes:
combined_dist_context.add_process_mesh(process_mesh)
return combined_dist_context
def prepare(self):
"""Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need."""
# step1: cluster operators to layers
begin = time.time()
self.layers = self.cluster_operators()
end = time.time()
self._logger.info(
"Cluster operators to {} layers in {}s.".format(
len(self.layers), end - begin
)
)
# step2: generate sub program of each layer
begin = time.time()
self.gen_fwd_sub_programs_by_clone()
end = time.time()
self._logger.info(
"Generate programs of every layer in {}s.".format(end - begin)
)
# step3: partition devices to device meshes
begin = time.time()
n, m = (
self._cluster.get_num_machines(),
self._cluster._num_devices_per_machine,
)
device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m)
end = time.time()
self._logger.info("Partition cluster in {}s.".format(end - begin))
# step4: transform device mesh to process meshes
dm_idx = 0
for device_meshes in device_meshes_list:
has_used_devices = 0
self.device_meshes_list.append([])
for device_mesh in device_meshes:
devices = reduce(lambda x, y: x * y, device_mesh)
processes = [
i
for i in range(has_used_devices, has_used_devices + devices)
]
device_mesh_shape = (
device_mesh
if device_mesh[0] != 1
else [device_mesh[i] for i in range(1, len(device_mesh))]
)
self.device_meshes_list[-1].append(
DeviceMesh(
mesh=np.array(processes)
.reshape(device_mesh_shape)
.tolist(),
name="device_mesh_" + str(dm_idx),
)
)
dm_idx += 1
has_used_devices += devices
process_mesh_shapes = convert_to_process_meshes(device_mesh)
for process_mesh_shape in process_mesh_shapes:
process_mesh = ProcessMesh(
np.array(processes).reshape(process_mesh_shape).tolist()
)
if process_mesh not in self.process_meshes:
self.process_meshes.append(process_mesh)
# step5: generate full program
begin = time.time()
self.gen_full_program()
end = time.time()
self._logger.info("Generate full program in {}s.".format(end - begin))
# step6: complete forward sub programs
begin = time.time()
for process_mesh in self.process_meshes:
self.complete_sub_fwd_programs(process_mesh)
end = time.time()
self._logger.info(
"Complete all sub forward programs in {}s.".format(end - begin)
)
if self.mode == "train":
# step7: complete backward sub programs
begin = time.time()
self.complete_sub_bwd_programs()
end = time.time()
self._logger.info(
"Complete all sub backward programs in {}s.".format(end - begin)
)
# step8: complete update sub programs
begin = time.time()
self.complete_sub_update_programs()
end = time.time()
self._logger.info(
"Complete all sub update programs in {}s.".format(end - begin)
)
def tune_o1(self):
"""The o1 level tuning."""
best_cost = sys.maxsize
best_dist_context = None
for device_meshes in self.device_meshes_list:
pp_stages = len(device_meshes)
average_layers = len(self.layers) // pp_stages
device_mesh_shape = device_meshes[0].shape
if len(device_mesh_shape) == 1:
device_mesh_shape.insert(0, 1)
process_mesh_shapes = convert_to_process_meshes(device_mesh_shape)
# For example, device_mesh is [1, 8] and process_mesh is [8].
# The selective parallelism is dp or mp
# Get dp8 or mp8 cost and compare them to get best sreategy.
for parallelism in ["dp", "mp", "dp_mp", "mp_dp"]:
for process_mesh_shape in process_mesh_shapes:
dist_context_of_device_meshes = None
for idx, device_mesh in enumerate(device_meshes):
device_mesh_shape = device_mesh.shape
process_mesh = ProcessMesh(
np.array(device_mesh.device_ids)
.reshape(process_mesh_shape)
.tolist()
)
selective_parallelisms = (
["dp", "mp"]
if len(process_mesh.shape) == 1
else ["dp_mp", "mp_dp"]
)
if parallelism not in selective_parallelisms:
total_cost_of_device_meshes = sys.maxsize
continue
key = self.convert_process_mesh_to_key(process_mesh)
if idx == len(device_meshes) - 1:
start = idx * average_layers
end = len(self.layers)
else:
start = idx * average_layers
end = (idx + 1) * average_layers
dist_context = self.combine_dist_contexts(
[
self.sub_programs_dist_context[j][parallelism][
key
]
for j in range(start, end)
]
)
dist_context_of_device_meshes = (
dist_context
if dist_context_of_device_meshes is None
else self.combine_dist_contexts(
[dist_context_of_device_meshes, dist_context]
)
)
if dist_context_of_device_meshes is not None:
cost, memory = self._get_sub_program_cost(
dist_context_of_device_meshes
)
self._logger.info(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format(
memory / (1024**3),
cost,
parallelism,
process_mesh_shape,
len(device_meshes),
)
)
# 15% buffer is reserved for memory cost
if memory > 0.85 * self.cluster.machines[0].devices[
0
].memory * (1024**3):
cost = sys.maxsize
if cost < best_cost:
best_cost = cost
best_dist_context = dist_context_of_device_meshes
self._logger.info(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format(
parallelism,
process_mesh_shape,
len(device_meshes),
memory / (1024**3),
)
)
return best_dist_context
def tune_o2(self):
return None
def save_strategy(self, best_dist_context, path):
dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []}
for key in best_dist_context._dist_tensors_for_program:
if key in self._dist_context._dist_tensors_for_program:
dist_tensor = best_dist_context._dist_tensors_for_program[key]
dist_attrs["tensor"][
key
] = dist_tensor.dist_attr.serialize_to_string()
assert dist_attrs["tensor"], "Tensor dist attrs must not be None."
for key in best_dist_context._dist_ops_for_program:
if key in self._dist_context._dist_ops_for_program:
dist_op = best_dist_context._dist_ops_for_program[key]
dist_attrs["op"][key] = dist_op.dist_attr.serialize_to_string()
assert dist_attrs["op"], "Op dist attrs must not be None."
for process_mesh in best_dist_context._process_meshes:
process_ids = process_mesh.process_ids
process_shape = process_mesh.shape
dist_attrs["process_meshes"].append([process_ids, process_shape])
dist_attrs["cluster"] = self._cluster
with open(path, 'wb') as f:
pickle.dump(dist_attrs, f)
self._logger.info("The strategy has been saved at {}".format(path))
def run_or_quit(self):
# Quit if just tune
if not self._is_run:
self._logger.info(
"The process will be quitted when just tune not run."
)
quit()
def tune(self):
begin = time.time()
self.match_program(self._dist_context.serial_main_program)
end = time.time()
self._logger.info("Pattern match in {}s.".format(end - begin))
if self._use_dp:
completer = Completer(self._dist_context)
completer.complete_forward_annotation()
print_program_with_dist_attr(
self._dist_context.serial_main_program, self._dist_context
)
# Save strategy if need
path = self._strategy_path
if path:
self.save_strategy(self._dist_context, path)
self.run_or_quit()
return
# prepare
self.prepare()
best_dist_context = None
if self.level == "o2":
best_dist_context = self.tune_o2()
elif self.level == "o1":
# If level is o1, it means all layers within same parallelism.
# When in pipeline parallism, it means that place layers evenly.
use_o2_level = False
for device_meshes in self.device_meshes_list:
if len(device_meshes) > 1:
shape = None
for device_mesh in device_meshes:
if shape is None:
shape = device_mesh.shape
continue
else:
if shape != device_mesh.shape:
self._logger.info(
"Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune."
)
use_o2_level = True
break
if use_o2_level:
best_dist_context = self.tune_o2()
else:
best_dist_context = self.tune_o1()
assert (
best_dist_context is not None
), "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding."
for key in best_dist_context._dist_tensors_for_program:
if key in self._dist_context._dist_tensors_for_program:
self._dist_context._dist_tensors_for_program[
key
] = best_dist_context._dist_tensors_for_program[key]
for key in best_dist_context._dist_ops_for_program:
if key in self._dist_context._dist_ops_for_program:
self._dist_context._dist_ops_for_program[
key
] = best_dist_context._dist_ops_for_program[key]
self._dist_context._process_meshes = best_dist_context._process_meshes
end = time.time()
self._logger.info("Rule-based tuner end in {}s.".format(end - begin))
self._logger.info("The best strategy found is as follows: ")
print_program_with_dist_attr(self.full_main_program, best_dist_context)
# Save strategy if need
path = self._strategy_path
if path:
self.save_strategy(best_dist_context, path)
self.run_or_quit()
...@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8 batch_size = 8
sequence_len = 512 sequence_len = 512
vocab_size = 1000 vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model( train_program, start_program, loss, gen_data = get_gpt_model(
train_program, train_program,
start_program, start_program,
...@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner, RuleBasedTuner,
) )
clip = paddle.nn.ClipGradByGlobalNorm(0.2) clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext( dist_context = DistributedContext(
serial_main_prog=train_program, serial_main_prog=train_program,
serial_startup_prog=start_program, serial_startup_prog=start_program,
serial_optimizer=opt, serial_optimizer=opt,
serial_loss=loss, serial_loss=loss,
cluster=cluster,
) )
dist_context.initialize() dist_context.initialize()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators() tuner.tune()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册