From b77fa1d91137af942c0788232ca5ef54cbccc7b6 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:36:30 +0800 Subject: [PATCH] [Auto Parallel] Support Iterable dataset for auto parallel (#45518) * support iterable dataset for auto parallel * add split_data proto * fix unittest bug * fix recompute bug * update cmake --- .../framework/distributed_strategy.proto | 1 + .../distributed/auto_parallel/dist_loader.py | 116 ++++++++--- .../distributed/auto_parallel/engine.py | 60 ++++-- .../fleet/base/distributed_strategy.py | 22 ++ .../unittests/auto_parallel/CMakeLists.txt | 5 + .../auto_parallel/iterable_dataset.py | 191 ++++++++++++++++++ .../auto_parallel/test_iterable_dataset.py | 49 +++++ 7 files changed, 393 insertions(+), 51 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_iterable_dataset.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 7c02c9bab73..3fd7a994a62 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -343,6 +343,7 @@ message DistributedStrategy { optional bool is_fl_ps_mode = 39 [ default = false ]; optional bool with_coordinator = 40 [ default = false ]; optional bool qat = 41 [ default = false ]; + optional bool split_data = 42 [ default = true ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 02bccb66920..5645235cb71 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -14,10 +14,13 @@ import abc import numpy as np +from functools import wraps + import paddle from .utils import to_list from paddle.fluid.layers.utils import flatten from paddle.io import DataLoader, BatchSampler, IterableDataset +from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn @@ -29,33 +32,41 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): epochs=1, data_parallel_world_size=None, data_parallel_rank=None, - drop_last=False): + drop_last=False, + split_data=True): if isinstance(dataset, IterableDataset): - raise TypeError("IterableDataset is not supported.") + self.dataset_kind = _DatasetKind.ITER else: self.dataset_kind = _DatasetKind.MAP self.dataset = dataset self.epochs = epochs self.drop_lost = drop_last + self.data_parallel_world_size = data_parallel_world_size + self.data_parallel_rank = data_parallel_rank + self.split_data = split_data if batch_size is None: self.batch_size = None self.batch_sampler = None else: if data_parallel_world_size is not None: - assert batch_size % data_parallel_world_size == 0, \ - "'batch_size' must be divisible by data parallel size" + for dp_world_size in data_parallel_world_size: + if dp_world_size is not None: + assert batch_size % dp_world_size == 0, \ + "batch_size must be divisible by dp_world_size value {}".format(str(dp_world_size)) self.batch_size = batch_size - self.batch_sampler = BatchSampler(dataset, - batch_size=batch_size, - shuffle=False, - drop_last=drop_last) + if isinstance(dataset, IterableDataset): + self.batch_sampler = _InfiniteIterableSampler( + dataset, batch_size) + else: + self.batch_sampler = BatchSampler(dataset, + batch_size=batch_size, + shuffle=False, + drop_last=drop_last) self.auto_collate_batch = self.batch_sampler is not None self.sampler_iter = iter(self.index_sampler) - self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size - self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank @abc.abstractmethod def __iter__(self): @@ -73,7 +84,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): if self.dataset_kind == _DatasetKind.MAP: return list(range(len(self.dataset))) else: - raise TypeError("Only support datasets in map-style.") + return _InfiniteIterableSampler(self.dataset, 1) class NonIterableGeneratorLoader(DistributedDataLoader): @@ -88,7 +99,8 @@ class NonIterableGeneratorLoader(DistributedDataLoader): collate_fn=None, data_parallel_world_size=None, data_parallel_rank=None, - drop_last=False): + drop_last=False, + split_data=True): self.feed_list = feed_list self.places = places self.steps_per_epoch = steps_per_epoch @@ -96,7 +108,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): super(NonIterableGeneratorLoader, self).__init__(dataset, batch_size, epochs, data_parallel_world_size, data_parallel_rank, - drop_last) + drop_last, split_data) if self.auto_collate_batch: self.collate_fn = collate_fn or default_collate_fn @@ -115,17 +127,22 @@ class NonIterableGeneratorLoader(DistributedDataLoader): return self def __next__(self): - if self._cur_step < self._steps: + if not self._steps: + self._cur_step += 1 + elif self._cur_step < self._steps: self._cur_step += 1 else: self._inner_dataloader.reset() + self.sampler_iter = iter(self.index_sampler) raise StopIteration def _infer_steps(self): if self.steps_per_epoch is not None: return self.steps_per_epoch try: - if self.batch_size is None: + if isinstance(self.dataset, IterableDataset): + steps_per_epoch = None + elif self.batch_size is None: steps_per_epoch = len(self.dataset) else: steps_per_epoch = len(self.dataset) // self.batch_size @@ -138,26 +155,61 @@ class NonIterableGeneratorLoader(DistributedDataLoader): def _create_inner_dataloader(self): def sample_data_generator(): - for indices in self.sampler_iter: - assert len(indices) % self.dp_world_size == 0, \ - "Please set batch_size to be divisible by data parallel size" - n = len(indices) // self.dp_world_size - cur_indices = [ - indices[i:i + n] for i in range(0, len(indices), n) - ] - batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank]) - yield batch[:len(self.feed_list)] + while True: + try: + indices = next(self.sampler_iter) + batch = self.dataset_fetcher.fetch(indices) + if batch is None: break + + except StopIteration: + self.dataset_fetcher = _DatasetKind.create_fetcher( + self.dataset_kind, self.dataset, + self.auto_collate_batch, self.collate_fn, + self.drop_lost) + break + + partial_data = [] + for i, d in enumerate(batch[:len(self.feed_list)]): + array = np.array(d) + if not self.split_data: + partial_data.append(array) + elif self.dp_world_sizes[i] is not None: + partial_data.append( + np.split(array, + self.dp_world_sizes[i])[self.dp_ranks[i]]) + else: + partial_data.append(array) + yield partial_data def batch_data_generator(): - for indices in self.sampler_iter: + while True: + try: + indices = next(self.sampler_iter) + + batch = self.dataset_fetcher.fetch(indices) + if batch is None: break + except StopIteration: + break + partial_data = [] - batch = self.dataset_fetcher.fetch(indices) - for data in batch: - assert data.shape[0] % self.dp_world_size == 0, \ - "Please padding dataset's batch_size to be divisible by data parallel size" - partial_data.append( - np.split(data, self.dp_world_size)[self.dp_rank]) - yield partial_data[:len(self.feed_list)] + for i, d in enumerate(batch[:len(self.feed_list)]): + array = np.array(d) + if not self.split_data: + partial_data.append(array) + elif self.dp_world_sizes[i] is not None: + partial_data.append( + np.split(array, + self.dp_world_sizes[i])[self.dp_ranks[i]]) + else: + partial_data.append(array) + yield partial_data + + self.dp_world_sizes = [ + 1 for _ in range(len(self.feed_list)) + ] if self.data_parallel_world_size is None else self.data_parallel_world_size + self.dp_ranks = [ + 0 for _ in range(len(self.feed_list)) + ] if self.data_parallel_rank is None else self.data_parallel_rank dataloader = paddle.fluid.io.DataLoader.from_generator( feed_list=self.feed_list, capacity=70, iterable=False) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 11953aa085d..8d1a1488ac7 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -233,8 +233,8 @@ class Engine: assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." batch_size = self._user_tuning_config["batch_size"] dataset = self._user_tuning_config["dataset"] - dataset.dp_world_size = self._input_split_size - dataset.dp_rank = self._input_split_rank + dataset.dp_world_size = self.dp_world_sizes + dataset.dp_rank = self.dp_ranks from .tuner.optimization_tuner import OptimizationTuner self._optimization_tuner = OptimizationTuner(self._user_tuning_config, @@ -272,8 +272,13 @@ class Engine: if var.name in block.vars: feed_list.append(block.vars[var.name]) - self._input_split_size, self._input_split_rank = self._get_input_split_info( - feed_list[0], self._dist_contexts[mode]) + self.dp_world_sizes = [] + self.dp_ranks = [] + for feed_var in feed_list: + dp_world_size, dp_rank = self._get_input_split_info( + feed_var, self._dist_contexts[mode]) + self.dp_world_sizes.append(dp_world_size) + self.dp_ranks.append(dp_rank) def _parallel(self, mode, all_ranks): # Parallelize program based on the planner's results @@ -440,15 +445,23 @@ class Engine: for epoch in range(epochs): train_logs = {"epoch: {:d} ": epoch} for step, _ in enumerate(train_dataloader): + try: + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_cache, + return_numpy=return_numpy) + except fluid.core.EOFException: + break - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) train_logs["step: {:d} "] = step if lr_scheduler is not None: lr_scheduler.step() - train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() + try: + train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() + except: + train_logs[ + "lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr( + ) # inner fetches if fetch_loss: train_logs["loss: {:9f} "] = outs[0][0] @@ -486,10 +499,13 @@ class Engine: for step, _ in enumerate(eval_dataloader): eval_logs = {"step: {:d} ": step} - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) + try: + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_cache, + return_numpy=return_numpy) + except fluid.core.EOFException: + break # inner fetches if fetch_loss: eval_logs["loss: {:9f} "] = outs[0][0] @@ -534,10 +550,13 @@ class Engine: outputs = [] for step, _ in enumerate(test_dataloader): predict_logs = {"step: {:d} ": step} - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) + try: + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_cache, + return_numpy=return_numpy) + except fluid.core.EOFException: + break outputs.append(outs[:len(fetch_outputs)]) for i, out in enumerate(outs): predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out @@ -586,8 +605,9 @@ class Engine: epochs, steps_per_epoch, collate_fn, - data_parallel_world_size=self._input_split_size, - data_parallel_rank=self._input_split_rank) + data_parallel_world_size=self.dp_world_sizes, + data_parallel_rank=self.dp_ranks, + split_data=self.strategy.split_data) # move read op from the end of program to the start of program new_op_size = len(dist_main_block.ops) @@ -682,6 +702,8 @@ class Engine: self.model, "gpt" ) and self.model.__class__.__name__ == 'GPTForPretraining': exact_ckpts = self.model.gpt.checkpoints + else: + exact_ckpts = config["checkpoints"] else: exact_ckpts = config["checkpoints"] diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index b83d97d1d35..2a11dd7eace 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1997,6 +1997,28 @@ class DistributedStrategy(object): else: print("WARNING: auto-search should have value of bool type") + @property + def split_data(self): + """ + Indicating whether we split the data. If True, we split the data. + Default Value: True + Examples: + .. code-block:: python + import paddle + paddle.enable_static() + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.split_data = True + """ + return self.strategy.split_data + + @split_data.setter + def split_data(self, flag): + if isinstance(flag, bool): + self.strategy.split_data = flag + else: + print("WARNING: split_data should have value of bool type") + @property def qat(self): """ diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 422b3db42c3..27f86dc9f10 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -72,4 +72,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) py_test_modules(test_quantization MODULES test_quantization) py_test_modules(test_dist_matmul MODULES test_dist_matmul) + + py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS + ${dist_ENVS}) + set_tests_properties(test_iterable_dataset + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py new file mode 100644 index 00000000000..4ca3d14f716 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py @@ -0,0 +1,191 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import tempfile +import copy +import os +import numpy as np +import subprocess +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.io import Dataset, IterableDataset, DataLoader +from paddle.static import InputSpec +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.engine import Engine +from paddle.optimizer.lr import CosineAnnealingDecay +from paddle.fluid.dataloader.collate import default_collate_fn + +paddle.enable_static() +global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) +PP_MESH_0 = auto.ProcessMesh([0]) +PP_MESH_1 = auto.ProcessMesh([1]) +batch_size = 2 +batch_num = 10 +hidden_size = 1024 +sequence_len = 512 +image_size = hidden_size +class_num = 10 + +paddle.seed(44) + + +class MyDataset(IterableDataset): + + def __init__(self, num_samples): + super(MyDataset, self).__init__() + self.num_samples = num_samples + + def __iter__(self): + for i in range(self.num_samples): + input = np.random.uniform(size=image_size).astype("float32") + label = np.random.randint(0, class_num - 1, dtype="int64") + yield input, label + + +class MyDataset1(Dataset): + + def __init__(self, num_samples): + super(MyDataset1, self).__init__() + self.num_samples = num_samples + self.data = [] + for i in range(self.num_samples): + input1 = np.random.uniform(size=image_size).astype("float32") + label1 = np.array(np.random.randint(0, class_num - 1, + dtype="int64")) + input2 = np.random.uniform(size=image_size).astype("float32") + label2 = np.array(np.random.randint(0, class_num - 1, + dtype="int64")) + input = np.stack((input1, input2)) + label = np.stack((label1, label2)) + self.data.append((input, label)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = auto.shard_op(self.norm, dist_attr={"process_mesh": + PP_MESH_0})(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = auto.shard_op(self.linear1, dist_attr={"process_mesh": + PP_MESH_1})(out) + out = self.dropout(out) + out = self.linear2(out) + self.out = out + return out + + +def train(fetch): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') + labels_spec = InputSpec([batch_size], 'int64', 'label') + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + dist_strategy.split_data = True + fleet.init(is_collective=True, strategy=dist_strategy) + + # init engine + engine = Engine(mlp, + inputs_spec=inputs_spec, + labels_spec=labels_spec, + strategy=dist_strategy) + engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + + # fetch + if fetch: + fetches = {'out': mlp.out} + else: + fetches = None + + # train + train_dataset = MyDataset(batch_num * batch_size) + train_dataset1 = MyDataset1(batch_num) + engine.fit(train_dataset, + epochs=2, + batch_size=batch_size, + steps_per_epoch=batch_num, + fetches=fetches) + + engine.fit(train_dataset1, + epochs=2, + batch_size=None, + steps_per_epoch=batch_num, + fetches=fetches) + + # eval + eval_dataset = MyDataset(batch_size) + engine.evaluate(eval_dataset, batch_size, fetches=fetches) + + # predict + test_dataset = MyDataset(batch_size) + engine.predict(test_dataset, batch_size, fetches=fetches) + + # save + temp_dir = tempfile.TemporaryDirectory() + model_filename = os.path.join(temp_dir.name, 'mlp_inf') + engine.save(model_filename, training=False, mode='predict') + temp_dir.cleanup() + + +if __name__ == "__main__": + train(fetch=True) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_iterable_dataset.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_iterable_dataset.py new file mode 100644 index 00000000000..7e990d88fa9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_iterable_dataset.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestEngineAPI(unittest.TestCase): + + def test_engine_api(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "iterable_dataset.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() -- GitLab