From 3649099f5b0535adebaf372330e866348944556d Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 15 Aug 2022 15:17:56 +0800 Subject: [PATCH] [AutoParallel] add collate_fn for dist_loader (#45053) * add collate_fn * fix number of inputs --- .../distributed/auto_parallel/completion.py | 6 +- .../distributed/auto_parallel/dist_loader.py | 91 +++++++++++++------ .../distributed/auto_parallel/engine.py | 55 ++++++++--- .../auto_parallel/parallelizer_v2.py | 1 + .../auto_parallel/tuner/optimization_tuner.py | 12 ++- .../unittests/auto_parallel/engine_api.py | 17 ++-- .../auto_parallel/high_order_grad.py | 2 +- .../unittests/test_auto_parallel_reshard.py | 35 +++++-- 8 files changed, 151 insertions(+), 68 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 4eb2f45cc18..64a3a96ae6b 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1300,6 +1300,10 @@ class Completer: def complete_update_annotation(self, serial_main_program): """Complete the annotation of vars and ops in the update phase for parallel program.""" + # Copy the dist tensors and dist ops annotated by users from the default context + # global mesh + from paddle.distributed.auto_parallel.process_group import get_world_process_group + world_ranks = get_world_process_group().ranks # Notice: serial_main_program is actually a dist_main_program of current rank, # and must be passed into this function. @@ -1371,7 +1375,7 @@ class Completer: if not learning_rate_completed: learning_rate_completed = True var_dist_attr = TensorDistributedAttribute() - var_dist_attr.process_mesh = ref_process_mesh + var_dist_attr.process_mesh = world_ranks var_dist_attr.dims_mapping = [-1] self._dist_context.set_tensor_dist_attr_for_program( learning_var, var_dist_attr) diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 03cc340fecd..02bccb66920 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -17,7 +17,8 @@ import numpy as np import paddle from .utils import to_list from paddle.fluid.layers.utils import flatten -from paddle.io import DataLoader, DistributedBatchSampler +from paddle.io import DataLoader, BatchSampler, IterableDataset +from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn class DistributedDataLoader(metaclass=abc.ABCMeta): @@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): data_parallel_world_size=None, data_parallel_rank=None, drop_last=False): + if isinstance(dataset, IterableDataset): + raise TypeError("IterableDataset is not supported.") + else: + self.dataset_kind = _DatasetKind.MAP + self.dataset = dataset - self.batch_size = batch_size self.epochs = epochs - self.data_parallel_world_size = data_parallel_world_size - self.data_parallel_rank = data_parallel_rank self.drop_lost = drop_last - if data_parallel_world_size is not None and batch_size is not None: - assert batch_size % data_parallel_world_size == 0 + + if batch_size is None: + self.batch_size = None + self.batch_sampler = None + else: + if data_parallel_world_size is not None: + assert batch_size % data_parallel_world_size == 0, \ + "'batch_size' must be divisible by data parallel size" + self.batch_size = batch_size + self.batch_sampler = BatchSampler(dataset, + batch_size=batch_size, + shuffle=False, + drop_last=drop_last) + + self.auto_collate_batch = self.batch_sampler is not None + self.sampler_iter = iter(self.index_sampler) + self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size + self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank @abc.abstractmethod def __iter__(self): @@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): def __next__(self): raise NotImplementedError + @property + def index_sampler(self): + if self.auto_collate_batch: + return self.batch_sampler + else: + if self.dataset_kind == _DatasetKind.MAP: + return list(range(len(self.dataset))) + else: + raise TypeError("Only support datasets in map-style.") + class NonIterableGeneratorLoader(DistributedDataLoader): @@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader): batch_size=1, epochs=1, steps_per_epoch=None, + collate_fn=None, data_parallel_world_size=None, data_parallel_rank=None, drop_last=False): self.feed_list = feed_list self.places = places self.steps_per_epoch = steps_per_epoch - self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size - self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank super(NonIterableGeneratorLoader, self).__init__(dataset, batch_size, epochs, data_parallel_world_size, data_parallel_rank, drop_last) - self._inner_dataloader = self._create_inner_dataloader() + + if self.auto_collate_batch: + self.collate_fn = collate_fn or default_collate_fn + else: + self.collate_fn = collate_fn or default_convert_fn + self.dataset_fetcher = _DatasetKind.create_fetcher( + self.dataset_kind, self.dataset, self.auto_collate_batch, + self.collate_fn, self.drop_lost) + self._steps = self._infer_steps() + self._inner_dataloader = self._create_inner_dataloader() def __iter__(self): self._cur_step = 0 @@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader): def _create_inner_dataloader(self): def sample_data_generator(): - batch_data = None - for step, data in enumerate(self.dataset): - data = flatten(data) - if batch_data is None: - batch_data = [[] for i in range(len(data))] - for idx in range(len(data)): - batch_data[idx].append(data[idx]) - if (step + 1) % self.batch_size == 0: - partial_data = [] - for d in batch_data: - array = np.array(d) - partial_data.append( - np.split(array, self.dp_world_size)[self.dp_rank]) - yield partial_data[:len(self.feed_list)] - batch_data = None + for indices in self.sampler_iter: + assert len(indices) % self.dp_world_size == 0, \ + "Please set batch_size to be divisible by data parallel size" + n = len(indices) // self.dp_world_size + cur_indices = [ + indices[i:i + n] for i in range(0, len(indices), n) + ] + batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank]) + yield batch[:len(self.feed_list)] def batch_data_generator(): - for data in self.dataset: - data = flatten(data) + for indices in self.sampler_iter: partial_data = [] - for d in data: - assert d.shape[0] % self.dp_world_size == 0, \ - "Please padding dataset with data parallel size" + batch = self.dataset_fetcher.fetch(indices) + for data in batch: + assert data.shape[0] % self.dp_world_size == 0, \ + "Please padding dataset's batch_size to be divisible by data parallel size" partial_data.append( - np.split(d, self.dp_world_size)[self.dp_rank]) + np.split(data, self.dp_world_size)[self.dp_rank]) yield partial_data[:len(self.feed_list)] dataloader = paddle.fluid.io.DataLoader.from_generator( diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index cd76b3dfcd3..3b38a169a5c 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time import copy import logging from collections import defaultdict @@ -306,6 +307,7 @@ class Engine: mode].dist_startup_programs self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars + self._optimizer = self._dist_contexts[mode].serial_optimizer if self._nranks > 1: # Traverse different rank programs and traverse each op of them, @@ -403,7 +405,8 @@ class Engine: epochs=1, fetches=None, steps_per_epoch=None, - use_program_cache=False, + collate_fn=None, + use_cache=False, return_numpy=True): # TODO: callbacks # TODO: evaluate after training @@ -417,18 +420,24 @@ class Engine: assert self.mode in self._dist_main_progs, \ "train model is not ready, please call `engine.prepare()` first." train_dataloader = self._create_dataloader(train_data, batch_size, - epochs, steps_per_epoch) + epochs, steps_per_epoch, + collate_fn) usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) + lr_scheduler = self.get_lr_scheduler(self.main_program) + for epoch in range(epochs): train_logs = {"epoch": epoch} for step, _ in enumerate(train_dataloader): outs = self._executor.run(self.main_program, fetch_list=fetch_list, - use_program_cache=use_program_cache, + use_program_cache=use_cache, return_numpy=return_numpy) + if lr_scheduler is not None: + lr_scheduler.step() + train_logs["lr"] = self._optimizer.get_lr() train_logs["step"] = step # inner fetches if fetch_loss: @@ -444,7 +453,8 @@ class Engine: eval_data, batch_size=1, fetches=None, - use_program_cache=False, + collate_fn=None, + use_cache=False, return_numpy=True): self.mode = 'eval' if not self._mode_init_states[self.mode]: @@ -452,7 +462,9 @@ class Engine: assert self.mode in self._dist_main_progs, \ "eval model is not ready, please call `engine.prepare()` first." - eval_dataloader = self._create_dataloader(eval_data, batch_size) + eval_dataloader = self._create_dataloader(eval_data, + batch_size, + collate_fn=collate_fn) usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) @@ -464,7 +476,7 @@ class Engine: eval_logs = {"step": step} outs = self._executor.run(self.main_program, fetch_list=fetch_list, - use_program_cache=use_program_cache, + use_program_cache=use_cache, return_numpy=return_numpy) # inner fetches if fetch_loss: @@ -489,7 +501,8 @@ class Engine: test_data, batch_size=1, fetches=None, - use_program_cache=False, + collate_fn=None, + use_cache=False, return_numpy=True): self.mode = 'predict' if not self._mode_init_states[self.mode]: @@ -497,7 +510,9 @@ class Engine: assert self.mode in self._dist_main_progs, \ "predict model is not ready, please call `engine.prepare()` first." - test_dataloader = self._create_dataloader(test_data, batch_size) + test_dataloader = self._create_dataloader(test_data, + batch_size, + collate_fn=collate_fn) usr_fetch = self._validate_fetches(fetches) fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) @@ -508,7 +523,7 @@ class Engine: predict_logs = {"step": step} outs = self._executor.run(self.main_program, fetch_list=fetch_list, - use_program_cache=use_program_cache, + use_program_cache=use_cache, return_numpy=return_numpy) outputs.append(outs[:len(fetch_outputs)]) for i, out in enumerate(outs): @@ -521,7 +536,8 @@ class Engine: dataset, batch_size, epochs=1, - steps_per_epoch=None): + steps_per_epoch=None, + collate_fn=None): dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_context = self._dist_contexts[self.mode] @@ -554,6 +570,7 @@ class Engine: batch_size, epochs, steps_per_epoch, + collate_fn, data_parallel_world_size=self._input_split_size, data_parallel_rank=self._input_split_rank) @@ -645,12 +662,11 @@ class Engine: config = self.strategy.recompute_configs # extract ckpts by specific model - self.model if isinstance(self.model, paddle.nn.Layer): if hasattr( - self.model, "model" - ) and self.model.model.__class__.__name__ == 'GPTForPretraining': - exact_ckpts = self.model.model.gpt.checkpoints + self.model, "gpt" + ) and self.model.__class__.__name__ == 'GPTForPretraining': + exact_ckpts = self.model.gpt.checkpoints else: exact_ckpts = config["checkpoints"] @@ -659,7 +675,7 @@ class Engine: config["checkpoints"] = exact_ckpts[:] self.strategy.recompute_configs = config logs = { - 'Model Class': self.model.model.__class__.__name__, + 'Model Class': self.model.__class__.__name__, 'Applied Recompute ckpts': exact_ckpts } self._logger.info(logs) @@ -699,6 +715,15 @@ class Engine: self._saver.load(path, dist_main_prog, dist_context, strict, load_optimizer) + @staticmethod + def get_lr_scheduler(program): + lr_sheduler = None + if hasattr(program, 'lr_sheduler'): + from paddle.optimizer.lr import LRScheduler + lr_sheduler = program.lr_sheduler + assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" + return lr_sheduler + @property def mode(self): return self._mode diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index e6b30e03680..8dd26dd7678 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -149,6 +149,7 @@ class Parallelizer: paddle.enable_static() else: optimizer = copy.deepcopy(optimizer) + self._dist_context._serial_optimizer = optimizer with program_guard(main_program, startup_program): optimizer_ops = optimizer.apply_gradients(params_grads) self._completer.complete_update_annotation(main_program) diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index bb50e2fb9c5..ec50371c7ec 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -363,11 +363,15 @@ class OptimizationTuner: profile_args = " ".join([ "--rank", - str(self.rank), "--device_id", - str(self.device_id), "--ctx_filename", ctx_path, + str(self.rank), + "--device_id", + str(self.device_id), + "--ctx_filename", + ctx_path, "--profile_start_step", - str(self._config.profile_start_step), "--profile_end_step", - str(self._config.profile_end_step) + str(self._config.profile_start_step), + "--profile_end_step", + str(self._config.profile_end_step), ]) cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 9335d7d9d2e..99d130e8756 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -31,6 +31,8 @@ from paddle.static import InputSpec from paddle.distributed import fleet import paddle.distributed.auto_parallel as auto from paddle.distributed.auto_parallel.engine import Engine +from paddle.optimizer.lr import CosineAnnealingDecay +from paddle.fluid.dataloader.collate import default_collate_fn paddle.enable_static() global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) @@ -106,19 +108,18 @@ def train(fetch): dropout_ratio=0.1, initializer_range=0.02) loss = paddle.nn.CrossEntropyLoss() - optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00001, + T_max=10) + optimizer = paddle.optimizer.Adam(learning_rate=scheduler, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') labels_spec = InputSpec([batch_size], 'int64', 'label') dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py index 1de44e91a78..9ab49b30d9d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py @@ -145,7 +145,7 @@ def main(): labels_spec=labels_spec, strategy=dist_strategy) engine.prepare(optimizer=optimizer, loss=loss_func) - res = engine.fit(train_dataset, batch_size=None) + engine.fit(train_dataset, batch_size=None) dist_context = engine.dist_context block = engine.main_program.global_block() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 4df770df696..51926286acc 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase): if op.type == "gelu_grad": op_need_check = op break - # print_program_with_dist_attr(dist_main_prog, dist_context) # grad op should have dist attr self.assertTrue( check_backward_dist_attr(dist_context, dist_main_prog, op_need_check)) + # clear _g_process_group_map + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + def test_mlp_pp(self): global _global_parallel_strategy _global_parallel_strategy = "pp" @@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase): rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id) - for key in list(_g_process_group_map.keys()): - del _g_process_group_map[key] - _g_process_group_map[0] = ProcessGroup(0, []) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) - # parameter initialization of every rank should be different in the pipeline scene self.assertTrue(check_initialization(dist_startup_prog, rank_id)) + # clear _g_process_group_map + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + def test_mlp_pp_diff_process_mesh(self): + global _global_parallel_strategy + _global_parallel_strategy = "pp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + global PP_MESH_0 + PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + global PP_MESH_1 + PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + train_program = paddle.static.Program() startup_program = paddle.static.Program() dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( train_program, startup_program, dist_context, rank_id, True) - for key in list(_g_process_group_map.keys()): - del _g_process_group_map[key] - _g_process_group_map[0] = ProcessGroup(0, []) resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() @@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase): self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id)) + # clear _g_process_group_map + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + def test_mlp_dp(self): global _global_parallel_strategy _global_parallel_strategy = "dp" @@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase): resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() + # send and recv should not exist in dp scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) - # all parameters should be initialized in dp scene self.assertTrue(check_initialization_for_dp(dist_startup_prog)) + # clear _g_process_group_map + _g_process_group_map.clear() + _g_process_group_map[0] = ProcessGroup(0, []) + if __name__ == "__main__": unittest.main() -- GitLab