未验证 提交 118a7415 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Add o2 tune of rule based tuner (#52928)

* add o2 tune

* add unittest

* fix error

* set unittest timeout
上级 d7659ce4
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
# limitations under the License. # limitations under the License.
import json import json
import logging
import os import os
import re import re
from enum import IntEnum, unique from enum import IntEnum, unique
import paddle import paddle
from ..utils.log_utils import get_logger
@unique @unique
class DeviceType(IntEnum): class DeviceType(IntEnum):
...@@ -830,6 +833,9 @@ class Cluster: ...@@ -830,6 +833,9 @@ class Cluster:
return self.__str__() return self.__str__()
logger = get_logger(logging.INFO)
def get_default_cluster(json_config=None): def get_default_cluster(json_config=None):
def is_by_json_config(json_config): def is_by_json_config(json_config):
if not json_config: if not json_config:
...@@ -889,18 +895,15 @@ def get_default_cluster(json_config=None): ...@@ -889,18 +895,15 @@ def get_default_cluster(json_config=None):
memory = int(gpu_info.total_memory) // (1000**3) memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name gpu_model = gpu_name
print( logger.info(
"Node Count: ", "Node Count: {}, Local Device Size: {}, GPU Model: {}, GPU Memory: {}GB, World size: {}, EndPoint: {}.".format(
node_count, node_count,
"Local Device Size: ", local_device_count,
local_device_count, gpu_model,
"GPU Model: ", memory,
gpu_model, paddle.distributed.get_world_size(),
"GPU Memory: ", os.getenv("PADDLE_CURRENT_ENDPOINT", None),
memory, )
"World size: ",
paddle.distributed.get_world_size(),
flush=True,
) )
cluster.gen_default_config_cluster( cluster.gen_default_config_cluster(
node_count=node_count, node_count=node_count,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import json
import logging import logging
import numbers import numbers
import os import os
...@@ -177,6 +178,23 @@ class Engine: ...@@ -177,6 +178,23 @@ class Engine:
self._strategy = strategy or Strategy() self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
self._json_config = None
if cluster:
self._cluster = cluster
else:
if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
try:
path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
with open(path, "r") as f:
self._json_config = json.load(f)
except Exception as e:
self._logger.info(
"Load json failed, please check json file, engine will run default config."
)
self._json_config = None
self._cluster = get_default_cluster(self._json_config)
if os.getenv("POD_NAME"): if os.getenv("POD_NAME"):
self._logger.info( self._logger.info(
"Distribute training by paddle.distributed.launch" "Distribute training by paddle.distributed.launch"
...@@ -653,6 +671,7 @@ class Engine: ...@@ -653,6 +671,7 @@ class Engine:
fetch_vars, fetch_vars,
self._cluster, self._cluster,
self._strategy, self._strategy,
self._json_config,
) )
self._fwd_dist_contexts[mode] = DistributedContext( self._fwd_dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_main_prog,
...@@ -663,6 +682,7 @@ class Engine: ...@@ -663,6 +682,7 @@ class Engine:
fetch_vars, fetch_vars,
self._cluster, self._cluster,
self._strategy, self._strategy,
self._json_config,
) )
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._fwd_main_progs[mode] = serial_main_prog.clone() self._fwd_main_progs[mode] = serial_main_prog.clone()
...@@ -769,7 +789,7 @@ class Engine: ...@@ -769,7 +789,7 @@ class Engine:
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
if self._strategy.auto_mode == "full": if self._strategy.auto_mode == "full_random":
auto_utils.initialize_pg_in_full_mode( auto_utils.initialize_pg_in_full_mode(
all_process_groups, self._cur_rank all_process_groups, self._cur_rank
) )
......
...@@ -12,9 +12,25 @@ ...@@ -12,9 +12,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import os
import pickle
import numpy as np
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from ..utils.log_utils import get_logger
from .completion import Completer from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner from .tuner.parallel_tuner import ParallelTuner
from .tuner.rule_based_tuner import RuleBasedTuner
from .utils import is_naive_data_parallel from .utils import is_naive_data_parallel
...@@ -22,6 +38,7 @@ class Planner: ...@@ -22,6 +38,7 @@ class Planner:
def __init__(self, mode, dist_context): def __init__(self, mode, dist_context):
self._mode = mode self._mode = mode
self._dist_context = dist_context self._dist_context = dist_context
self._load = False
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion. # dependency of backward-forward ops in forward completion.
...@@ -29,30 +46,135 @@ class Planner: ...@@ -29,30 +46,135 @@ class Planner:
self._dist_context._dist_op_context = default_ctx.dist_op_context self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.data_parallel = default_ctx.data_parallel self._dist_context.data_parallel = default_ctx.data_parallel
if not is_naive_data_parallel(self._dist_context): if not is_naive_data_parallel(self._dist_context):
# Use SSA graph for complex parallelism # Use SSA graph for complex parallism
self._dist_context.initialize(with_graph=True) self._dist_context.initialize(with_graph=True)
else: else:
# Use program for data parallel parallelism # Use program for data parallel parallism
self._dist_context.initialize(with_graph=False) self._dist_context.initialize(with_graph=False)
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy self._strategy = dist_context.strategy
# set parallel tuner for auto search # set parallel tuner for auto search
if self._strategy.auto_mode == "full": if self._strategy.auto_mode == "full_random":
self._parallel_tuner = ParallelTuner( self._parallel_tuner = ParallelTuner(
self._dist_context, mode=self._mode self._dist_context, mode=self._mode
) )
elif self._strategy.auto_mode == "full_rule_based":
self._parallel_tuner = RuleBasedTuner(
self._dist_context, mode=self._mode
)
@property @property
def completer(self): def completer(self):
return self._completer return self._completer
def plan(self): def plan(self):
if self._strategy.auto_mode == "full": logger = get_logger(logging.INFO)
self._parallel_tuner.tune() path = None
else: if self._dist_context._json_config:
self._completer.complete_forward_annotation() try:
path = self._dist_context._json_config["tuner_load_path"]
except:
path = None
if path and os.path.exists(path):
try:
with open(path, "rb") as f:
dist_attrs = pickle.load(f)
tensor_dist_attrs = dist_attrs["tensor"]
op_dist_attrs = dist_attrs["op"]
process_meshes = dist_attrs["process_meshes"]
cluster = dist_attrs["cluster"]
last_gpu_model = cluster.machines[0].devices[0].model
last_gpu_memory = cluster.machines[0].devices[0].memory
last_node_count = len(cluster.machines)
last_device_count = len(cluster.get_all_devices("GPU"))
gpu_model = (
self._dist_context.cluster.machines[0].devices[0].model
)
gpu_memory = (
self._dist_context.cluster.machines[0].devices[0].memory
)
node_count = len(self._dist_context.cluster.machines)
device_count = len(
self._dist_context.cluster.get_all_devices("GPU")
)
if (
gpu_model != last_gpu_model
or gpu_memory != last_gpu_memory
or last_node_count != node_count
or device_count != last_device_count
):
logger.info(
"The cluster {} nodes {} {} devices is different from the saved last cluster {} nodes {} {} devices, so we run the planner again.".format(
node_count,
device_count,
gpu_model,
last_node_count,
last_device_count,
last_gpu_model,
)
)
need_set_dist_attr = False
else:
need_set_dist_attr = True
except:
need_set_dist_attr = False
if need_set_dist_attr:
for key in op_dist_attrs:
serial_op = self._dist_context._dist_ops_for_program[
key
].serial_op
# clear dist attr
serial_op.dist_attr = OperatorDistAttr(serial_op.desc)
serial_op.dist_attr.parse_from_string(op_dist_attrs[key])
self._dist_context._dist_ops_for_program[
key
] = DistributedOperator(serial_op)
for key in tensor_dist_attrs:
serial_tensor = (
self._dist_context._dist_tensors_for_program[
key
].serial_tensor
)
# clear dist attr
serial_tensor.dist_attr = TensorDistAttr(serial_tensor.desc)
serial_tensor.dist_attr.parse_from_string(
tensor_dist_attrs[key]
)
self._dist_context._dist_tensors_for_program[
key
] = DistributedTensor(serial_tensor)
process_meshes = []
for item in dist_attrs["process_meshes"]:
process_ids = item[0]
shape = item[1]
process_meshes.append(
ProcessMesh(
np.array(process_ids).reshape(shape).tolist()
)
)
self._dist_context.process_meshes = process_meshes
self._load = True
logger.info(
f"The parallel strategy has been loaded from {path}"
)
if not self._load:
if self._strategy.auto_mode != "semi":
self._parallel_tuner.tune()
else:
self._completer.complete_forward_annotation()
if os.getenv("PADDLE_AUTO_PARALLEL_STAGE", "run") != "run":
quit()
# parse forward sub block # parse forward sub block
self._dist_context.block_state.parse_forward_blocks( self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program self._dist_context.serial_main_program
......
...@@ -35,6 +35,9 @@ from paddle.distributed.auto_parallel.dist_attribute import ( ...@@ -35,6 +35,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
) )
from paddle.distributed.auto_parallel.dist_context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op, is_gradient_clip_op,
...@@ -579,7 +582,6 @@ class GraphUtil: ...@@ -579,7 +582,6 @@ class GraphUtil:
def _match_core(src_node, tgt_node): def _match_core(src_node, tgt_node):
nonlocal not_matched nonlocal not_matched
# not support one input name or output name corresponding to multiple vars # not support one input name or output name corresponding to multiple vars
if not_matched: if not_matched:
return return
...@@ -1126,13 +1128,6 @@ class RuleBasedTuner: ...@@ -1126,13 +1128,6 @@ class RuleBasedTuner:
def level(self): def level(self):
return self._level return self._level
def convert_process_mesh_to_key(self, process_mesh):
"""Convert process mesh object to str."""
processes = ",".join([str(x) for x in process_mesh._process_ids])
topology = ",".join([str(x) for x in process_mesh._shape])
key = processes + ";" + topology
return key
def gen_full_program(self): def gen_full_program(self):
"""Generate full program that contain backward and update phase program if mode is train.""" """Generate full program that contain backward and update phase program if mode is train."""
self.full_main_program = self.dist_context.serial_main_program.clone() self.full_main_program = self.dist_context.serial_main_program.clone()
...@@ -1878,6 +1873,13 @@ class RuleBasedTuner: ...@@ -1878,6 +1873,13 @@ class RuleBasedTuner:
][parallelism][key] ][parallelism][key]
self._complete_sub_update_program(sub_program_dist_context) self._complete_sub_update_program(sub_program_dist_context)
def convert_process_mesh_to_key(self, process_mesh):
"""Convert process mesh object to str."""
processes = ",".join([str(x) for x in process_mesh._process_ids])
topology = ",".join([str(x) for x in process_mesh._shape])
key = processes + ";" + topology
return key
def convert_device_mesh_to_key(self, device_mesh): def convert_device_mesh_to_key(self, device_mesh):
"""Convert device mesh object to str.""" """Convert device mesh object to str."""
processes = ",".join([str(x) for x in device_mesh.device_ids]) processes = ",".join([str(x) for x in device_mesh.device_ids])
...@@ -1894,6 +1896,168 @@ class RuleBasedTuner: ...@@ -1894,6 +1896,168 @@ class RuleBasedTuner:
) )
return global_cost.time, max_memory return global_cost.time, max_memory
def _local_stage_pass(self, start, end, process_mesh):
"""Get the best cost and the corresponding strategy of layers on the given process mesh."""
# convert process mesh to dict key
key = self.convert_process_mesh_to_key(process_mesh)
if start in self.stage_best_cost_of_pm:
if end in self.stage_best_cost_of_pm[start]:
if key in self.stage_best_cost_of_pm[start][end]:
return self.stage_best_cost_of_pm[start][end][key]["cost"]
assert end >= start
selective_parallelisms = (
["dp", "mp"] if len(process_mesh.shape) == 1 else ["dp_mp", "mp_dp"]
)
if start not in self.stage_best_cost_of_pm:
self.stage_best_cost_of_pm[start] = {}
if end not in self.stage_best_cost_of_pm[start]:
self.stage_best_cost_of_pm[start][end] = {}
if key not in self.stage_best_cost_of_pm[start][end]:
self.stage_best_cost_of_pm[start][end][key] = {}
if end == start:
dist_contexts_x = [DistributedContext(), DistributedContext()]
else:
dist_contexts_x = self.stage_best_cost_of_pm[start][end - 1][key][
"dist_context"
]
# Use beam search, the beam size is 2.
# When the process mesh is 1-D, the selecetive parallelsim can be dp or mp.
# Because the first layer often contains more ops than other layer, using beam search can find more accurate strategy.
count = 0
for dist_context_x in dist_contexts_x:
if end == start and count == 1:
break
for parallelism in selective_parallelisms:
dist_context_y = self.sub_programs_dist_context[end][
parallelism
][key]
dist_context = self.combine_dist_contexts(
[dist_context_x, dist_context_y]
)
if (
"dist_context"
not in self.stage_best_cost_of_pm[start][end][key]
):
self.stage_best_cost_of_pm[start][end][key][
"dist_context"
] = [None, None]
self.stage_best_cost_of_pm[start][end][key]["cost"] = [
sys.maxsize,
sys.maxsize,
]
# estimate cost and memory
cost, local_stage_memory = self._get_sub_program_cost(
dist_context
)
if local_stage_memory > 0.9 * self.cluster.machines[0].devices[
0
].memory * (1024**3):
cost = sys.maxsize
index = -1
for idx, item in enumerate(
self.stage_best_cost_of_pm[start][end][key]["cost"]
):
if cost <= item:
index = idx
break
if index == 0:
self.stage_best_cost_of_pm[start][end][key]["cost"][
1
] = self.stage_best_cost_of_pm[start][end][key]["cost"][0]
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
1
] = self.stage_best_cost_of_pm[start][end][key][
"dist_context"
][
0
]
self.stage_best_cost_of_pm[start][end][key]["cost"][
0
] = cost
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
0
] = dist_context
elif index == 1:
self.stage_best_cost_of_pm[start][end][key]["cost"][
1
] = cost
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
1
] = dist_context
count += 1
if (
self.stage_best_cost_of_pm[start][end][key]["cost"][1]
< self.stage_best_cost_of_pm[start][end][key]["cost"][0]
):
self.stage_best_cost_of_pm[start][end][key][
"best_cost"
] = self.stage_best_cost_of_pm[start][end][key]["cost"][1]
self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
] = self.stage_best_cost_of_pm[start][end][key]["dist_context"][1]
else:
self.stage_best_cost_of_pm[start][end][key][
"best_cost"
] = self.stage_best_cost_of_pm[start][end][key]["cost"][0]
self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
] = self.stage_best_cost_of_pm[start][end][key]["dist_context"][0]
return self.stage_best_cost_of_pm[start][end][key]["best_cost"]
def local_stage_pass(self, start, end, device_mesh):
"""Get the best cost and the corresponding strategy of layers on the given device mesh."""
dm_key = self.convert_device_mesh_to_key(device_mesh)
device_mesh_shape = device_mesh.shape
if len(device_mesh_shape) == 1:
device_mesh_shape.insert(0, 1)
process_mesh_shapes = convert_to_process_meshes(device_mesh_shape)
best_cost = sys.maxsize
if start not in self.stage_best_cost_of_dm:
self.stage_best_cost_of_dm[start] = {}
if end not in self.stage_best_cost_of_dm[start]:
self.stage_best_cost_of_dm[start][end] = {}
if dm_key not in self.stage_best_cost_of_dm[start][end]:
self.stage_best_cost_of_dm[start][end][dm_key] = {}
for process_mesh_shape in process_mesh_shapes:
process_mesh = ProcessMesh(
np.array(device_mesh.device_ids)
.reshape(process_mesh_shape)
.tolist()
)
key = self.convert_process_mesh_to_key(process_mesh)
for i in range(start, end + 1):
self._local_stage_pass(start, i, process_mesh)
if (
self.stage_best_cost_of_pm[start][end][key]["best_cost"]
<= best_cost
):
best_cost = self.stage_best_cost_of_pm[start][end][key][
"best_cost"
]
self.stage_best_cost_of_dm[start][end][dm_key][
"cost"
] = best_cost
self.stage_best_cost_of_dm[start][end][dm_key][
"dist_context"
] = self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
]
return best_cost
def combine_dist_contexts(self, dist_contexts): def combine_dist_contexts(self, dist_contexts):
"""Combine the dist attr in dist contexts to one dist context.""" """Combine the dist attr in dist contexts to one dist context."""
combined_dist_context = DistributedContext() combined_dist_context = DistributedContext()
...@@ -1927,7 +2091,7 @@ class RuleBasedTuner: ...@@ -1927,7 +2091,7 @@ class RuleBasedTuner:
self.layers = self.cluster_operators() self.layers = self.cluster_operators()
end = time.time() end = time.time()
self._logger.info( self._logger.info(
"Cluster operators to {} layers in {}s.".format( "Cluster operators to {} layers in {:.2f}s.".format(
len(self.layers), end - begin len(self.layers), end - begin
) )
) )
...@@ -1937,7 +2101,7 @@ class RuleBasedTuner: ...@@ -1937,7 +2101,7 @@ class RuleBasedTuner:
self.gen_fwd_sub_programs_by_clone() self.gen_fwd_sub_programs_by_clone()
end = time.time() end = time.time()
self._logger.info( self._logger.info(
f"Generate programs of every layer in {end - begin}s." f"Generate programs of every layer in {end - begin:.2f}s."
) )
# step3: partition devices to device meshes # step3: partition devices to device meshes
...@@ -1948,7 +2112,7 @@ class RuleBasedTuner: ...@@ -1948,7 +2112,7 @@ class RuleBasedTuner:
) )
device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m) device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m)
end = time.time() end = time.time()
self._logger.info(f"Partition cluster in {end - begin}s.") self._logger.info(f"Partition cluster in {end - begin:.2f}s.")
# step4: transform device mesh to process meshes # step4: transform device mesh to process meshes
dm_idx = 0 dm_idx = 0
...@@ -1987,7 +2151,7 @@ class RuleBasedTuner: ...@@ -1987,7 +2151,7 @@ class RuleBasedTuner:
begin = time.time() begin = time.time()
self.gen_full_program() self.gen_full_program()
end = time.time() end = time.time()
self._logger.info(f"Generate full program in {end - begin}s.") self._logger.info(f"Generate full program in {end - begin:.2f}s.")
# step6: complete forward sub programs # step6: complete forward sub programs
begin = time.time() begin = time.time()
...@@ -1995,7 +2159,7 @@ class RuleBasedTuner: ...@@ -1995,7 +2159,7 @@ class RuleBasedTuner:
self.complete_sub_fwd_programs(process_mesh) self.complete_sub_fwd_programs(process_mesh)
end = time.time() end = time.time()
self._logger.info( self._logger.info(
f"Complete all sub forward programs in {end - begin}s." f"Complete all sub forward programs in {end - begin:.2f}s."
) )
if self.mode == "train": if self.mode == "train":
...@@ -2004,7 +2168,9 @@ class RuleBasedTuner: ...@@ -2004,7 +2168,9 @@ class RuleBasedTuner:
self.complete_sub_bwd_programs() self.complete_sub_bwd_programs()
end = time.time() end = time.time()
self._logger.info( self._logger.info(
f"Complete all sub backward programs in {end - begin}s." "Complete all sub backward programs in {:.2f}s.".format(
end - begin
)
) )
# step8: complete update sub programs # step8: complete update sub programs
...@@ -2015,6 +2181,88 @@ class RuleBasedTuner: ...@@ -2015,6 +2181,88 @@ class RuleBasedTuner:
f"Complete all sub update programs in {end - begin}s." f"Complete all sub update programs in {end - begin}s."
) )
def layer_placement_pass(self, stages, layers, device_meshes):
"""Get the best cost and the corresponding strategy of the given layers on the stages which running on the devices."""
stage_layer_cost = [
[sys.maxsize for i in range(layers)] for j in range(stages)
]
# To get the balance among the stages, we select the minimum maximum cost of stages.
min_max_stage_costs = [
[None for i in range(layers)] for j in range(stages)
]
best_strategies = [[None for i in range(layers)] for j in range(stages)]
for s in range(len(device_meshes)):
for i in range(0, layers):
if s == 0:
stage_layer_cost[s][i] = self.local_stage_pass(
0, i, device_meshes[s]
)
min_max_stage_costs[s][i] = stage_layer_cost[s][i]
key = self.convert_device_mesh_to_key(device_meshes[s])
best_strategies[s][i] = self.stage_best_cost_of_dm[0][i][
key
]["dist_context"]
else:
min_cost = sys.maxsize
min_max_stage_cost = sys.maxsize
for j in range(0, i):
key = self.convert_device_mesh_to_key(device_meshes[s])
local_stage_cost = self.local_stage_pass(
j + 1, i, device_meshes[s]
)
dist_context = self.combine_dist_contexts(
[
best_strategies[s - 1][j],
self.stage_best_cost_of_dm[j + 1][i][key][
"dist_context"
],
]
)
cost, _ = self._get_sub_program_cost(dist_context)
max_stage_cost = (
min_max_stage_costs[s - 1][j]
if local_stage_cost < min_max_stage_costs[s - 1][j]
else local_stage_cost
)
if cost <= min_cost:
if cost == min_cost:
if max_stage_cost < min_max_stage_cost:
min_max_stage_cost = max_stage_cost
best_strategies[s][i] = dist_context
else:
break
else:
best_strategies[s][i] = dist_context
min_cost = cost
stage_layer_cost[s][i] = min_cost
min_max_stage_costs[s][i] = min_max_stage_cost
return (
stage_layer_cost[stages - 1][layers - 1],
best_strategies[stages - 1][layers - 1],
)
def tune_o2(self):
"""The o2 level tuning."""
best_dist_context = None
best_cost = sys.maxsize
for device_meshes in self.device_meshes_list:
cost, dist_context = self.layer_placement_pass(
len(device_meshes), len(self.layers), device_meshes
)
if cost <= best_cost:
self._logger.info(
"O2 level: a better strategy has be found as follows: "
)
print_program_with_dist_attr(
self.full_main_program, best_dist_context
)
best_cost = cost
best_dist_context = dist_context
return best_dist_context
def tune_o1(self): def tune_o1(self):
"""The o1 level tuning.""" """The o1 level tuning."""
best_cost = sys.maxsize best_cost = sys.maxsize
...@@ -2082,7 +2330,7 @@ class RuleBasedTuner: ...@@ -2082,7 +2330,7 @@ class RuleBasedTuner:
) )
self._logger.info( self._logger.info(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format( "Cost Model: The max memory is {:.2f}GB and cost is {:.2f} when {} parallelism under process mesh shape {} on {} stages.".format(
memory / (1024**3), memory / (1024**3),
cost, cost,
parallelism, parallelism,
...@@ -2090,8 +2338,8 @@ class RuleBasedTuner: ...@@ -2090,8 +2338,8 @@ class RuleBasedTuner:
len(device_meshes), len(device_meshes),
) )
) )
# 15% buffer is reserved for memory cost # 10% buffer is reserved safely for memory cost
if memory > 0.85 * self.cluster.machines[0].devices[ if memory > 0.9 * self.cluster.machines[0].devices[
0 0
].memory * (1024**3): ].memory * (1024**3):
cost = sys.maxsize cost = sys.maxsize
...@@ -2100,7 +2348,7 @@ class RuleBasedTuner: ...@@ -2100,7 +2348,7 @@ class RuleBasedTuner:
best_cost = cost best_cost = cost
best_dist_context = dist_context_of_device_meshes best_dist_context = dist_context_of_device_meshes
self._logger.info( self._logger.info(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format( "O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {:.2f}GB.".format(
parallelism, parallelism,
process_mesh_shape, process_mesh_shape,
len(device_meshes), len(device_meshes),
...@@ -2110,9 +2358,6 @@ class RuleBasedTuner: ...@@ -2110,9 +2358,6 @@ class RuleBasedTuner:
return best_dist_context return best_dist_context
def tune_o2(self):
return None
def save_strategy(self, best_dist_context, path): def save_strategy(self, best_dist_context, path):
dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []} dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []}
for key in best_dist_context._dist_tensors_for_program: for key in best_dist_context._dist_tensors_for_program:
...@@ -2151,9 +2396,14 @@ class RuleBasedTuner: ...@@ -2151,9 +2396,14 @@ class RuleBasedTuner:
begin = time.time() begin = time.time()
self.match_program(self._dist_context.serial_main_program) self.match_program(self._dist_context.serial_main_program)
end = time.time() end = time.time()
self._logger.info(f"Pattern match in {end - begin}s.") self._logger.info(f"Pattern match in {end - begin:.2f}s.")
if self._use_dp: if self._use_dp:
total_rank = (
self._cluster.get_num_machines()
* self._cluster._num_devices_per_machine
)
get_world_process_group().add_ranks(list(range(total_rank)))
completer = Completer(self._dist_context) completer = Completer(self._dist_context)
completer.complete_forward_annotation() completer.complete_forward_annotation()
print_program_with_dist_attr( print_program_with_dist_attr(
...@@ -2213,7 +2463,7 @@ class RuleBasedTuner: ...@@ -2213,7 +2463,7 @@ class RuleBasedTuner:
self._dist_context._process_meshes = best_dist_context._process_meshes self._dist_context._process_meshes = best_dist_context._process_meshes
end = time.time() end = time.time()
self._logger.info(f"Rule-based tuner end in {end - begin}s.") self._logger.info(f"Rule-based tuner end in {end - begin:.2f}s.")
self._logger.info("The best strategy found is as follows: ") self._logger.info("The best strategy found is as follows: ")
print_program_with_dist_attr(self.full_main_program, best_dist_context) print_program_with_dist_attr(self.full_main_program, best_dist_context)
......
...@@ -85,6 +85,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -85,6 +85,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20) set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20)
py_test_modules(test_fuse_adamw_pass MODULES test_fuse_adamw_pass) py_test_modules(test_fuse_adamw_pass MODULES test_fuse_adamw_pass)
set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20) set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20)
py_test_modules(test_rule_based_tuner_o2 MODULES test_rule_based_tuner_o2)
set_tests_properties(test_rule_based_tuner_o2 PROPERTIES TIMEOUT 50)
# End of unittests WITH single card and timeout # End of unittests WITH single card and timeout
# NOTE(zyl): unittests WITH single card and WITHOUT timeout # NOTE(zyl): unittests WITH single card and WITHOUT timeout
......
...@@ -153,7 +153,7 @@ class TestParallelTunerFull(unittest.TestCase): ...@@ -153,7 +153,7 @@ class TestParallelTunerFull(unittest.TestCase):
cluster = Cluster() cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8) cluster.gen_default_config_cluster(node_count=1, device_count=8)
strategy = Strategy() strategy = Strategy()
strategy.auto_mode = "full" strategy.auto_mode = "full_random"
dist_context = DistributedContext( dist_context = DistributedContext(
train_program, train_program,
start_program, start_program,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
from paddle import static
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def get_gpt_model(
train_program, start_program, place, batch_size, sequence_len, vocab_size
):
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.data(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32',
)
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
class TestRuleBasedTuner(unittest.TestCase):
def test_gpt_o2(self):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
batch_size = 8
sequence_len = 512
vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
place,
batch_size,
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
cluster=cluster,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context, level="o2")
tuner.tune()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册