From bc153701da60fd042335177c6bf2f145f38fc90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=B6=E6=B3=BD=E4=BC=9F?= <93245647+littsk@users.noreply.github.com> Date: Wed, 19 Jul 2023 14:32:11 +0800 Subject: [PATCH] add sequence parallel utils to fleet utils (#55462) --- .../distributed/fleet/utils/__init__.py | 1 + .../fleet/utils/sequence_parallel_utils.py | 461 +++++++++++++++++ ...arallel_mp_model_with_sequence_parallel.py | 477 ++++++++++++++++++ .../test_parallel_dygraph_tensor_parallel.py | 5 + 4 files changed, 944 insertions(+) create mode 100644 python/paddle/distributed/fleet/utils/sequence_parallel_utils.py create mode 100644 test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 25b1c153165..0ad0d6256ab 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -23,6 +23,7 @@ from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 from . import tensor_parallel_utils # noqa: F401 from . import mix_precision_utils # noqa: F401 +from . import sequence_parallel_utils __all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa diff --git a/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py new file mode 100644 index 00000000000..1e7f5e93785 --- /dev/null +++ b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py @@ -0,0 +1,461 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import distributed as dist +from paddle.autograd import PyLayer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + fused_allreduce_gradients_with_group, +) +from paddle.fluid import core +from paddle.nn import Layer +from paddle.nn import functional as F + +#################################################### +# # +# Distributed Communication Operator # +# # +#################################################### + + +def scatter(input): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + rank = group.rank + seq_len = input.shape[0] + assert ( + seq_len % parallelism == 0 + ), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format( + seq_len, parallelism + ) + interval = seq_len // parallelism + input = paddle.slice( + input, axes=[0], starts=[interval * rank], ends=[interval * (rank + 1)] + ) + return input + + +def all_gather(input): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + output_shape = input.shape + output_shape[0] = output_shape[0] * parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + group.process_group.all_gather(input, output).wait() + return output + + +def reduce_scatter(input): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + parallelism = group.nranks + output_shape = input.shape + assert ( + input.shape[0] % parallelism == 0 + ), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format( + input.shape[0], parallelism + ) + output_shape[0] = output_shape[0] // parallelism + output = paddle.empty(shape=output_shape, dtype=input.dtype) + dist.stream.reduce_scatter( + output, input, op=dist.ReduceOp.SUM, group=group, sync_op=True + ) + return output + + +class ScatterOp(PyLayer): + # input shape: [s, b, h], n is mp parallelism + # after forward shape: [s/n, b, h] + @staticmethod + def forward(ctx, input): + return scatter(input) + + @staticmethod + def backward(ctx, grad): + return all_gather(grad) + + +class GatherOp(PyLayer): + # input shape: [s/n, b, h], n is mp parallelism + # after forward shape: [s, b, h] + @staticmethod + def forward(ctx, input): + return all_gather(input) + + @staticmethod + def backward(ctx, grad): + return scatter(grad) + + +# All gather along the first dim during forward pass +# All reduce and scatter along the first dim during backward pass +class AllGatherOp(PyLayer): + # input shape: [s/n, b, h], n is mp parallelism + # after forward shape: [s, b, h] + @staticmethod + def forward(ctx, input): + return all_gather(input) + + # grad shape: [s, b, h], n is mp parallelism + # after forward shape: [s/n, b, h] + @staticmethod + def backward(ctx, grad): + return reduce_scatter(grad) + + +# All reduce and scatter along the first dim during forward pass +# All gather along the first dim during backward pass +class ReduceScatterOp(PyLayer): + # input shape: [s, b, h], n is mp parallelism + # after forward shape: [s/n, b, h] + @staticmethod + def forward(ctx, input): + return reduce_scatter(input) + + # grad shape: [s/n, b, h], n is mp parallelism + # after forward shape: [s, b, h] + @staticmethod + def backward(ctx, grad): + return all_gather(grad) + + +################################################### +# # +# Modified Parallel Linear Operator # +# # +################################################### + + +def mark_as_sequence_parallel_parameter(parameter): + parameter.sequence_parallel = True + + +def is_sequence_parallel_parameter(parameter): + return getattr(parameter, "sequence_parallel", False) + + +def create_fused_allreduce_gradient_hook(parameter_list, accumulation_steps): + hcg = fleet.get_hybrid_communicate_group() + group = hcg.get_model_parallel_group() + + step = [0] + accumulation_steps *= len(parameter_list) + + def __impl__(grad): + step[0] += 1 + if step[0] == accumulation_steps: + step[0] = 0 + fused_allreduce_gradients_with_group( + parameter_list, group=group, scale=1.0 + ) + return grad + + return __impl__ + + +def create_non_fused_allreduce_gradient_hook(param, accumulation_steps): + hcg = fleet.get_hybrid_communicate_group() + pg = hcg.get_model_parallel_group().process_group + step = [0] + + @paddle.autograd.no_grad() + def __impl__(): + step[0] += 1 + if (step[0] % accumulation_steps) == 0: + if hasattr(param, "main_grad"): + pg.allreduce(param.main_grad).wait() + else: + pg.allreduce(param.grad).wait() + + return __impl__ + + +def register_sequence_parallel_allreduce_hooks( + model, accumulation_steps, fuse_sequence_parallel_allreduce +): + if accumulation_steps <= 0 or not paddle.distributed.is_initialized(): + return + + mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group() + if mp_group.nranks <= 1: + return + + params = [] + for p in model.parameters(): + if is_sequence_parallel_parameter(p): + params.append(p) + + if fuse_sequence_parallel_allreduce: + hook = create_fused_allreduce_gradient_hook(params, accumulation_steps) + for p in params: + p._register_backward_hook(hook) + else: + for p in params: + hook = create_non_fused_allreduce_gradient_hook( + p, accumulation_steps + ) + p._register_backward_hook(hook) + + +def is_fused_matmul_bias_supported(): + if ( + paddle.is_compiled_with_cuda() + and not paddle.is_compiled_with_rocm() + or paddle.is_compiled_with_xpu() + ): + return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue") + else: + return False + + +class ColumnSequenceParallelLinear(Layer): + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=None, + gather_output=True, + fuse_matmul_bias=False, + mp_group=None, + name=None, + ): + super().__init__() + + hcg = fleet.get_hybrid_communicate_group() + self.model_parallel_group = ( + hcg.get_model_parallel_group() if mp_group is None else mp_group + ) + self.world_size = ( + hcg.get_model_parallel_group().nranks + if mp_group is None + else mp_group.nranks + ) + self._name = name + self.is_mp = self.world_size > 1 + + assert ( + gather_output is False + ), "If sequence_parallel is True, \ + gather_output is False" + + self.gather_output = gather_output + assert out_features % self.world_size == 0, ( + "Number of column of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format( + out_features, self.world_size + ) + ) + self.output_size_per_partition = out_features // self.world_size + + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + else: + self.weight = self.create_parameter( + shape=[in_features, self.output_size_per_partition], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + + self.weight.is_distributed = True if self.is_mp else False + + if has_bias: + # initialize bias to zero like Megatron + self.bias = self.create_parameter( + shape=[self.output_size_per_partition], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True, + ) + self.bias.is_distributed = True if self.is_mp else False + else: + self.bias = None + + self.linear = F.linear + + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in ColumnSequenceParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher, or use xpu version." + ) + from paddle.incubate.nn.functional import fused_linear + + self.linear = fused_linear + + def forward(self, x): + # sequence parallelism is same as model parallelism + # if sequence parallel is true, input shape is [s, b, h] + # else input shape is [b, s, h] + if self.is_mp: + input_parallel = AllGatherOp.apply(x) + else: + input_parallel = x + output = self.linear( + input_parallel, self.weight, self.bias, name=self._name + ) + return output + + +class MPScale(PyLayer): + @staticmethod + def forward(ctx, x, mp_degree): + out = paddle.scale(x, 1.0 / mp_degree) + return out + + @staticmethod + def backward(ctx, dout): + return dout + + +class RowSequenceParallelLinear(Layer): + def __init__( + self, + in_features, + out_features, + weight_attr=None, + has_bias=True, + input_is_parallel=False, + fuse_matmul_bias=False, + mp_group=None, + name=None, + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + assert ( + input_is_parallel is True + ), "If sequence_parallel is True, \ + input_is_parallel should be true." + + self.input_is_parallel = input_is_parallel + self._weight_attr = weight_attr + self._dtype = self._helper.get_default_dtype() + self._name = name + + hcg = fleet.get_hybrid_communicate_group() + self.model_parallel_group = ( + hcg.get_model_parallel_group() if mp_group is None else mp_group + ) + self.world_size = ( + hcg.get_model_parallel_group().nranks + if mp_group is None + else mp_group.nranks + ) + self.rank = ( + hcg.get_model_parallel_group().rank + if mp_group is None + else mp_group.rank + ) + + self.is_mp = self.world_size > 1 + assert in_features % self.world_size == 0, ( + "Number of row of the weight for linear ({}) must be" + " divisible by model parallel size ({})".format( + in_features, self.world_size + ) + ) + + self.input_size_per_partition = in_features // self.world_size + + if self.is_mp and paddle.in_dynamic_mode(): + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + else: + self.weight = self.create_parameter( + shape=[self.input_size_per_partition, self.out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + + self.weight.is_distributed = True if self.is_mp else False + + # if sequence parallel is true, + # register hook to all_reduce gradient of weight and bias + if has_bias: + self.bias = self.create_parameter( + shape=[self.out_features], + attr=paddle.nn.initializer.Constant(value=0.0), + dtype=self._dtype, + is_bias=True, + ) + if self.is_mp: + mark_as_sequence_parallel_parameter(self.bias) + else: + self.bias = None + + self.linear = F.linear + + self.mp_scale = None + if fuse_matmul_bias: + if not is_fused_matmul_bias_supported(): + raise NotImplementedError( + "You set fuse_matmul_bias=True in RowParallelLinear, " + "however, the paddle you are using not support this operation. " + "Please set fuse_matmul_bias=False or use paddle compiled " + "with cuda 11.6 or higher." + ) + from paddle.incubate.nn.functional import fused_linear + + self.linear = fused_linear + if self.is_mp and has_bias: + self.mp_scale = MPScale.apply + + def forward(self, x): + input_parallel = x + if self.is_mp: + if self.mp_scale is not None: + bias = self.mp_scale(self.bias, self.world_size) + else: + bias = None + output_parallel = self.linear( + input_parallel, self.weight, bias, name=self._name + ) + output_ = ReduceScatterOp.apply(output_parallel) + # if self.bias is not none, sequence parallel will use + # register_hook to all_reduce self.bias + if bias is None and self.bias is not None: + output = output_ + self.bias + else: + output = output_ + else: + output = self.linear( + input_parallel, self.weight, self.bias, name=self._name + ) + return output diff --git a/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py b/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py new file mode 100644 index 00000000000..fa78482601f --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_mp_model_with_sequence_parallel.py @@ -0,0 +1,477 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import sequence_parallel_utils as spu + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + rank_id) + + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 + + +def parallel_matmul(lm_output, logit_weights, parallel_output): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + world_size = hcg.get_model_parallel_world_size() + rank = hcg.get_model_parallel_rank() + + if world_size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=model_parallel_group + ) + + logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat( + logits, group=model_parallel_group + ) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + +class SimpleSPNet(paddle.nn.Layer): + def __init__( + self, + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ): + super().__init__() + + if mp_id == 0: + init_fc1_data = np_fc1[:, : (inner_size // 2)] + init_fc2_data = np_fc2[: (inner_size // 2), :] + else: + init_fc1_data = np_fc1[:, (inner_size // 2) :] + init_fc2_data = np_fc2[(inner_size // 2) :, :] + + self.embedding = fleet.meta_parallel.VocabParallelEmbedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + self.linear1 = spu.ColumnSequenceParallelLinear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc1_data) + ), + gather_output=False, + has_bias=True, + ) + + self.linear2 = spu.RowSequenceParallelLinear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(init_fc2_data) + ), + input_is_parallel=True, + has_bias=True, + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5) + # if sequence parallel is true, + # register hook to all_reduce gradient of weight, bias + spu.mark_as_sequence_parallel_parameter(self.norm.weight) + spu.mark_as_sequence_parallel_parameter(self.norm.bias) + + spu.register_sequence_parallel_allreduce_hooks(self, 1, False) + + def forward(self, x): + x = self.embedding(x) + + x = paddle.transpose(x, perm=[1, 0, 2]) + x = spu.ScatterOp.apply(x) + + x = self.linear1(x) + x = self.linear2(x) + x = self.norm(x) + x = self.linear3(x) + + x = paddle.transpose(x, perm=[1, 0, 2]) + + x = parallel_matmul(x, self.embedding.weight, False) + return x + + +class SimpleDPNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.norm(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistSPSyncTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": False, + "sync_grad": False, + "sync_moment": False, + }, + } + fleet.init(is_collective=True, strategy=strategy) + + def build_model_optimizer_train( + self, + batchs, + fp16=False, + amp_level="O1", + mp_sync_param=False, + mp_sync_grad=False, + mp_sync_moment=False, + ): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + paddle.seed(2023) + np.random.seed(2023) + random.seed(2023) + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = SimpleSPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + optimizer = paddle.optimizer.AdamW( + learning_rate=0.1, parameters=model.parameters() + ) + + if fp16 and amp_level == "O2": + model, optimizer = paddle.amp.decorate( + models=model, optimizers=optimizer, level='O2' + ) + + strategy = fleet.fleet._user_defined_strategy + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": mp_sync_param, + "sync_grad": mp_sync_grad, + "sync_moment": mp_sync_moment, + }, + } + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + return self.train_batch(batchs, model, optimizer, fp16, amp_level) + + def train_batch(self, batchs, model, optimizer, fp16=False, amp_level="O1"): + losses = [] + if fp16: + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + scaler = fleet.distributed_scaler(scaler) + for batch in batchs: + with paddle.amp.auto_cast(enable=fp16, level=amp_level): + output = model(batch) + loss = output.mean() + losses.append(loss.numpy()) + if fp16: + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + optimizer.step() + optimizer.clear_grad() + return losses + + def mp_sync_base( + self, mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False + ): + batchs = [] + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batchs.append(paddle.to_tensor(np_data)) + + losses = self.build_model_optimizer_train(batchs) + losses_sync = self.build_model_optimizer_train( + batchs, + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses)): + np.testing.assert_allclose( + losses[i], losses_sync[i], rtol=1e-5, atol=1e-5 + ) + + # test fp16 O1 + losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True) + losses_sync_fp16 = self.build_model_optimizer_train( + batchs, + fp16=True, + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses_fp16)): + np.testing.assert_allclose( + losses_fp16[i], losses_sync_fp16[i], rtol=1e-5, atol=1e-5 + ) + + # test fp16 O2 + losses_fp16_O2 = self.build_model_optimizer_train( + batchs, fp16=True, amp_level="O2" + ) + losses_sync_fp16_O2 = self.build_model_optimizer_train( + batchs, + fp16=True, + amp_level="O2", + mp_sync_param=mp_sync_param, + mp_sync_grad=mp_sync_grad, + mp_sync_moment=mp_sync_moment, + ) + + for i in range(len(losses_fp16_O2)): + np.testing.assert_allclose( + losses_fp16_O2[i], losses_sync_fp16_O2[i], rtol=1e-5, atol=1e-5 + ) + + def test_mp_sync_param(self): + self.mp_sync_base(mp_sync_param=True) + + def test_mp_sync_grad(self): + self.mp_sync_base(mp_sync_grad=True) + + def test_mp_sync_moment(self): + self.mp_sync_base(mp_sync_moment=True) + + def test_mp_sync_all(self): + self.mp_sync_base( + mp_sync_param=True, mp_sync_grad=True, mp_sync_moment=True + ) + + +class TestDistSPSyncModelTraning(TestDistSPSyncTraning): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + "mp_configs": { + "sync_param": False, + "sync_grad": False, + "sync_moment": False, + "sync_mode": "average", + "sync_param_name": ["embedding", "layer_norm", ".b_"], + }, + } + fleet.init(is_collective=True, strategy=strategy) + + +class TestDistSPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + def train_batch(self, batch, model, optimizer, is_mp): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model): + optimizer = paddle.optimizer.SGD( + learning_rate=0.001, parameters=model.parameters() + ) + return optimizer + + def build_model_optimizer(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + mp_id = hcg.get_model_parallel_rank() + dp_id = hcg.get_data_parallel_rank() + rank_id = dist.get_rank() + set_random_seed(1024, dp_id, rank_id) + + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleSPNet( + vocab_size, + hidden_size, + inner_size, + output_size, + np_fc1, + np_fc2, + mp_id, + ) + optimizer_a = self.build_optimizer(model_a) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer(model_b) + + return model_a, optimizer_a, model_b, optimizer_b + + def test_mp_model(self): + ( + model_a, + optimizer_a, + model_b, + optimizer_b, + ) = self.build_model_optimizer() + + for _ in range(5): + np_data = np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + batch = paddle.to_tensor(np_data) + loss_a = self.train_batch(batch, model_a, optimizer_a, True) + loss_b = self.train_batch(batch, model_b, optimizer_b, False) + + np.testing.assert_allclose( + loss_a.numpy(), loss_b.numpy(), rtol=1e-5, atol=1e-5 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_tensor_parallel.py b/test/collective/fleet/test_parallel_dygraph_tensor_parallel.py index 769e9c29727..47ab325c73d 100644 --- a/test/collective/fleet/test_parallel_dygraph_tensor_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_tensor_parallel.py @@ -24,6 +24,11 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_mp_model(self): self.run_mnist_2gpu('hybrid_parallel_mp_model.py') + def test_hybrid_parallel_mp_model_with_sequence_parallel(self): + self.run_mnist_2gpu( + 'hybrid_parallel_mp_model_with_sequence_parallel.py' + ) + def test_hybrid_parallel_mp_amp(self): self.run_mnist_2gpu('hybrid_parallel_mp_amp.py') -- GitLab