未验证 提交 d6011cb6 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add o1 level tune (#52041)

* add tune o1 level

* add unittest
上级 418b983c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import json import json
import os import os
import re
from enum import IntEnum, unique from enum import IntEnum, unique
import paddle import paddle
...@@ -449,7 +450,6 @@ class Cluster: ...@@ -449,7 +450,6 @@ class Cluster:
npu_models = ["NPU"] npu_models = ["NPU"]
dcu_models = ["DCU"] dcu_models = ["DCU"]
all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models all_gpu_models = gpu_models + xpu_models + npu_models + dcu_models
assert gpu_model in all_gpu_models
self._num_devices_per_machine = device_count self._num_devices_per_machine = device_count
def _convert_to_type(gpu_model): def _convert_to_type(gpu_model):
...@@ -462,6 +462,8 @@ class Cluster: ...@@ -462,6 +462,8 @@ class Cluster:
type = "NPU" type = "NPU"
elif gpu_model in dcu_models: elif gpu_model in dcu_models:
type = "DCU" type = "DCU"
else:
type = "GPU"
assert type is not None assert type is not None
return type return type
...@@ -470,6 +472,12 @@ class Cluster: ...@@ -470,6 +472,12 @@ class Cluster:
model = None model = None
if gpu_model == "V100": if gpu_model == "V100":
model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB" model = "Tesla V100-SXM2-" + str(gpu_memory) + "GB"
elif gpu_model == "A100":
model = "Tesla A100-SXM-" + str(gpu_memory) + "GB"
elif gpu_model == "A30":
model = "Tesla A30-SXM-" + str(gpu_memory) + "GB"
else:
model = gpu_model + str(gpu_memory) + "GB"
assert model is not None assert model is not None
return model return model
...@@ -527,6 +535,8 @@ class Cluster: ...@@ -527,6 +535,8 @@ class Cluster:
device["memory"] = memory device["memory"] = memory
device["sp_gflops"] = sp_gflops device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops device["dp_gflops"] = dp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type global_id_to_device_type[global_id] = type
global_id_to_node[global_id] = i global_id_to_node[global_id] = i
devices.append(device) devices.append(device)
...@@ -820,30 +830,82 @@ class Cluster: ...@@ -820,30 +830,82 @@ class Cluster:
return self.__str__() return self.__str__()
def get_default_cluster(): def get_default_cluster(json_config=None):
def is_by_json_config(json_config):
if not json_config:
return False
if "cluster" not in json_config:
return False
else:
if "path" not in json_config["cluster"]:
if "num_nodes" not in json_config["cluster"]:
return False
if "num_gpus" not in json_config["cluster"]:
return False
if "gpu_model" not in json_config["cluster"]:
return False
if "gpu_memory" not in json_config["cluster"]:
return False
return True
else:
return True
cluster = Cluster() cluster = Cluster()
local_device_count = os.getenv("PADDLE_LOCAL_SIZE") if json_config and is_by_json_config(json_config):
if local_device_count is None: # Get GPU info by json config
local_device_count = 1 if "path" in json_config["cluster"]:
else: cluster.build_from_file(json_config["cluster"]["path"])
local_device_count = int(local_device_count) return cluster
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE") else:
if global_device_count is None: node_count = json_config["cluster"]["num_nodes"]
node_count = 1 local_device_count = json_config["cluster"]["num_gpus"]
gpu_model = json_config["cluster"]["gpu_model"]
memory = json_config["cluster"]["gpu_memory"]
else: else:
global_device_count = int(global_device_count) # Get GPU info by get_device_properties
assert global_device_count % local_device_count == 0 local_device_count = os.getenv("PADDLE_LOCAL_SIZE")
node_count = int(global_device_count) // local_device_count if local_device_count is None:
local_device_count = 1
else:
local_device_count = int(local_device_count)
global_device_count = os.getenv("PADDLE_GLOBAL_SIZE")
if global_device_count is None:
node_count = 1
else:
global_device_count = int(global_device_count)
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now."
gpu_name = gpu_info.name
try:
re_result = re.split(r'[ , -]', gpu_name)
gpu_model = re_result[1]
memory = int(re_result[-1][:-2])
except:
memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name
print( print(
"Node Count: ", "Node Count: ",
node_count, node_count,
"Local Device Size: ", "Local Device Size: ",
local_device_count, local_device_count,
"GPU Model: ",
gpu_model,
"GPU Memory: ",
memory,
"World size: ", "World size: ",
paddle.distributed.get_world_size(), paddle.distributed.get_world_size(),
flush=True, flush=True,
) )
cluster.gen_default_config_cluster( cluster.gen_default_config_cluster(
node_count=node_count, device_count=local_device_count node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
) )
return cluster return cluster
...@@ -16,6 +16,7 @@ from collections import OrderedDict ...@@ -16,6 +16,7 @@ from collections import OrderedDict
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.utils.flops import flops
from ..cluster import LinkType from ..cluster import LinkType
from ..dist_tensor import DistributedTensor from ..dist_tensor import DistributedTensor
...@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -91,9 +92,10 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
output_desc = OrderedDict() output_desc = OrderedDict()
# Get partitioned shape of input # Get partitioned shape of input
input_var_desc = {}
for input_name in op.input_names: for input_name in op.input_names:
var_name_list = op.input(input_name) var_name_list = op.input(input_name)
var_desc = [] input_var_desc[input_name] = []
for var_name in var_name_list: for var_name in var_name_list:
var = get_var_with_recursion( var = get_var_with_recursion(
var_name, op.block, op.block.program var_name, op.block, op.block.program
...@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -112,7 +114,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
process, process,
shard_sizes, shard_sizes,
) )
var_desc.append((var.dtype, shape)) input_var_desc[input_name].append(shape)
# For special op such as embedding and its grad op # For special op such as embedding and its grad op
if ( if (
...@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context): ...@@ -137,8 +139,7 @@ def build_comp_desc_from_dist_op(dist_op, dist_context):
relative_idx = relative_idx * per_part_size relative_idx = relative_idx * per_part_size
desc["attrs"]["start_index"] = relative_idx desc["attrs"]["start_index"] = relative_idx
input_desc[input_name] = var_desc desc["inputs"] = input_var_desc
desc["inputs"] = input_desc
for out_name in op.output_names: for out_name in op.output_names:
var_name_list = op.output(out_name) var_name_list = op.output(out_name)
...@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None): ...@@ -350,7 +351,9 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
return desc return desc
def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): def build_comm_costs_from_descs(
op_cost_class, ctx, processes, descs, cluster, is_dp=False
):
"""Build comm costs by descriptions""" """Build comm costs by descriptions"""
comm_context = CommContext(cluster) comm_context = CommContext(cluster)
group_ranks_list = [] group_ranks_list = []
...@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster): ...@@ -363,6 +366,8 @@ def build_comm_costs_from_descs(op_cost_class, ctx, processes, descs, cluster):
comm_op_cost = op_cost_class( comm_op_cost = op_cost_class(
op_desc=desc, comm_context=comm_context op_desc=desc, comm_context=comm_context
) )
if is_dp:
comm_op_cost.cost.time *= 0.9
comm_op_cost_list.append(comm_op_cost) comm_op_cost_list.append(comm_op_cost)
return comm_op_cost_list return comm_op_cost_list
...@@ -389,6 +394,7 @@ def build_dp_costs( ...@@ -389,6 +394,7 @@ def build_dp_costs(
vars = dist_op.serial_op.block.vars vars = dist_op.serial_op.block.vars
var_name = var_names[0] var_name = var_names[0]
has_found = False has_found = False
is_input = True
for name in dist_op.serial_op.input_arg_names: for name in dist_op.serial_op.input_arg_names:
if var_name in name: if var_name in name:
var_name = name var_name = name
...@@ -400,6 +406,7 @@ def build_dp_costs( ...@@ -400,6 +406,7 @@ def build_dp_costs(
if var_name in name: if var_name in name:
var_name = name var_name = name
has_found = True has_found = True
is_input = False
break break
if not has_found: if not has_found:
return return
...@@ -418,6 +425,7 @@ def build_dp_costs( ...@@ -418,6 +425,7 @@ def build_dp_costs(
processes, processes,
c_allreduce_sum_descs, c_allreduce_sum_descs,
cluster, cluster,
is_dp=True,
) )
result.append(comm_cost_list) result.append(comm_cost_list)
...@@ -431,22 +439,11 @@ def build_dp_costs( ...@@ -431,22 +439,11 @@ def build_dp_costs(
desc = {} desc = {}
desc["op"] = op_type desc["op"] = op_type
desc["inputs"] = {} desc["inputs"] = {}
if var_name in dist_attr.inputs_dist_attrs: dims_mapping = (
dims_mapping = dist_attr.get_input_dims_mapping(var_name) dist_attr.get_input_dims_mapping(var_name)
elif var_name in dist_attr.outputs_dist_attrs: if is_input
dims_mapping = dist_attr.get_output_dims_mapping(var_name) else dist_attr.get_output_dims_mapping(var_name)
else: )
raise AssertionError(
"cannot find dims_mapping for {} in {}".format(
var_name, dist_attr
)
)
# dims_mapping = (
# dist_attr.get_input_dims_mapping(var_name)
# if dist_attr.get_input_dims_mapping(var_name) is not None
# else dist_attr.get_output_dims_mapping(var_name)
# )
var = get_var_with_recursion( var = get_var_with_recursion(
var_name, var_name,
dist_op.serial_op.block, dist_op.serial_op.block,
...@@ -493,8 +490,6 @@ class CommContext: ...@@ -493,8 +490,6 @@ class CommContext:
# if cluster has no info about those vars, it will be set by default # if cluster has no info about those vars, it will be set by default
self.base_ring = None self.base_ring = None
self.base_tree = None self.base_tree = None
# self.base_inter_ring = None
# self.base_inter_tree = None
self.intra_ring = None self.intra_ring = None
self.intra_tree = None self.intra_tree = None
self.inter_ring = None self.inter_ring = None
...@@ -508,8 +503,6 @@ class CommContext: ...@@ -508,8 +503,6 @@ class CommContext:
# set default # set default
self.base_ring = 8.4 self.base_ring = 8.4
self.base_tree = 0.0 self.base_tree = 0.0
# self.base_inter_ring = 9.6
# self.base_inter_tree = 28
# NVL in default # NVL in default
self.intra_ring = 3.4 self.intra_ring = 3.4
self.intra_tree = 28 self.intra_tree = 28
...@@ -681,6 +674,8 @@ class Cost: ...@@ -681,6 +674,8 @@ class Cost:
class OpCost: class OpCost:
OP_TYPE = "op"
def __init__(self, op=None, op_desc=None): def __init__(self, op=None, op_desc=None):
self._op = op self._op = op
self._op_desc = op_desc self._op_desc = op_desc
...@@ -883,6 +878,24 @@ class CompOpCost(OpCost): ...@@ -883,6 +878,24 @@ class CompOpCost(OpCost):
) )
) )
def calc_flops(self):
if not self.op_desc:
return 0
if "_grad" in self.__class__.OP_TYPE:
op_type = self.__class__.OP_TYPE[: len(self.__class__.OP_TYPE) - 5]
return 2 * flops(
op_type, self.op_desc["inputs"], self.op_desc["attrs"]
)
return flops(
self.__class__.OP_TYPE,
self.op_desc["inputs"],
self.op_desc["attrs"],
)
def calc_time(self):
flops_count = self.calc_flops()
return flops_count * 2.9e-7
def register_op_cost(cls): def register_op_cost(cls):
op_type = cls.OP_TYPE op_type = cls.OP_TYPE
......
...@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost): ...@@ -140,7 +140,7 @@ class IdentityOpCost(CommOpCost):
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context) super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
def calc_time(self): def calc_time(self):
return 0 return self.comm_count * 1 / (144 * 1e3)
@register_op_cost @register_op_cost
......
...@@ -189,6 +189,9 @@ class CostEstimator: ...@@ -189,6 +189,9 @@ class CostEstimator:
# Calc dist op cost # Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.process_ids processes = op_dist_attr.process_mesh.process_ids
...@@ -225,6 +228,8 @@ class CostEstimator: ...@@ -225,6 +228,8 @@ class CostEstimator:
for rank in group_ranks: for rank in group_ranks:
self.local_cost(rank).time = ( self.local_cost(rank).time = (
max_time + comm_op_cost.time max_time + comm_op_cost.time
if op.attr('op_role') != OpRole.Backward
else max_time + 0.9 * comm_op_cost.time
) )
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
...@@ -290,6 +295,7 @@ class CostEstimator: ...@@ -290,6 +295,7 @@ class CostEstimator:
self._ordered_ops.append([op.desc.id(), op]) self._ordered_ops.append([op.desc.id(), op])
self._ordered_ops.sort(key=lambda x: x[0]) self._ordered_ops.sort(key=lambda x: x[0])
parameters = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_py_reader",
...@@ -298,11 +304,14 @@ class CostEstimator: ...@@ -298,11 +304,14 @@ class CostEstimator:
]: ]:
continue continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name var_name
) )
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str( key = _convert_pm_and_dm_to_str(
...@@ -311,6 +320,10 @@ class CostEstimator: ...@@ -311,6 +320,10 @@ class CostEstimator:
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# It is even partition now # It is even partition now
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -324,9 +337,16 @@ class CostEstimator: ...@@ -324,9 +337,16 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
) )
if "position" not in var_info[var_name][key]: if var.persistable:
var_info[var_name][key]["position"] = [] name = var_name + key
var_info[var_name][key]["position"].append(op_id) if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
...@@ -339,6 +359,10 @@ class CostEstimator: ...@@ -339,6 +359,10 @@ class CostEstimator:
) )
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -352,11 +376,19 @@ class CostEstimator: ...@@ -352,11 +376,19 @@ class CostEstimator:
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype sizes, dtype
) )
if "position" not in var_info[var_name][key]: if var.persistable:
var_info[var_name][key]["position"] = [] name = var_name + key
var_info[var_name][key]["position"].append(op_id) if name not in parameters:
parameters.add(name)
for process in process_mesh.process_ids:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key][
"memory"
]
has_used_vars = set() has_used_vars = set()
not_calc_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_py_reader",
...@@ -367,6 +399,8 @@ class CostEstimator: ...@@ -367,6 +399,8 @@ class CostEstimator:
can_free_memories = {} can_free_memories = {}
can_free_vars = set() can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if not dist_op:
continue
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
...@@ -378,24 +412,30 @@ class CostEstimator: ...@@ -378,24 +412,30 @@ class CostEstimator:
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# Not used # Not used
if var_name + key not in has_used_vars: if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
if has_used_var in not_calc_vars:
continue
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.process_ids: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# Used # Used
else: if op_id == var_info[var_name][key]["position"][-1]:
if op_id == var_info[var_name][key]["position"][-1]: if (
if has_used_var not in can_free_vars: has_used_var not in can_free_vars
can_free_vars.add(has_used_var) and not var.persistable
if not var.persistable: ):
for process in process_mesh.process_ids: can_free_vars.add(has_used_var)
if process not in can_free_memories: for process in process_mesh.process_ids:
can_free_memories[process] = 0 if process not in can_free_memories:
can_free_memories[process] += var_info[ can_free_memories[process] = 0
var_name can_free_memories[process] += var_info[var_name][
][key]["memory"] key
]["memory"]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
...@@ -406,25 +446,36 @@ class CostEstimator: ...@@ -406,25 +446,36 @@ class CostEstimator:
) )
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
if (
op.type == "reshape2"
or op.type == "transpose2"
or op.type == "elementwise_add"
):
not_calc_vars.add(has_used_var)
continue
# Not used # Not used
if var_name + key not in has_used_vars: if (
has_used_var not in has_used_vars
and has_used_var not in parameters
):
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.process_ids: for process in process_mesh.process_ids:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# Used # Used
else: if op_id == var_info[var_name][key]["position"][-1]:
if op_id == var_info[var_name][key]["position"][-1]: if (
if has_used_var not in can_free_vars: has_used_var not in can_free_vars
can_free_vars.add(has_used_var) and not var.persistable
if not var.persistable: ):
for process in process_mesh.process_ids: can_free_vars.add(has_used_var)
if process not in can_free_memories: for process in process_mesh.process_ids:
can_free_memories[process] = 0 if process not in can_free_memories:
can_free_memories[process] += var_info[ can_free_memories[process] = 0
var_name can_free_memories[process] += var_info[var_name][
][key]["memory"] key
]["memory"]
# Calc peak memory # Calc peak memory
for process in memories: for process in memories:
...@@ -433,7 +484,6 @@ class CostEstimator: ...@@ -433,7 +484,6 @@ class CostEstimator:
else: else:
if memories[process] > self.max_memories[process]: if memories[process] > self.max_memories[process]:
self.max_memories[process] = memories[process] self.max_memories[process] = memories[process]
# Free memory # Free memory
for process in can_free_memories: for process in can_free_memories:
if process in memories: if process in memories:
...@@ -513,7 +563,7 @@ class CostEstimator: ...@@ -513,7 +563,7 @@ class CostEstimator:
# Padding automatically # Padding automatically
max_len = 0 max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"] header = ["Execution Time(us)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in vals + header: for memory in vals + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
......
...@@ -2716,6 +2716,8 @@ class Resharder: ...@@ -2716,6 +2716,8 @@ class Resharder:
) )
# simplified processing: ignore union process mesh and output reshard # simplified processing: ignore union process mesh and output reshard
dist_op = self.dist_context.get_dist_op_for_program(op) dist_op = self.dist_context.get_dist_op_for_program(op)
if not dist_tensor or not dist_op:
return reshard_op_cost
dims_mapping = dist_op.dist_attr.get_input_dims_mapping( dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
tensor.name tensor.name
) )
......
...@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -100,10 +100,10 @@ class TestRuleBasedTuner(unittest.TestCase):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8 batch_size = 8
sequence_len = 512 sequence_len = 512
vocab_size = 1000 vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model( train_program, start_program, loss, gen_data = get_gpt_model(
train_program, train_program,
start_program, start_program,
...@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase): ...@@ -112,31 +112,29 @@ class TestRuleBasedTuner(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.dist_context import (
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner, RuleBasedTuner,
) )
clip = paddle.nn.ClipGradByGlobalNorm(0.2) clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext( dist_context = DistributedContext(
serial_main_prog=train_program, serial_main_prog=train_program,
serial_startup_prog=start_program, serial_startup_prog=start_program,
serial_optimizer=opt, serial_optimizer=opt,
serial_loss=loss, serial_loss=loss,
cluster=cluster,
) )
dist_context.initialize() dist_context.initialize()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
tuner.cluster_operators() tuner.tune()
tuner.gen_full_program()
tuner.match_program(tuner._dist_context.serial_main_program)
process_mesh = ProcessMesh([0, 1])
tuner.gen_fwd_sub_programs_by_clone()
tuner.complete_sub_fwd_programs(process_mesh)
tuner.complete_sub_bwd_programs()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册