From 5595fdbbd20b80190b30ab2f77329f2c0c4cfdc4 Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 22 Feb 2022 11:37:11 +0800 Subject: [PATCH] [Auto Parallel] Add the high-level Engine API (#39709) * [Auto Parallel] Add the high-level Engine API * Update the test cmakefile --- .../distributed/auto_parallel/dist_context.py | 20 +- .../distributed/auto_parallel/dist_loader.py | 93 ++++++ .../distributed/auto_parallel/engine.py | 309 ++++++++++++++++++ .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/test_engine_api.py | 135 ++++++++ 5 files changed, 552 insertions(+), 6 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/dist_loader.py create mode 100644 python/paddle/distributed/auto_parallel/engine.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index e06811df88..caf220646b 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -45,9 +45,13 @@ class DistributedContext: One auto-parallel run should use its own DistributedContext to avoid interfering other run. """ - def __init__(self, program=None): + def __init__(self, + serial_main_prog=None, + serial_startup_prog=None, + dist_main_progs=None, + dist_startup_progs=None): # Program related data members - self._serial_program = program + self._serial_program = serial_main_prog self._is_initialized_for_program = False self._dist_tensors_for_program = {} self._dist_ops_for_program = {} @@ -65,8 +69,12 @@ class DistributedContext: self._tensor_id_to_tensor_node_ids = {} # Distributed programs - self._dist_main_programs = {} - self._dist_startup_programs = {} + self._dist_main_programs = dist_main_progs + if not self._dist_main_programs: + self._dist_main_programs = {} + self._dist_startup_programs = dist_startup_progs + if not self._dist_startup_programs: + self._dist_startup_programs = {} @property def serial_program(self): @@ -78,8 +86,8 @@ class DistributedContext: @serial_program.setter def serial_program(self, program): - assert self._serial_program is None, \ - "This distributed context has already been realted to a serial program" + # assert self._serial_program is None, \ + # "This distributed context has already been realted to a serial program" self._serial_program = program @property diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py new file mode 100644 index 0000000000..92deeffd2c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import abc +import numpy as np +import paddle +from paddle.io import DataLoader, DistributedBatchSampler + + +class DistributedDataLoader(metaclass=abc.ABCMeta): + def __init__(self, + dataset, + batch_size=1, + epochs=1, + data_parallel_world_size=None, + data_parallel_rank=None, + drop_last=False): + self.dataset = dataset + self.batch_size = batch_size + self.epochs = epochs + self.data_parallel_world_size = data_parallel_world_size + self.data_parallel_rank = data_parallel_rank + self.drop_lost = drop_last + if data_parallel_world_size is not None: + assert batch_size % data_parallel_world_size == 0 + + @abc.abstractmethod + def __iter__(self): + raise NotImplementedError + + @abc.abstractmethod + def __next__(self): + raise NotImplementedError + + +class NonIterableGeneratorLoader(DistributedDataLoader): + def __init__(self, + dataset, + feed_list, + places, + batch_size=1, + epochs=1, + steps_per_epoch=1000, + data_parallel_world_size=None, + data_parallel_rank=None, + drop_last=False): + self.feed_list = feed_list + self.places = places + self.steps_per_epoch = steps_per_epoch + super(NonIterableGeneratorLoader, self).__init__( + dataset, batch_size, epochs, data_parallel_world_size, + data_parallel_rank, drop_last) + self._inner_dataloader = self._create_inner_dataloader() + + def __iter__(self): + self._cur_step = 0 + self._inner_dataloader.start() + return self + + def __next__(self): + if self._cur_step < self.steps_per_epoch: + self._cur_step += 1 + else: + self._inner_dataloader.reset() + raise StopIteration + + def _create_inner_dataloader(self): + def data_generator(): + batch_data = None + for step, data in enumerate(self.dataset): + if batch_data is None: + batch_data = [[] for i in range(len(data))] + for idx, data_item in enumerate(data): + batch_data[idx].append(np.array(data_item)) + if (step + 1) % self.batch_size == 0: + yield batch_data[0], batch_data[1] + batch_data = None + + dataloader = paddle.fluid.io.DataLoader.from_generator( + feed_list=self.feed_list, capacity=70, iterable=False) + dataloader.set_batch_generator(data_generator, self.places) + return dataloader diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py new file mode 100644 index 0000000000..98b76056a1 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -0,0 +1,309 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +from collections import defaultdict + +import paddle +from paddle import fluid +from paddle.io import Dataset +from paddle.fluid.backward import append_backward +import paddle.fluid.core as core +from paddle.static import InputSpec +from paddle.fluid import program_guard +from paddle.fluid.framework import Operator +from paddle.fluid.framework import _current_expected_place as _get_device +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed.passes import new_pass, PassContext +from paddle.distributed.utils import get_logger + +from .dist_loader import NonIterableGeneratorLoader +from .dist_op import DistributedOperator +from .dist_tensor import DistributedTensor +from .dist_context import DistributedContext +from .dist_context import get_default_distributed_context +from .dist_context import set_default_distributed_context +from .process_group import get_all_process_groups +from .process_group import get_process_group +from .process_group import get_world_process_group +from .process_group import _g_process_group_map, ProcessGroup +from .completion import Completer +from .partitioner import Partitioner +from .reshard import reshard, HAS_SENT, HAS_RECV, HAS_ALLGATHER +from .cluster import Cluster +from .mapper import mapping +from .planner import Planner +from .utils import make_data_unshard +from .utils import set_grad_var_shape +from .utils import print_program_with_dist_attr +from .utils import SerialProgramInfo + +paddle.enable_static() + + +def to_list(value): + if value is None: + return value + if isinstance(value, (list, tuple)): + return list(value) + return [value] + + +class Engine: + def __init__(self, model=None, data_spec=None, cluster=None, strategy=None): + self.model = model + self.data_spec = data_spec + self.cluster = cluster + self.strategy = strategy + self._executor = None + self._orig_main_prog = fluid.default_main_program() + self._orig_startup_prog = fluid.default_startup_program() + self._serial_main_progs = {} + self._serial_startup_progs = {} + self._dist_main_progs = defaultdict(dict) + self._dist_startup_progs = defaultdict(dict) + self._orig_dist_context = get_default_distributed_context() + self._dist_contexts = {} + self._pass_contexts = {} + self._cur_rank = paddle.distributed.get_rank() + self._logger = get_logger(logging.INFO) + + def prepare(self, + optimizer=None, + loss=None, + metrics=None, + mode="train", + all_ranks=False): + self.optimizer = optimizer + self.loss = loss + self.metrics = metrics + self.mode = mode + self._build() + self._plan() + if not all_ranks: + self._parallel(self._cur_rank) + else: + world_process_group = get_world_process_group() + all_ranks = world_process_group.ranks + for rank in all_ranks: + self._parallel(rank) + place = _get_device() + if isinstance(place, fluid.CUDAPlace): + self._place = fluid.CUDAPlace(ParallelEnv().dev_id) + if self._executor is None: + self._executor = fluid.Executor(place) + + def _build(self): + serial_main_prog = self._serial_main_progs.get(self.mode, None) + if serial_main_prog is not None: + return + + serial_main_prog = self._orig_main_prog.clone() + serial_startup_prog = self._orig_startup_prog.clone() + with fluid.program_guard(serial_main_prog, serial_startup_prog): + inputs_spec = self.data_spec[0] + labels_spec = self.data_spec[1] + inputs = [s._create_feed_layer() for s in to_list(inputs_spec)] + labels = [s._create_feed_layer() for s in to_list(labels_spec)] + self._input_vars = inputs + self._label_vars = labels + feed_list = self._input_vars + self._label_vars + outputs = to_list(self.model(*inputs)) + if self.mode != "predict" and self.loss: + loss = self.loss(*(outputs + labels)) + self._loss_var = loss + + self._serial_main_progs[self.mode] = serial_main_prog + self._serial_startup_progs[self.mode] = serial_startup_prog + self._dist_contexts[self.mode] = DistributedContext( + serial_main_prog, serial_startup_prog, + self._dist_main_progs[self.mode], + self._dist_startup_progs[self.mode]) + self._pass_contexts[self.mode] = PassContext() + + def _plan(self): + # Complete the distributed annotation + serial_main_prog = self._serial_main_progs[self.mode] + self._completer = Completer(self._dist_contexts[self.mode]) + self._completer.complete_forward_annotation(serial_main_prog) + # TODO: add auto planner process + + def _parallel(self, rank): + serial_main_program = self._serial_main_progs[self.mode] + serial_startup_program = self._serial_startup_progs[self.mode] + dist_context = self._dist_contexts[self.mode] + if self.mode != "predict" and self.loss: + # Generate backward + serial_loss = self._loss_var + params_grads = self._generate_backward( + serial_main_program, serial_startup_program, serial_loss) + # Apply pre optimization passes + self._apply_pre_optimization(serial_main_program, + serial_startup_program, serial_loss, + params_grads) + # Do logical partition + partitioner = Partitioner(dist_context, rank) + dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( + serial_main_program, serial_startup_program, params_grads) + # Generate optimizer + self._generate_optimizer(dist_main_prog, dist_startup_prog, + dist_params_grads) + # Do reshard process + set_grad_var_shape(dist_main_prog, dist_context) + make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) + reshard(dist_main_prog, dist_startup_prog, rank, dist_context, + dist_params_grads) + # Apply post optimization passes + self._apply_post_optimization(dist_main_prog, dist_startup_prog, + rank, dist_params_grads) + self._dist_main_progs[self.mode][rank] = dist_main_prog + self._dist_startup_progs[self.mode][rank] = dist_startup_prog + + def _generate_backward(self, main_program, startup_program, loss): + with program_guard(main_program, startup_program): + params_grads = append_backward( + loss, + distop_context=self._dist_contexts[self.mode].dist_op_context) + self._completer.complete_backward_annotation(main_program) + return params_grads + + def _generate_optimizer(self, main_program, startup_program, params_grads): + with program_guard(main_program, startup_program): + optimizer_ops = copy.deepcopy(self.optimizer).apply_gradients( + params_grads) + self._completer.complete_update_annotation(main_program) + return optimizer_ops + + def _apply_pre_optimization(self, main_program, startup_program, loss, + params_grads): + # apply amp pass + if self.strategy.amp: + config = copy.deepcopy(self.strategy.amp_configs) + config["dist_context"] = self._dist_contexts[self.mode] + config["params_grads"] = params_grads + config["loss"] = loss + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply([main_program], [startup_program], + self._pass_contexts[self.mode]) + + # apply recompute pass + if self.strategy.recompute: + config = copy.deepcopy(self.strategy.recompute_configs) + config["dist_context"] = self._dist_contexts[self.mode] + config["no_grad_set"] = None + config["loss"] = loss + auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", + config) + auto_parallel_recompute_pass.apply([main_program], + [startup_program], + self._pass_contexts[self.mode]) + + def _apply_post_optimization(self, main_program, startup_program, rank, + params_grads): + if self.strategy.sharding: + config = copy.deepcopy(self.strategy.sharding_configs) + config["dist_context"] = self._dist_contexts[self.mode] + config["params_grads"] = params_grads + config["global_rank"] = rank + auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", + config) + auto_parallel_sharding_pass.apply([main_program], + [startup_program], + self._pass_contexts[self.mode]) + + if self.strategy.gradient_merge: + config = copy.deepcopy(self.strategy.gradient_merge_configs) + config["dist_context"] = self._dist_contexts[self.mode] + config["params_grads"] = params_grads + auto_parallel_gradient_merge_pass = new_pass( + "auto_parallel_gradient_merge_pass", config) + auto_parallel_gradient_merge_pass.apply( + [main_program], [startup_program], + self._pass_contexts[self.mode]) + + def fit(self, train_data, batch_size=1, epochs=1, steps_per_epoch=1000): + assert isinstance(train_data, Dataset) + assert steps_per_epoch is not None + train_dataloader = self._create_dataloader(train_data, batch_size, + epochs, steps_per_epoch) + self._init_communication() + dist_startup_prog = self._dist_startup_progs["train"][self._cur_rank] + self._executor.run(dist_startup_prog) + for epoch in range(epochs): + # train_dataloader.start() + # for step in range(steps_per_epoch): + # logs = self.train_step(None) + # self._logger.info(logs) + # train_dataloader.reset() + for step, data in enumerate(train_dataloader): + logs = self._train_step(data) + train_logs = { + "train_" + name: val + for name, val in logs.items() + } + self._logger.info(logs) + + def _train_step(self, data): + logs = {} + dist_main_prog = self._dist_main_progs["train"][self._cur_rank] + if self._loss_var.name not in dist_main_prog.global_block().vars: + loss = self._executor.run(dist_main_prog) + logs["loss"] = None + else: + fetch_list = self._loss_var + loss = self._executor.run(dist_main_prog, fetch_list=fetch_list) + logs["loss"] = loss + return logs + + def _create_dataloader(self, dataset, batch_size, epochs, steps_per_epoch): + feed_list = self._input_vars + self._label_vars + dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] + dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] + dist_context = self._dist_contexts[self.mode] + dist_main_block = dist_main_prog.global_block() + op_size = len(dist_main_block.ops) + places = paddle.static.cuda_places() + with fluid.program_guard(dist_main_prog, dist_startup_prog): + dataloader = NonIterableGeneratorLoader( + dataset, feed_list, places, batch_size, epochs, steps_per_epoch) + new_op_size = len(dist_main_block.ops) + for idx in range(new_op_size - 1, op_size - 1, -1): + op = dist_main_block.ops[new_op_size - 1] + new_op_desc = dist_main_block.desc._prepend_op() + new_op_desc.copy_from(op.desc) + new_op = Operator( + dist_main_block, new_op_desc, type=new_op_desc.type()) + dist_main_block.ops.insert(0, new_op) + dist_op = DistributedOperator(new_op) + dist_context.add_dist_op_for_program(dist_op) + for _ in range(new_op_size - op_size): + dist_main_block._remove_op(new_op_size, sync=False) + dist_main_block._sync_with_cpp() + return dataloader + + def _init_communication(self): + # Traverse different rank programs and traverse each op of them, + # instantiate communication by process_mapping. + all_process_groups = get_all_process_groups() + for process_group in all_process_groups: + if self._cur_rank not in process_group.ranks: + continue + process_group.instantiate() + + # def save(self, path, training=True): + # pass + + # def load(self, path, strict=True, load_optimizer=True): + # pass diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 220611be18..0a9eaf34ba 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -7,4 +7,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240) + py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py new file mode 100644 index 0000000000..0fc1ea4103 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import paddle.fluid as fluid +import copy +import os +import numpy as np +import subprocess +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.io import Dataset, IterableDataset, DataLoader +from paddle.static import InputSpec +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.engine import Engine + +paddle.enable_static() +global_process_mesh = auto.ProcessMesh(mesh=[0]) +batch_size = 1 +batch_num = 10 +hidden_size = 1024 +sequence_len = 512 +image_size = hidden_size +class_num = 10 + +paddle.seed(44) + + +class MyDataset(Dataset): + def __init__(self, num_samples): + super(MyDataset, self).__init__() + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + label = np.random.randint(0, class_num - 1, dtype="int64") + return input, label + + def __len__(self): + return self.num_samples + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + # self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + # self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": global_process_mesh, + "dims_mappig": [-1] + }) + # out = self.norm(input) + out = self.linear0(input) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + # out = self.dropout(out) + out = self.linear2(out) + return out + + +class TestEngineAPI(unittest.TestCase): + def test_engine_api(self): + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + dataset = MyDataset(batch_num * batch_size) + data_spec = [ + InputSpec([batch_size, hidden_size], 'float32', 'x'), + InputSpec([batch_size], 'int64', 'label') + ] + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + # init parallel optimizer + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + engine = Engine(mlp, data_spec, strategy=dist_strategy) + engine.prepare(optimizer, loss) + engine.fit(dataset, + batch_size=batch_size, + steps_per_epoch=batch_num * batch_size) + + +if __name__ == "__main__": + unittest.main() -- GitLab