未验证 提交 118a7415 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel]Add o2 tune of rule based tuner (#52928)

* add o2 tune

* add unittest

* fix error

* set unittest timeout
上级 d7659ce4
......@@ -13,12 +13,15 @@
# limitations under the License.
import json
import logging
import os
import re
from enum import IntEnum, unique
import paddle
from ..utils.log_utils import get_logger
@unique
class DeviceType(IntEnum):
......@@ -830,6 +833,9 @@ class Cluster:
return self.__str__()
logger = get_logger(logging.INFO)
def get_default_cluster(json_config=None):
def is_by_json_config(json_config):
if not json_config:
......@@ -889,18 +895,15 @@ def get_default_cluster(json_config=None):
memory = int(gpu_info.total_memory) // (1000**3)
gpu_model = gpu_name
print(
"Node Count: ",
node_count,
"Local Device Size: ",
local_device_count,
"GPU Model: ",
gpu_model,
"GPU Memory: ",
memory,
"World size: ",
paddle.distributed.get_world_size(),
flush=True,
logger.info(
"Node Count: {}, Local Device Size: {}, GPU Model: {}, GPU Memory: {}GB, World size: {}, EndPoint: {}.".format(
node_count,
local_device_count,
gpu_model,
memory,
paddle.distributed.get_world_size(),
os.getenv("PADDLE_CURRENT_ENDPOINT", None),
)
)
cluster.gen_default_config_cluster(
node_count=node_count,
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import copy
import json
import logging
import numbers
import os
......@@ -177,6 +178,23 @@ class Engine:
self._strategy = strategy or Strategy()
self._logger = get_logger(logging.INFO)
self._json_config = None
if cluster:
self._cluster = cluster
else:
if os.getenv("PADDLE_AUTO_PARALLEL_CONFIG"):
try:
path = os.getenv("PADDLE_AUTO_PARALLEL_CONFIG")
with open(path, "r") as f:
self._json_config = json.load(f)
except Exception as e:
self._logger.info(
"Load json failed, please check json file, engine will run default config."
)
self._json_config = None
self._cluster = get_default_cluster(self._json_config)
if os.getenv("POD_NAME"):
self._logger.info(
"Distribute training by paddle.distributed.launch"
......@@ -653,6 +671,7 @@ class Engine:
fetch_vars,
self._cluster,
self._strategy,
self._json_config,
)
self._fwd_dist_contexts[mode] = DistributedContext(
serial_main_prog,
......@@ -663,6 +682,7 @@ class Engine:
fetch_vars,
self._cluster,
self._strategy,
self._json_config,
)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
self._fwd_main_progs[mode] = serial_main_prog.clone()
......@@ -769,7 +789,7 @@ class Engine:
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
if self._strategy.auto_mode == "full":
if self._strategy.auto_mode == "full_random":
auto_utils.initialize_pg_in_full_mode(
all_process_groups, self._cur_rank
)
......
......@@ -12,9 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import pickle
import numpy as np
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from ..utils.log_utils import get_logger
from .completion import Completer
from .dist_context import get_default_distributed_context
from .tuner.parallel_tuner import ParallelTuner
from .tuner.rule_based_tuner import RuleBasedTuner
from .utils import is_naive_data_parallel
......@@ -22,6 +38,7 @@ class Planner:
def __init__(self, mode, dist_context):
self._mode = mode
self._dist_context = dist_context
self._load = False
# NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
# dependency of backward-forward ops in forward completion.
......@@ -29,30 +46,135 @@ class Planner:
self._dist_context._dist_op_context = default_ctx.dist_op_context
self._dist_context.data_parallel = default_ctx.data_parallel
if not is_naive_data_parallel(self._dist_context):
# Use SSA graph for complex parallelism
# Use SSA graph for complex parallism
self._dist_context.initialize(with_graph=True)
else:
# Use program for data parallel parallelism
# Use program for data parallel parallism
self._dist_context.initialize(with_graph=False)
self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy
# set parallel tuner for auto search
if self._strategy.auto_mode == "full":
if self._strategy.auto_mode == "full_random":
self._parallel_tuner = ParallelTuner(
self._dist_context, mode=self._mode
)
elif self._strategy.auto_mode == "full_rule_based":
self._parallel_tuner = RuleBasedTuner(
self._dist_context, mode=self._mode
)
@property
def completer(self):
return self._completer
def plan(self):
if self._strategy.auto_mode == "full":
self._parallel_tuner.tune()
else:
self._completer.complete_forward_annotation()
logger = get_logger(logging.INFO)
path = None
if self._dist_context._json_config:
try:
path = self._dist_context._json_config["tuner_load_path"]
except:
path = None
if path and os.path.exists(path):
try:
with open(path, "rb") as f:
dist_attrs = pickle.load(f)
tensor_dist_attrs = dist_attrs["tensor"]
op_dist_attrs = dist_attrs["op"]
process_meshes = dist_attrs["process_meshes"]
cluster = dist_attrs["cluster"]
last_gpu_model = cluster.machines[0].devices[0].model
last_gpu_memory = cluster.machines[0].devices[0].memory
last_node_count = len(cluster.machines)
last_device_count = len(cluster.get_all_devices("GPU"))
gpu_model = (
self._dist_context.cluster.machines[0].devices[0].model
)
gpu_memory = (
self._dist_context.cluster.machines[0].devices[0].memory
)
node_count = len(self._dist_context.cluster.machines)
device_count = len(
self._dist_context.cluster.get_all_devices("GPU")
)
if (
gpu_model != last_gpu_model
or gpu_memory != last_gpu_memory
or last_node_count != node_count
or device_count != last_device_count
):
logger.info(
"The cluster {} nodes {} {} devices is different from the saved last cluster {} nodes {} {} devices, so we run the planner again.".format(
node_count,
device_count,
gpu_model,
last_node_count,
last_device_count,
last_gpu_model,
)
)
need_set_dist_attr = False
else:
need_set_dist_attr = True
except:
need_set_dist_attr = False
if need_set_dist_attr:
for key in op_dist_attrs:
serial_op = self._dist_context._dist_ops_for_program[
key
].serial_op
# clear dist attr
serial_op.dist_attr = OperatorDistAttr(serial_op.desc)
serial_op.dist_attr.parse_from_string(op_dist_attrs[key])
self._dist_context._dist_ops_for_program[
key
] = DistributedOperator(serial_op)
for key in tensor_dist_attrs:
serial_tensor = (
self._dist_context._dist_tensors_for_program[
key
].serial_tensor
)
# clear dist attr
serial_tensor.dist_attr = TensorDistAttr(serial_tensor.desc)
serial_tensor.dist_attr.parse_from_string(
tensor_dist_attrs[key]
)
self._dist_context._dist_tensors_for_program[
key
] = DistributedTensor(serial_tensor)
process_meshes = []
for item in dist_attrs["process_meshes"]:
process_ids = item[0]
shape = item[1]
process_meshes.append(
ProcessMesh(
np.array(process_ids).reshape(shape).tolist()
)
)
self._dist_context.process_meshes = process_meshes
self._load = True
logger.info(
f"The parallel strategy has been loaded from {path}"
)
if not self._load:
if self._strategy.auto_mode != "semi":
self._parallel_tuner.tune()
else:
self._completer.complete_forward_annotation()
if os.getenv("PADDLE_AUTO_PARALLEL_STAGE", "run") != "run":
quit()
# parse forward sub block
self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program
......
......@@ -35,6 +35,9 @@ from paddle.distributed.auto_parallel.dist_attribute import (
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
is_gradient_clip_op,
......@@ -579,7 +582,6 @@ class GraphUtil:
def _match_core(src_node, tgt_node):
nonlocal not_matched
# not support one input name or output name corresponding to multiple vars
if not_matched:
return
......@@ -1126,13 +1128,6 @@ class RuleBasedTuner:
def level(self):
return self._level
def convert_process_mesh_to_key(self, process_mesh):
"""Convert process mesh object to str."""
processes = ",".join([str(x) for x in process_mesh._process_ids])
topology = ",".join([str(x) for x in process_mesh._shape])
key = processes + ";" + topology
return key
def gen_full_program(self):
"""Generate full program that contain backward and update phase program if mode is train."""
self.full_main_program = self.dist_context.serial_main_program.clone()
......@@ -1878,6 +1873,13 @@ class RuleBasedTuner:
][parallelism][key]
self._complete_sub_update_program(sub_program_dist_context)
def convert_process_mesh_to_key(self, process_mesh):
"""Convert process mesh object to str."""
processes = ",".join([str(x) for x in process_mesh._process_ids])
topology = ",".join([str(x) for x in process_mesh._shape])
key = processes + ";" + topology
return key
def convert_device_mesh_to_key(self, device_mesh):
"""Convert device mesh object to str."""
processes = ",".join([str(x) for x in device_mesh.device_ids])
......@@ -1894,6 +1896,168 @@ class RuleBasedTuner:
)
return global_cost.time, max_memory
def _local_stage_pass(self, start, end, process_mesh):
"""Get the best cost and the corresponding strategy of layers on the given process mesh."""
# convert process mesh to dict key
key = self.convert_process_mesh_to_key(process_mesh)
if start in self.stage_best_cost_of_pm:
if end in self.stage_best_cost_of_pm[start]:
if key in self.stage_best_cost_of_pm[start][end]:
return self.stage_best_cost_of_pm[start][end][key]["cost"]
assert end >= start
selective_parallelisms = (
["dp", "mp"] if len(process_mesh.shape) == 1 else ["dp_mp", "mp_dp"]
)
if start not in self.stage_best_cost_of_pm:
self.stage_best_cost_of_pm[start] = {}
if end not in self.stage_best_cost_of_pm[start]:
self.stage_best_cost_of_pm[start][end] = {}
if key not in self.stage_best_cost_of_pm[start][end]:
self.stage_best_cost_of_pm[start][end][key] = {}
if end == start:
dist_contexts_x = [DistributedContext(), DistributedContext()]
else:
dist_contexts_x = self.stage_best_cost_of_pm[start][end - 1][key][
"dist_context"
]
# Use beam search, the beam size is 2.
# When the process mesh is 1-D, the selecetive parallelsim can be dp or mp.
# Because the first layer often contains more ops than other layer, using beam search can find more accurate strategy.
count = 0
for dist_context_x in dist_contexts_x:
if end == start and count == 1:
break
for parallelism in selective_parallelisms:
dist_context_y = self.sub_programs_dist_context[end][
parallelism
][key]
dist_context = self.combine_dist_contexts(
[dist_context_x, dist_context_y]
)
if (
"dist_context"
not in self.stage_best_cost_of_pm[start][end][key]
):
self.stage_best_cost_of_pm[start][end][key][
"dist_context"
] = [None, None]
self.stage_best_cost_of_pm[start][end][key]["cost"] = [
sys.maxsize,
sys.maxsize,
]
# estimate cost and memory
cost, local_stage_memory = self._get_sub_program_cost(
dist_context
)
if local_stage_memory > 0.9 * self.cluster.machines[0].devices[
0
].memory * (1024**3):
cost = sys.maxsize
index = -1
for idx, item in enumerate(
self.stage_best_cost_of_pm[start][end][key]["cost"]
):
if cost <= item:
index = idx
break
if index == 0:
self.stage_best_cost_of_pm[start][end][key]["cost"][
1
] = self.stage_best_cost_of_pm[start][end][key]["cost"][0]
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
1
] = self.stage_best_cost_of_pm[start][end][key][
"dist_context"
][
0
]
self.stage_best_cost_of_pm[start][end][key]["cost"][
0
] = cost
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
0
] = dist_context
elif index == 1:
self.stage_best_cost_of_pm[start][end][key]["cost"][
1
] = cost
self.stage_best_cost_of_pm[start][end][key]["dist_context"][
1
] = dist_context
count += 1
if (
self.stage_best_cost_of_pm[start][end][key]["cost"][1]
< self.stage_best_cost_of_pm[start][end][key]["cost"][0]
):
self.stage_best_cost_of_pm[start][end][key][
"best_cost"
] = self.stage_best_cost_of_pm[start][end][key]["cost"][1]
self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
] = self.stage_best_cost_of_pm[start][end][key]["dist_context"][1]
else:
self.stage_best_cost_of_pm[start][end][key][
"best_cost"
] = self.stage_best_cost_of_pm[start][end][key]["cost"][0]
self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
] = self.stage_best_cost_of_pm[start][end][key]["dist_context"][0]
return self.stage_best_cost_of_pm[start][end][key]["best_cost"]
def local_stage_pass(self, start, end, device_mesh):
"""Get the best cost and the corresponding strategy of layers on the given device mesh."""
dm_key = self.convert_device_mesh_to_key(device_mesh)
device_mesh_shape = device_mesh.shape
if len(device_mesh_shape) == 1:
device_mesh_shape.insert(0, 1)
process_mesh_shapes = convert_to_process_meshes(device_mesh_shape)
best_cost = sys.maxsize
if start not in self.stage_best_cost_of_dm:
self.stage_best_cost_of_dm[start] = {}
if end not in self.stage_best_cost_of_dm[start]:
self.stage_best_cost_of_dm[start][end] = {}
if dm_key not in self.stage_best_cost_of_dm[start][end]:
self.stage_best_cost_of_dm[start][end][dm_key] = {}
for process_mesh_shape in process_mesh_shapes:
process_mesh = ProcessMesh(
np.array(device_mesh.device_ids)
.reshape(process_mesh_shape)
.tolist()
)
key = self.convert_process_mesh_to_key(process_mesh)
for i in range(start, end + 1):
self._local_stage_pass(start, i, process_mesh)
if (
self.stage_best_cost_of_pm[start][end][key]["best_cost"]
<= best_cost
):
best_cost = self.stage_best_cost_of_pm[start][end][key][
"best_cost"
]
self.stage_best_cost_of_dm[start][end][dm_key][
"cost"
] = best_cost
self.stage_best_cost_of_dm[start][end][dm_key][
"dist_context"
] = self.stage_best_cost_of_pm[start][end][key][
"best_dist_context"
]
return best_cost
def combine_dist_contexts(self, dist_contexts):
"""Combine the dist attr in dist contexts to one dist context."""
combined_dist_context = DistributedContext()
......@@ -1927,7 +2091,7 @@ class RuleBasedTuner:
self.layers = self.cluster_operators()
end = time.time()
self._logger.info(
"Cluster operators to {} layers in {}s.".format(
"Cluster operators to {} layers in {:.2f}s.".format(
len(self.layers), end - begin
)
)
......@@ -1937,7 +2101,7 @@ class RuleBasedTuner:
self.gen_fwd_sub_programs_by_clone()
end = time.time()
self._logger.info(
f"Generate programs of every layer in {end - begin}s."
f"Generate programs of every layer in {end - begin:.2f}s."
)
# step3: partition devices to device meshes
......@@ -1948,7 +2112,7 @@ class RuleBasedTuner:
)
device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m)
end = time.time()
self._logger.info(f"Partition cluster in {end - begin}s.")
self._logger.info(f"Partition cluster in {end - begin:.2f}s.")
# step4: transform device mesh to process meshes
dm_idx = 0
......@@ -1987,7 +2151,7 @@ class RuleBasedTuner:
begin = time.time()
self.gen_full_program()
end = time.time()
self._logger.info(f"Generate full program in {end - begin}s.")
self._logger.info(f"Generate full program in {end - begin:.2f}s.")
# step6: complete forward sub programs
begin = time.time()
......@@ -1995,7 +2159,7 @@ class RuleBasedTuner:
self.complete_sub_fwd_programs(process_mesh)
end = time.time()
self._logger.info(
f"Complete all sub forward programs in {end - begin}s."
f"Complete all sub forward programs in {end - begin:.2f}s."
)
if self.mode == "train":
......@@ -2004,7 +2168,9 @@ class RuleBasedTuner:
self.complete_sub_bwd_programs()
end = time.time()
self._logger.info(
f"Complete all sub backward programs in {end - begin}s."
"Complete all sub backward programs in {:.2f}s.".format(
end - begin
)
)
# step8: complete update sub programs
......@@ -2015,6 +2181,88 @@ class RuleBasedTuner:
f"Complete all sub update programs in {end - begin}s."
)
def layer_placement_pass(self, stages, layers, device_meshes):
"""Get the best cost and the corresponding strategy of the given layers on the stages which running on the devices."""
stage_layer_cost = [
[sys.maxsize for i in range(layers)] for j in range(stages)
]
# To get the balance among the stages, we select the minimum maximum cost of stages.
min_max_stage_costs = [
[None for i in range(layers)] for j in range(stages)
]
best_strategies = [[None for i in range(layers)] for j in range(stages)]
for s in range(len(device_meshes)):
for i in range(0, layers):
if s == 0:
stage_layer_cost[s][i] = self.local_stage_pass(
0, i, device_meshes[s]
)
min_max_stage_costs[s][i] = stage_layer_cost[s][i]
key = self.convert_device_mesh_to_key(device_meshes[s])
best_strategies[s][i] = self.stage_best_cost_of_dm[0][i][
key
]["dist_context"]
else:
min_cost = sys.maxsize
min_max_stage_cost = sys.maxsize
for j in range(0, i):
key = self.convert_device_mesh_to_key(device_meshes[s])
local_stage_cost = self.local_stage_pass(
j + 1, i, device_meshes[s]
)
dist_context = self.combine_dist_contexts(
[
best_strategies[s - 1][j],
self.stage_best_cost_of_dm[j + 1][i][key][
"dist_context"
],
]
)
cost, _ = self._get_sub_program_cost(dist_context)
max_stage_cost = (
min_max_stage_costs[s - 1][j]
if local_stage_cost < min_max_stage_costs[s - 1][j]
else local_stage_cost
)
if cost <= min_cost:
if cost == min_cost:
if max_stage_cost < min_max_stage_cost:
min_max_stage_cost = max_stage_cost
best_strategies[s][i] = dist_context
else:
break
else:
best_strategies[s][i] = dist_context
min_cost = cost
stage_layer_cost[s][i] = min_cost
min_max_stage_costs[s][i] = min_max_stage_cost
return (
stage_layer_cost[stages - 1][layers - 1],
best_strategies[stages - 1][layers - 1],
)
def tune_o2(self):
"""The o2 level tuning."""
best_dist_context = None
best_cost = sys.maxsize
for device_meshes in self.device_meshes_list:
cost, dist_context = self.layer_placement_pass(
len(device_meshes), len(self.layers), device_meshes
)
if cost <= best_cost:
self._logger.info(
"O2 level: a better strategy has be found as follows: "
)
print_program_with_dist_attr(
self.full_main_program, best_dist_context
)
best_cost = cost
best_dist_context = dist_context
return best_dist_context
def tune_o1(self):
"""The o1 level tuning."""
best_cost = sys.maxsize
......@@ -2082,7 +2330,7 @@ class RuleBasedTuner:
)
self._logger.info(
"Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format(
"Cost Model: The max memory is {:.2f}GB and cost is {:.2f} when {} parallelism under process mesh shape {} on {} stages.".format(
memory / (1024**3),
cost,
parallelism,
......@@ -2090,8 +2338,8 @@ class RuleBasedTuner:
len(device_meshes),
)
)
# 15% buffer is reserved for memory cost
if memory > 0.85 * self.cluster.machines[0].devices[
# 10% buffer is reserved safely for memory cost
if memory > 0.9 * self.cluster.machines[0].devices[
0
].memory * (1024**3):
cost = sys.maxsize
......@@ -2100,7 +2348,7 @@ class RuleBasedTuner:
best_cost = cost
best_dist_context = dist_context_of_device_meshes
self._logger.info(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format(
"O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {:.2f}GB.".format(
parallelism,
process_mesh_shape,
len(device_meshes),
......@@ -2110,9 +2358,6 @@ class RuleBasedTuner:
return best_dist_context
def tune_o2(self):
return None
def save_strategy(self, best_dist_context, path):
dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []}
for key in best_dist_context._dist_tensors_for_program:
......@@ -2151,9 +2396,14 @@ class RuleBasedTuner:
begin = time.time()
self.match_program(self._dist_context.serial_main_program)
end = time.time()
self._logger.info(f"Pattern match in {end - begin}s.")
self._logger.info(f"Pattern match in {end - begin:.2f}s.")
if self._use_dp:
total_rank = (
self._cluster.get_num_machines()
* self._cluster._num_devices_per_machine
)
get_world_process_group().add_ranks(list(range(total_rank)))
completer = Completer(self._dist_context)
completer.complete_forward_annotation()
print_program_with_dist_attr(
......@@ -2213,7 +2463,7 @@ class RuleBasedTuner:
self._dist_context._process_meshes = best_dist_context._process_meshes
end = time.time()
self._logger.info(f"Rule-based tuner end in {end - begin}s.")
self._logger.info(f"Rule-based tuner end in {end - begin:.2f}s.")
self._logger.info("The best strategy found is as follows: ")
print_program_with_dist_attr(self.full_main_program, best_dist_context)
......
......@@ -85,6 +85,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20)
py_test_modules(test_fuse_adamw_pass MODULES test_fuse_adamw_pass)
set_tests_properties(test_fuse_adamw_pass PROPERTIES TIMEOUT 20)
py_test_modules(test_rule_based_tuner_o2 MODULES test_rule_based_tuner_o2)
set_tests_properties(test_rule_based_tuner_o2 PROPERTIES TIMEOUT 50)
# End of unittests WITH single card and timeout
# NOTE(zyl): unittests WITH single card and WITHOUT timeout
......
......@@ -153,7 +153,7 @@ class TestParallelTunerFull(unittest.TestCase):
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
strategy = Strategy()
strategy.auto_mode = "full"
strategy.auto_mode = "full_random"
dist_context = DistributedContext(
train_program,
start_program,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
from paddle import static
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def get_gpt_model(
train_program, start_program, place, batch_size, sequence_len, vocab_size
):
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.data(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32',
)
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
class TestRuleBasedTuner(unittest.TestCase):
def test_gpt_o2(self):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
batch_size = 8
sequence_len = 512
vocab_size = 1000
place = None
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
place,
batch_size,
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.cluster import Cluster
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
RuleBasedTuner,
)
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
cluster = Cluster()
cluster.gen_default_config_cluster(node_count=1, device_count=8)
dist_context = DistributedContext(
serial_main_prog=train_program,
serial_startup_prog=start_program,
serial_optimizer=opt,
serial_loss=loss,
cluster=cluster,
)
dist_context.initialize()
tuner = RuleBasedTuner(dist_context, level="o2")
tuner.tune()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册