diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 19a5b001abbb7f05799751b392ce7416f1a72a87..d18c05a058eea52550b08cec074ab7c35b085b85 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -27,6 +27,7 @@ from .dist_op import DistributedOperator from .dist_attribute import TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute from .process_mesh import ProcessMesh +from .process_group import get_world_process_group from paddle.distributed.fleet.meta_optimizers.common import OpRole @@ -765,16 +766,29 @@ class Completer: else: self._dist_context._serial_main_program = serial_main_program - self._dist_context.initialize() + start_time = time.time() + # print("start time", start_time, flush=True) + if not self._dist_context.data_parallel: + self._dist_context.initialize(with_graph=True) + + # self._dist_context.validate_dist_attr_for_program() + + self._prepare() + + self._update_process_mesh() - self._prepare() + self._update_dims_mapping() - self._update_process_mesh() + # Copy the corresponding distributed attribute from graph to serial_main_program + self._dist_context.copy_dist_attr_from_graph_to_program() + else: + self._dist_context.initialize(with_graph=False) - self._update_dims_mapping() + # A fast and special completion for data parallel + self._update_dist_attr_for_dp() - # Copy the corresponding distributed attribute from graph to serial_main_program - self._dist_context.copy_dist_attr_from_graph_to_program() + # print_program_with_dist_attr(self._dist_context.serial_main_program, + # self._dist_context) # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient self._complete_high_order_grad_annotation(serial_main_program) @@ -784,8 +798,107 @@ class Completer: self._dist_context.validate_dist_attr_for_program() + end_time = time.time() + # print("end time", end_time, flush=True) + # print("elapsed time", end_time - start_time, flush=True) + return serial_main_program + def _update_dist_attr_for_dp(self): + # TODO: we must ensure the world process group contains all ranks + ranks = get_world_process_group().ranks + process_mesh = ProcessMesh(ranks) + for dist_tensor in self._dist_context._dist_tensors_for_program.values( + ): + serial_tensor = dist_tensor.serial_tensor + tensor_dist_attr = dist_tensor.dist_attr + tensor_dist_attr.process_mesh = process_mesh + + for dist_op in self._dist_context._dist_ops_for_program.values(): + serial_op = dist_op.serial_op + op_desc = serial_op.desc + op_dist_attr = dist_op.dist_attr + op_dist_attr.process_mesh = process_mesh + original_op_dist_attr = copy.deepcopy(op_dist_attr) + input_xshape_arg_names = [] + if "XShape" in op_desc.input_names(): + input_xshape_arg_names = op_desc.input("XShape") + for arg_name in serial_op.input_arg_names: + serial_tensor = dist_op.get_serial_input(arg_name) + if not serial_tensor.is_parameter: + if arg_name not in input_xshape_arg_names: + old_dims_mapping = op_dist_attr.get_input_dims_mapping( + arg_name) + if len(old_dims_mapping) > 0: + new_dims_mapping = [0] + [ + -1 for _ in range(len(old_dims_mapping) - 1) + ] + op_dist_attr.set_input_dims_mapping( + arg_name, new_dims_mapping) + else: + old_dims_mapping = op_dist_attr.get_input_dims_mapping( + arg_name) + if len(old_dims_mapping) > 1: + new_dims_mapping = [-1, 0] + [ + -1 for _ in range(len(old_dims_mapping) - 2) + ] + op_dist_attr.set_input_dims_mapping( + arg_name, new_dims_mapping) + # Set tensor's dims_mapping by the op's + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + serial_tensor) + tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping( + arg_name) + output_xshape_arg_names = [] + if "XShape" in op_desc.output_names(): + output_xshape_arg_names = op_desc.output("XShape") + for arg_name in serial_op.output_arg_names: + serial_tensor = dist_op.get_serial_output(arg_name) + if not serial_tensor.is_parameter: + if arg_name not in output_xshape_arg_names: + old_dims_mapping = op_dist_attr.get_output_dims_mapping( + arg_name) + if len(old_dims_mapping) > 0: + new_dims_mapping = [0] + [ + -1 for _ in range(len(old_dims_mapping) - 1) + ] + op_dist_attr.set_output_dims_mapping( + arg_name, new_dims_mapping) + else: + old_dims_mapping = op_dist_attr.get_output_dims_mapping( + arg_name) + if len(old_dims_mapping) > 1: + new_dims_mapping = [-1, 0] + [ + -1 for _ in range(len(old_dims_mapping) - 2) + ] + op_dist_attr.set_output_dims_mapping( + arg_name, new_dims_mapping) + # Set tensor's dims_mapping by the op's + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( + serial_tensor) + tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( + arg_name) + + op_dist_impls = find_compatible_distributed_operator_impls( + dist_op, partial=False) + if op_dist_impls is not None: + not_compatible = True + backup_op_dist_attr = copy.deepcopy(op_dist_attr) + for op_dist_impl in op_dist_impls: + op_dist_impl.update_dims_mapping(dist_op) + if op_dist_impl.is_auto_compatible(dist_op) \ + and dist_op.validate_dist_attr(): + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + not_compatible = False + break + else: + dist_op.dist_attr = backup_op_dist_attr + if not_compatible: + dist_op.dist_attr = original_op_dist_attr + else: + dist_op.dist_attr = original_op_dist_attr + def _complete_high_order_grad_annotation(self, serial_main_program=None): """ NOTE: diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index bf4f66e7c1b6b72c549ee6f7e793f869e697311a..2f57b0ac0e4158db8c2e329c0d5c1a16c60e85c1 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -110,6 +110,7 @@ class DistributedContext: # self._tensor_id_to_tensor_node_ids = {} self._is_initialized = False + #TODO: need a better way to remove the following flag self._need_copy_dist_attr_to_graph = False self._backup_pass_context_stack = [] self._backup_block_state_stack = [] @@ -121,6 +122,9 @@ class DistributedContext: # flag whether scale gradient with dp size self._gradient_scale = True + # A flag indicates whether the used parallelism is data parallel + self._data_parallel = False + @property def serial_main_program(self): return self._serial_main_program @@ -198,6 +202,14 @@ class DistributedContext: def gradient_scale(self, gs): self._gradient_scale = gs + @property + def data_parallel(self): + return self._data_parallel + + @data_parallel.setter + def data_parallel(self, dp): + self._data_parallel = dp + def _backup_serial_info(self, mode): self._backup_serial_main_program_stack.append( self._serial_main_program.clone()) @@ -335,7 +347,7 @@ class DistributedContext: if dist: self._restore_dist_info(dist_mode) - def initialize(self): + def initialize(self, with_graph=True): if not self._is_initialized: if not self._serial_main_program: self._serial_main_program = self._original_serial_main_program @@ -366,13 +378,16 @@ class DistributedContext: self._dist_ops_for_program) self._tensors_ids = list(self._dist_tensors_for_program.keys()) self._ops_ids = list(self._dist_ops_for_program.keys()) - set_flags({"FLAGS_convert_all_blocks": True}) - self._serial_graph = framework.IrGraph( - core.Graph(self._serial_main_program.desc)) - self._init_dist_attr_for_graph() self._is_initialized = True - self._need_copy_dist_attr_to_graph = False - if self._need_copy_dist_attr_to_graph: + + if with_graph: + set_flags({"FLAGS_convert_all_blocks": True}) + self._serial_graph = framework.IrGraph( + core.Graph(self._serial_main_program.desc)) + self._init_dist_attr_for_graph() + self._need_copy_dist_attr_to_graph = False + + if self._need_copy_dist_attr_to_graph and with_graph: self.copy_dist_attr_from_program_to_graph() def add_process_mesh(self, process_mesh): @@ -522,6 +537,8 @@ class DistributedContext: self._process_meshes = copy.deepcopy(default_ctx.process_meshes) else: default_ctx = self + # Copy the data parallel flag from the default context + self._data_parallel = default_ctx.data_parallel for block in self._serial_main_program.blocks: for tensor in block.vars.values(): # Copy the distributed tensors in the default context diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index dcdd098dcd9cc008fea89bfebc89731e1274c080..555b3ff6cd992a8d463730645dfbe9663590ed2c 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -44,7 +44,7 @@ from .dist_saver import DistributedSaver from .dist_loader import NonIterableGeneratorLoader from .utils import make_data_unshard, set_grad_var_shape from .utils import print_program_with_dist_attr, to_list -from .process_group import get_all_process_groups, get_world_process_group +from .process_group import new_process_group, get_all_process_groups, get_world_process_group from .dist_context import DistributedContext, get_default_distributed_context @@ -155,8 +155,10 @@ class Engine: default_ctx = get_default_distributed_context() if not default_ctx.has_annotation or self._default_strategy: - inputs = [self._set_data_parallel(var) for var in inputs] - labels = [self._set_data_parallel(var) for var in labels] + # We build the world process group because the data parallel + # needs all ranks by default. + new_process_group(list(range(self._nranks))) + default_ctx.data_parallel = True # self._feed_vars[mode] = {"inputs": inputs, "labels": labels} feed_vars = {"inputs": inputs, "labels": labels} diff --git a/python/paddle/distributed/auto_parallel/planner_v2.py b/python/paddle/distributed/auto_parallel/planner_v2.py index 77496ed3e6d20b04b5f1d7e53289c340243ee955..90b840c5943bcef056b0d5cd7f1d9a6044c572c5 100755 --- a/python/paddle/distributed/auto_parallel/planner_v2.py +++ b/python/paddle/distributed/auto_parallel/planner_v2.py @@ -27,10 +27,14 @@ class Planner: # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need # dependency of backward-forward ops in forward completion. - # TODO: The id mapping will be lost if we clone the original program. default_ctx = get_default_distributed_context() self._dist_context._dist_op_context = default_ctx.dist_op_context - self._dist_context.initialize() + if not default_ctx.data_parallel: + # Use SSA graph for complex parallism + self._dist_context.initialize(with_graph=True) + else: + # Use program for data parallel parallism + self._dist_context.initialize(with_graph=False) self._completer = Completer(self._dist_context) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 10498bf48e9d0589aeb4818027a1841a29829a0e..e0eb04e2535c5a53e528ff9d12563f3502926148 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -20,6 +20,11 @@ if(WITH_DISTRIBUTE AND WITH_GPU) set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + py_test_modules(test_engine_api_dp MODULES test_engine_api_dp ENVS + ${dist_ENVS}) + set_tests_properties(test_engine_api_dp + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index ae69ee087686a4acc78d0e962e22e48783db89ef..ec757c03478defee80b56322ee5f5f533a466934 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -90,7 +90,7 @@ class MLPLayer(nn.Layer): def forward(self, input): out = auto.shard_op(self.norm, dist_attr={"process_mesh": PP_MESH_0})(input)[0] - out = self.linear0(input) + out = self.linear0(out) out = F.gelu(out, approximate=True) out = auto.shard_op(self.linear1, dist_attr={"process_mesh": PP_MESH_1})(out)[0] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4c8a9986cef26a474a25839c24f42cf186c5d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import time +import tempfile +import copy +import os +import numpy as np +import subprocess +import paddle +import paddle.nn as nn +import paddle.fluid as fluid +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +from paddle.fluid import layers +from paddle.io import Dataset, IterableDataset, DataLoader +from paddle.static import InputSpec +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.engine import Engine + +paddle.enable_static() +batch_size = 1 +batch_num = 10 +hidden_size = 1024 +sequence_len = 512 +image_size = hidden_size +class_num = 10 + +paddle.seed(44) + + +class MyDataset(Dataset): + + def __init__(self, num_samples): + super(MyDataset, self).__init__() + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + label = np.random.randint(0, class_num - 1, dtype="int64") + return input, label + + def __len__(self): + return self.num_samples + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear(d_model, + dim_feedforward, + weight_attr, + bias_attr=bias_attr) + self.linear1 = nn.Linear(dim_feedforward, + d_model, + weight_attr, + bias_attr=bias_attr) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + out = self.linear2(out) + self.out = out + return out + + +def train(fetch): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') + labels_spec = InputSpec([batch_size], 'int64', 'label') + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.amp = False + dist_strategy.pipeline = False + dist_strategy.recompute = False + # init parallel optimizer + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + + # init engine + engine = Engine(mlp, + inputs_spec=inputs_spec, + labels_spec=labels_spec, + strategy=dist_strategy) + engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + + # fetch + if fetch: + fetches = {'out': mlp.out} + else: + fetches = None + + # train + train_dataset = MyDataset(batch_num * batch_size) + engine.fit(train_dataset, + batch_size=batch_size, + steps_per_epoch=batch_num * batch_size, + fetches=fetches) + + # eval + eval_dataset = MyDataset(batch_size) + engine.evaluate(eval_dataset, batch_size, fetches=fetches) + + # predict + test_dataset = MyDataset(batch_size) + engine.predict(test_dataset, batch_size, fetches=fetches) + + # save + temp_dir = tempfile.TemporaryDirectory() + model_filename = os.path.join(temp_dir.name, 'mlp_inf') + engine.save(model_filename, training=False, mode='predict') + temp_dir.cleanup() + + +if __name__ == "__main__": + train(True) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_dp.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_dp.py new file mode 100644 index 0000000000000000000000000000000000000000..92c8e534aa26bbb1d73d2ba71d2796ee0a5bfb9b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_engine_api_dp.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestEngineAPI(unittest.TestCase): + + def test_engine_api(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "engine_api_dp.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "launch", "--gpus", "0,1", "--log_dir", tmp_dir.name, + launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main()