From 85bb1a85cdb3bc9927f5047dc81e25f0d7ada844 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Wed, 13 Oct 2021 15:02:41 +0800 Subject: [PATCH] support auto parallel data shard (#36055) --- .../distributed/auto_parallel/parallelizer.py | 3 + .../paddle/distributed/auto_parallel/utils.py | 37 ++++ .../distributed/fleet/base/fleet_base.py | 1 + .../fluid/tests/unittests/CMakeLists.txt | 3 + .../unittests/auto_parallel_data_unshard.py | 179 ++++++++++++++++++ .../test_auto_parallel_data_unshard.py | 29 +++ 6 files changed, 252 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py create mode 100644 python/paddle/fluid/tests/unittests/test_auto_parallel_data_unshard.py diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 2994d35ef92..1437dbb2f90 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -20,6 +20,7 @@ from .context import get_default_distributed_context from .completion import complete_annotation, complete_backward_annotation from .partitioner import Partitioner from .process import get_all_process_groups +from .utils import make_data_unshard from .reshard import reshard @@ -95,6 +96,8 @@ class AutoParallelizer: self._remove_distributed_attrs(partitioned_main_prog) complete_backward_annotation(partitioned_main_prog, self._dist_context) + + make_data_unshard(partitioned_main_prog, partitioned_startup_prog) reshard(partitioned_main_prog, partitioned_startup_prog, rank, self._dist_context) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 547495fb848..a81ff699189 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -277,3 +277,40 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): # row major order return coordinate + + +def _get_unshard_dist_shape(var, dist_attr): + var_shape = var.shape + mapping = dist_attr.get_dims_mapping() + mesh = dist_attr.get_process_mesh().topology + assert len(var_shape) == len( + mapping + ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( + var_shape, mapping) + new_shape = [] + for idx in range(len(var_shape)): + if var_shape[idx] == -1 or mapping[idx] == -1: + new_shape.append(var_shape[idx]) + else: + new_shape.append(var_shape[idx] * mesh[mapping[idx]]) + + return new_shape + + +def make_data_unshard(dist_main_prog, dist_startup_prog): + from .context import get_default_distributed_context + dist_context = get_default_distributed_context() + + for var in dist_main_prog.list_vars(): + if var.is_data: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + var) + inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) + var.desc.set_shape(inverse_shape) + dim_mapping = tensor_dist_attr.get_dims_mapping() + dim_mapping = [-1] * len(dim_mapping) + tensor_dist_attr.set_dims_mapping(dim_mapping) + dist_context.set_tensor_distributed_attr_for_program( + var, tensor_dist_attr) + var._set_attr('dim_mapping' + core.kAutoParallelSuffix(), + dim_mapping) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 687295b1f2c..544c79a0b39 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -1423,6 +1423,7 @@ class Fleet(object): auto_parallelizer = AutoParallelizer(self) optimize_ops, params_grads, dist_startup_prog, dist_main_prog = auto_parallelizer.parallelize( loss, startup_program, parameter_list, no_grad_set) + return optimize_ops, params_grads, dist_startup_prog, dist_main_prog # compile time diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 33cd236a7d0..f883d7a80a4 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -36,6 +36,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers) list(APPEND DIST_TEST_OPS test_hybrid_parallel_inference_helper) list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy) +list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -233,6 +234,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -1001,6 +1003,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_class_center_sample PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120) + set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py new file mode 100644 index 00000000000..367d9858626 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import copy +import numpy as np +import random + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.distributed.auto_parallel as auto +import paddle.nn.functional as F +from paddle.distributed import fleet + +paddle.enable_static() +paddle.distributed.init_parallel_env() + + +class TestDataUnshard(unittest.TestCase): + def test_dp2pp1mp1(self): + def create_model(train_program, start_program): + with paddle.static.program_guard(train_program, start_program): + + ROOT_MESH = auto.ProcessMesh([0, 1]) + MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH) + input = paddle.static.data(name='input', shape=[2, 8]) + label = paddle.static.data(name='label', shape=[2, 8]) + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=0.02)) + linear0 = nn.Linear(8, 8, weight_attr) + linear1 = nn.Linear(8, 8, weight_attr) + + auto.shard_tensor(input, MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor(label, MESH_0, dim_mapping=[0, -1]) + auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[-1, -1]) + + linear0_out = linear0(input) + gelu_out = F.gelu(linear0_out) + linear1_out = linear1(gelu_out) + error_cost = paddle.nn.functional.square_error_cost(linear1_out, + label) + loss = paddle.mean(error_cost) + return train_program, start_program, loss, input, label + + train_program = paddle.static.Program() + start_program = paddle.static.Program() + # serial program + train_program, start_program, loss, input, label = create_model( + train_program, start_program) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + optimizer = fleet.distributed_optimizer(optimizer) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, start_program) + + worker_index = paddle.distributed.get_rank() + paddle.seed(worker_index + 2021) + random.seed(worker_index + 2021) + np.random.seed(worker_index + 2021) + + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(distributed_startup_program) + + input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32") + label_data = np.random.randint(0, 10, [2, 8]).astype("float32") + + fetchs = [loss.name, 'input@RESHARD_0'] + loss_np, shard_data_np = exe.run( + distributed_main_program, + feed={"input": input_data, + "label": label_data}, + fetch_list=fetchs) + desired = input_data[worker_index].reshape(shard_data_np.shape) + np.testing.assert_allclose(shard_data_np, desired) + + def dp1pp1mp2(self): + def create_model(train_program, start_program): + with paddle.static.program_guard(train_program, start_program): + + ROOT_MESH = auto.ProcessMesh([0, 1]) + MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH) + input = paddle.static.data(name='input', shape=[8, 8]) + label = paddle.static.data(name='label', shape=[8, 8]) + + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=0.02)) + linear0 = nn.Linear(8, 8, weight_attr) + linear1 = nn.Linear(8, 8, weight_attr) + + auto.shard_tensor(input, MESH_0, dim_mapping=[-1, -1]) + auto.shard_tensor(label, MESH_0, dim_mapping=[-1, -1]) + + auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, 0]) + auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[0, -1]) + + linear0_out = linear0(input) + gelu_out = F.gelu(linear0_out) + + linear1_out = linear1(gelu_out) + + error_cost = paddle.nn.functional.square_error_cost(linear1_out, + label) + loss = paddle.mean(error_cost) + return train_program, start_program, loss, input, label + + train_program = paddle.static.Program() + start_program = paddle.static.Program() + # serial program + train_program, start_program, loss, input, label = create_model( + train_program, start_program) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + optimizer = fleet.distributed_optimizer(optimizer) + _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( + loss, start_program) + + worker_index = paddle.distributed.get_rank() + paddle.seed(worker_index + 2021) + random.seed(worker_index + 2021) + np.random.seed(worker_index + 2021) + + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(distributed_startup_program) + + input_data = np.array(range(8 * 8)).reshape([8, 8]).astype("float32") + label_data = np.random.randint(0, 10, [8, 8]).astype("float32") + + fetchs = [loss.name, 'input'] + loss_np, shard_data_np = exe.run( + distributed_main_program, + feed={"input": input_data, + "label": label_data}, + fetch_list=fetchs) + + desired = input_data.reshape(shard_data_np.shape) + np.testing.assert_allclose(shard_data_np, desired) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_data_unshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_data_unshard.py new file mode 100644 index 00000000000..6cc953dfdee --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_data_unshard.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestAutoParallelDataUnshard(TestMultipleGpus): + def test_auto_parallel_data_unshard(self): + self.run_mnist_2gpu('auto_parallel_data_unshard.py') + + +if __name__ == "__main__": + unittest.main() -- GitLab