未验证 提交 3649099f 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add collate_fn for dist_loader (#45053)

* add collate_fn

* fix number of inputs
上级 8788513b
......@@ -1300,6 +1300,10 @@ class Completer:
def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_ranks = get_world_process_group().ranks
# Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function.
......@@ -1371,7 +1375,7 @@ class Completer:
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr)
......
......@@ -17,7 +17,8 @@ import numpy as np
import paddle
from .utils import to_list
from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
class DistributedDataLoader(metaclass=abc.ABCMeta):
......@@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False):
if isinstance(dataset, IterableDataset):
raise TypeError("IterableDataset is not supported.")
else:
self.dataset_kind = _DatasetKind.MAP
self.dataset = dataset
self.batch_size = batch_size
self.epochs = epochs
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.drop_lost = drop_last
if data_parallel_world_size is not None and batch_size is not None:
assert batch_size % data_parallel_world_size == 0
if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
if data_parallel_world_size is not None:
assert batch_size % data_parallel_world_size == 0, \
"'batch_size' must be divisible by data parallel size"
self.batch_size = batch_size
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
@abc.abstractmethod
def __iter__(self):
......@@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
def __next__(self):
raise NotImplementedError
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
raise TypeError("Only support datasets in map-style.")
class NonIterableGeneratorLoader(DistributedDataLoader):
......@@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
data_parallel_world_size=None,
data_parallel_rank=None,
drop_last=False):
self.feed_list = feed_list
self.places = places
self.steps_per_epoch = steps_per_epoch
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs,
data_parallel_world_size, data_parallel_rank,
drop_last)
self._inner_dataloader = self._create_inner_dataloader()
if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
else:
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
self._cur_step = 0
......@@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def _create_inner_dataloader(self):
def sample_data_generator():
batch_data = None
for step, data in enumerate(self.dataset):
data = flatten(data)
if batch_data is None:
batch_data = [[] for i in range(len(data))]
for idx in range(len(data)):
batch_data[idx].append(data[idx])
if (step + 1) % self.batch_size == 0:
partial_data = []
for d in batch_data:
array = np.array(d)
partial_data.append(
np.split(array, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
batch_data = None
for indices in self.sampler_iter:
assert len(indices) % self.dp_world_size == 0, \
"Please set batch_size to be divisible by data parallel size"
n = len(indices) // self.dp_world_size
cur_indices = [
indices[i:i + n] for i in range(0, len(indices), n)
]
batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank])
yield batch[:len(self.feed_list)]
def batch_data_generator():
for data in self.dataset:
data = flatten(data)
for indices in self.sampler_iter:
partial_data = []
for d in data:
assert d.shape[0] % self.dp_world_size == 0, \
"Please padding dataset with data parallel size"
batch = self.dataset_fetcher.fetch(indices)
for data in batch:
assert data.shape[0] % self.dp_world_size == 0, \
"Please padding dataset's batch_size to be divisible by data parallel size"
partial_data.append(
np.split(d, self.dp_world_size)[self.dp_rank])
np.split(data, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
dataloader = paddle.fluid.io.DataLoader.from_generator(
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import copy
import logging
from collections import defaultdict
......@@ -306,6 +307,7 @@ class Engine:
mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode].serial_optimizer
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
......@@ -403,7 +405,8 @@ class Engine:
epochs=1,
fetches=None,
steps_per_epoch=None,
use_program_cache=False,
collate_fn=None,
use_cache=False,
return_numpy=True):
# TODO: callbacks
# TODO: evaluate after training
......@@ -417,18 +420,24 @@ class Engine:
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
epochs, steps_per_epoch,
collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
lr_scheduler = self.get_lr_scheduler(self.main_program)
for epoch in range(epochs):
train_logs = {"epoch": epoch}
for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
use_program_cache=use_cache,
return_numpy=return_numpy)
if lr_scheduler is not None:
lr_scheduler.step()
train_logs["lr"] = self._optimizer.get_lr()
train_logs["step"] = step
# inner fetches
if fetch_loss:
......@@ -444,7 +453,8 @@ class Engine:
eval_data,
batch_size=1,
fetches=None,
use_program_cache=False,
collate_fn=None,
use_cache=False,
return_numpy=True):
self.mode = 'eval'
if not self._mode_init_states[self.mode]:
......@@ -452,7 +462,9 @@ class Engine:
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
eval_dataloader = self._create_dataloader(eval_data,
batch_size,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
......@@ -464,7 +476,7 @@ class Engine:
eval_logs = {"step": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
use_program_cache=use_cache,
return_numpy=return_numpy)
# inner fetches
if fetch_loss:
......@@ -489,7 +501,8 @@ class Engine:
test_data,
batch_size=1,
fetches=None,
use_program_cache=False,
collate_fn=None,
use_cache=False,
return_numpy=True):
self.mode = 'predict'
if not self._mode_init_states[self.mode]:
......@@ -497,7 +510,9 @@ class Engine:
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
test_dataloader = self._create_dataloader(test_data,
batch_size,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
......@@ -508,7 +523,7 @@ class Engine:
predict_logs = {"step": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
use_program_cache=use_cache,
return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
......@@ -521,7 +536,8 @@ class Engine:
dataset,
batch_size,
epochs=1,
steps_per_epoch=None):
steps_per_epoch=None,
collate_fn=None):
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode]
......@@ -554,6 +570,7 @@ class Engine:
batch_size,
epochs,
steps_per_epoch,
collate_fn,
data_parallel_world_size=self._input_split_size,
data_parallel_rank=self._input_split_rank)
......@@ -645,12 +662,11 @@ class Engine:
config = self.strategy.recompute_configs
# extract ckpts by specific model
self.model
if isinstance(self.model, paddle.nn.Layer):
if hasattr(
self.model, "model"
) and self.model.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.model.gpt.checkpoints
self.model, "gpt"
) and self.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.gpt.checkpoints
else:
exact_ckpts = config["checkpoints"]
......@@ -659,7 +675,7 @@ class Engine:
config["checkpoints"] = exact_ckpts[:]
self.strategy.recompute_configs = config
logs = {
'Model Class': self.model.model.__class__.__name__,
'Model Class': self.model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts
}
self._logger.info(logs)
......@@ -699,6 +715,15 @@ class Engine:
self._saver.load(path, dist_main_prog, dist_context, strict,
load_optimizer)
@staticmethod
def get_lr_scheduler(program):
lr_sheduler = None
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
lr_sheduler = program.lr_sheduler
assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
return lr_sheduler
@property
def mode(self):
return self._mode
......
......@@ -149,6 +149,7 @@ class Parallelizer:
paddle.enable_static()
else:
optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer
with program_guard(main_program, startup_program):
optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program)
......
......@@ -363,11 +363,15 @@ class OptimizationTuner:
profile_args = " ".join([
"--rank",
str(self.rank), "--device_id",
str(self.device_id), "--ctx_filename", ctx_path,
str(self.rank),
"--device_id",
str(self.device_id),
"--ctx_filename",
ctx_path,
"--profile_start_step",
str(self._config.profile_start_step), "--profile_end_step",
str(self._config.profile_end_step)
str(self._config.profile_start_step),
"--profile_end_step",
str(self._config.profile_end_step),
])
cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args
cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args)
......
......@@ -31,6 +31,8 @@ from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
......@@ -106,19 +108,18 @@ def train(fetch):
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00001,
T_max=10)
optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
......
......@@ -145,7 +145,7 @@ def main():
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, batch_size=None)
engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context
block = engine.main_program.global_block()
......
......@@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase):
if op.type == "gelu_grad":
op_need_check = op
break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr
self.assertTrue(
check_backward_dist_attr(dist_context, dist_main_prog,
op_need_check))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_pp(self):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
......@@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 1
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
# parameter initialization of every rank should be different in the pipeline scene
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_pp_diff_process_mesh(self):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
dist_context = DistributedContext()
rank_id = 1
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
......@@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_dp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
......@@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
# send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
# all parameters should be initialized in dp scene
self.assertTrue(check_initialization_for_dp(dist_startup_prog))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册