未验证 提交 d6011cb6 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add o1 level tune (#52041)

* add tune o1 level

* add unittest
上级 418b983c
......@@ -14,6 +14,7 @@
import json
import os
import re
from enum import IntEnum, unique
import paddle
......@@ -449,7 +450,6 @@ class Cluster:
npu_models = ["NPU"]
dcu_models = ["DCU"]
all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models
assert gpu_model in all_gpu_models
self._num_devices_per_machine = device_count
def _convert_to_type(gpu_model):
......@@ -462,6 +462,8 @@ class Cluster:
type = "NPU"
elif gpu_model in dcu_models:
type = "DCU"
else:
type = "GPU"
assert type is not None
return type
......@@ -470,6 +472,12 @@ class Cluster:
model = None
if gpu_model == "V100":
model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB"
elif gpu_model == "A100":
model = "Tesla A100-SXM-" + str(gpu_memory) + "GB"
elif gpu_model == "A30":
model = "Tesla A30-SXM-" + str(gpu_memory) + "GB"
else:
model = gpu_model + str(gpu_memory) + "GB"
assert model is not None
return model
......@@ -527,6 +535,8 @@ class Cluster:
device["memory"] = memory
device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type
global_id_to_node[global_id] = i
devices.append(device)
......@@ -820,30 +830,82 @@ class Cluster:
return self.__str__()
def get_default_cluster():
def get_default_cluster(json_config=None):
def is_by_json_config(json_config):
if not json_config:
return False
if "cluster" not in json_config:
return False
else:
if "path" not in json_config["cluster"]:
if "num_nodes" not in json_config["cluster"]:
return False
if "num_gpus" not in json_config["cluster"]:
return False
if "gpu_model" not in json_config["cluster"]:
return False
if "gpu_memory" not in json_config["cluster"]:
return False
return True
else:
return True
cluster = Cluster()
local_device_count = os.getenv("PADDLE_LOCAL_SIZE")
if local_device_count is None:
local_device_count = 1
else:
local_device_count = int(local_device_count)
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE")
if global_device_count is None:
node_count = 1
if json_config and is_by_json_config(json_config):
# Get GPU info by json config
if "path" in json_config["cluster"]:
cluster.build_from_file(json_config["cluster"]["path"])
return cluster
else:
node_count = json_config["cluster"]["num_nodes"]
local_device_count = json_config["cluster"]["num_gpus"]
gpu_model = json_config["cluster"]["gpu_model"]
memory = json_config["cluster"]["gpu_memory"]
else:
global_device_count = int(global_device_count)
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
# Get GPU info by get_device_properties
local_device_count = os.getenv("PADDLE_LOCAL_SIZE")
if local_device_count is None:
local_device_count = 1
else:
local_device_count = int(local_device_count)
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE")
if global_device_count is None:
node_count = 1
else:
global_device_count = int(global_device_count)
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now."
gpu_name = gpu_info.name
try:
re_result = re.split(r'[ , -]', gpu_name)
gpu_model = re_result[1]
memory = int(re_result[-1][:-2])
except:
memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name
print(
"Node Count: ",
node_count,
"Local Device Size: ",
local_device_count,
"GPU Model: ",
gpu_model,
"GPU Memory: ",
memory,
"World size: ",
paddle.distributed.get_world_size(),
flush=True,
)
cluster.gen_default_config_cluster(
node_count=node_count, device_count=local_device_count
node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
)
return cluster
......@@ -16,6 +16,7 @@ from collections import OrderedDict
from functools import reduce
import paddle
from paddle.utils.flops import flops
from ..cluster import LinkType
from ..dist_tensor import DistributedTensor
......@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc = OrderedDict()
# Get partitioned shape of input
input_var_desc = {}
for input_name in op.input_names:
var_name_list = op.input(input_name)
var_desc = []
input_var_desc[input_name] = []
for var_name in var_name_list:
var = get_var_with_recursion(
var_name, op.block, op.block.program
......@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process,
shard_sizes,
)
var_desc.append((var.dtype, shape))
input_var_desc[input_name].append(shape)
# For special op such as embedding and its grad op
if (
......@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx = relative_idx * per_part_size
desc["attrs"]["start_index"] = relative_idx
input_desc[input_name] = var_desc
desc["inputs"] = input_desc
desc["inputs"] = input_var_desc
for out_name in op.output_names:
var_name_list = op.output(out_name)
......@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return desc
def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
def build_comm_costs_from_descs(
op_cost_class, ctx, processes, descs, cluster, is_dp=False
):
"""Build comm costs by descriptions"""
comm_context = CommContext(cluster)
group_ranks_list = []
......@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost = op_cost_class(
op_desc=desc, comm_context=comm_context
)
if is_dp:
comm_op_cost.cost.time *= 0.9
comm_op_cost_list.append(comm_op_cost)
return comm_op_cost_list
......@@ -389,6 +394,7 @@ def build_dp_costs(
vars = dist_op.serial_op.block.vars
var_name = var_names[0]
has_found = False
is_input = True
for name in dist_op.serial_op.input_arg_names:
if var_name in name:
var_name = name
......@@ -400,6 +406,7 @@ def build_dp_costs(
if var_name in name:
var_name = name
has_found = True
is_input = False
break
if not has_found:
return
......@@ -418,6 +425,7 @@ def build_dp_costs(
processes,
c_allreduce_sum_descs,
cluster,
is_dp=True,
)
result.append(comm_cost_list)
......@@ -431,22 +439,11 @@ def build_dp_costs(
desc = {}
desc["op"] = op_type
desc["inputs"] = {}
if var_name in dist_attr.inputs_dist_attrs:
dims_mapping = dist_attr.get_input_dims_mapping(var_name)
elif var_name in dist_attr.outputs_dist_attrs:
dims_mapping = dist_attr.get_output_dims_mapping(var_name)
else:
raise AssertionError(
"cannot find dims_mapping for {} in {}".format(
var_name, dist_attr
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
dims_mapping = (
dist_attr.get_input_dims_mapping(var_name)
if is_input
else dist_attr.get_output_dims_mapping(var_name)
)
var = get_var_with_recursion(
var_name,
dist_op.serial_op.block,
......@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default
self.base_ring = None
self.base_tree = None
# self.base_inter_ring = None
# self.base_inter_tree = None
self.intra_ring = None
self.intra_tree = None
self.inter_ring = None
......@@ -508,8 +503,6 @@ class CommContext:
# set default
self.base_ring = 8.4
self.base_tree = 0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default
self.intra_ring = 3.4
self.intra_tree = 28
......@@ -681,6 +674,8 @@ class Cost:
class OpCost:
OP_TYPE = "op"
def __init__(self, op=None, op_desc=None):
self._op = op
self._op_desc = op_desc
......@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
)
)
def calc_flops(self):
if not self.op_desc:
return 0
if "_grad" in self.__class__.OP_TYPE:
op_type = self.__class__.OP_TYPE[: len(self.__class__.OP_TYPE) - 5]
return 2 * flops(
op_type, self.op_desc["inputs"], self.op_desc["attrs"]
)
return flops(
self.__class__.OP_TYPE,
self.op_desc["inputs"],
self.op_desc["attrs"],
)
def calc_time(self):
flops_count = self.calc_flops()
return flops_count * 2.9e-7
def register_op_cost(cls):
op_type = cls.OP_TYPE
......
......@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self):
return 0
return self.comm_count * 1 / (144 * 1e3)
@register_op_cost
......
......@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.process_ids
......@@ -225,6 +228,8 @@ class CostEstimator:
for rank in group_ranks:
self.local_cost(rank).time = (
max_time + comm_op_cost.time
if op.attr('op_role') != OpRole.Backward
else max_time + 0.9 * comm_op_cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
......@@ -290,6 +295,7 @@ class CostEstimator:
self._ordered_ops.append([op.desc.id(), op])
self._ordered_ops.sort(key=lambda x: x[0])
parameters = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader",
......@@ -298,11 +304,14 @@ class CostEstimator:
]:
continue
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(
......@@ -311,6 +320,10 @@ class CostEstimator:
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# It is even partition now
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name)
global_sizes = var.shape
......@@ -324,9 +337,16 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if var.persistable:
name = var_name + key
if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
......@@ -339,6 +359,10 @@ class CostEstimator:
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_output(var_name)
global_sizes = var.shape
......@@ -352,11 +376,19 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if var.persistable:
name = var_name + key
if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
has_used_vars = set()
not_calc_vars = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader",
......@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories = {}
can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
......@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# Not used
if var_name + key not in has_used_vars:
if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
if has_used_var in not_calc_vars:
continue
has_used_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name
][key]["memory"]
if op_id == var_info[var_name][key]["position"][-1]:
if (
has_used_var not in can_free_vars
and not var.persistable
):
can_free_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[var_name][
key
]["memory"]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
......@@ -406,25 +446,36 @@ class CostEstimator:
)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
if (
op.type == "reshape2"
or op.type == "transpose2"
or op.type == "elementwise_add"
):
not_calc_vars.add(has_used_var)
continue
# Not used
if var_name + key not in has_used_vars:
if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
has_used_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
can_free_vars.add(has_used_var)
if not var.persistable:
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name
][key]["memory"]
if op_id == var_info[var_name][key]["position"][-1]:
if (
has_used_var not in can_free_vars
and not var.persistable
):
can_free_vars.add(has_used_var)
for process in process_mesh.process_ids:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[var_name][
key
]["memory"]
# Calc peak memory
for process in memories:
......@@ -433,7 +484,6 @@ class CostEstimator:
else:
if memories[process] > self.max_memories[process]:
self.max_memories[process] = memories[process]
# Free memory
for process in can_free_memories:
if process in memories:
......@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
header = ["Execution Time(us)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in vals + header:
if len(str(memory)) > max_len:
......
......@@ -2716,6 +2716,8 @@ class Resharder:
)
# simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_tensor or not dist_op:
return reshard_op_cost
dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name
)
......
......@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8
sequence_len = 512
vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
......@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
cluster=cluster,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
tuner.tune()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册