未验证 提交 bc153701 编写于 作者: 陶泽伟 提交者: GitHub

add sequence parallel utils to fleet utils (#55462)

上级 4c4d3185
......@@ -23,6 +23,7 @@ from . import log_util # noqa: F401
from . import hybrid_parallel_util # noqa: F401
from . import tensor_parallel_utils # noqa: F401
from . import mix_precision_utils # noqa: F401
from . import sequence_parallel_utils
__all__ = ["LocalFS", "recompute", "DistributedInfer", "HDFSClient"] # noqa
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import distributed as dist
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients_with_group,
)
from paddle.fluid import core
from paddle.nn import Layer
from paddle.nn import functional as F
####################################################
# #
# Distributed Communication Operator #
# #
####################################################
def scatter(input):
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
parallelism = group.nranks
rank = group.rank
seq_len = input.shape[0]
assert (
seq_len % parallelism == 0
), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format(
seq_len, parallelism
)
interval = seq_len // parallelism
input = paddle.slice(
input, axes=[0], starts=[interval * rank], ends=[interval * (rank + 1)]
)
return input
def all_gather(input):
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
parallelism = group.nranks
output_shape = input.shape
output_shape[0] = output_shape[0] * parallelism
output = paddle.empty(shape=output_shape, dtype=input.dtype)
group.process_group.all_gather(input, output).wait()
return output
def reduce_scatter(input):
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
parallelism = group.nranks
output_shape = input.shape
assert (
input.shape[0] % parallelism == 0
), "Input sequence length {} can't be divided exactly by sequence parallelism {}".format(
input.shape[0], parallelism
)
output_shape[0] = output_shape[0] // parallelism
output = paddle.empty(shape=output_shape, dtype=input.dtype)
dist.stream.reduce_scatter(
output, input, op=dist.ReduceOp.SUM, group=group, sync_op=True
)
return output
class ScatterOp(PyLayer):
# input shape: [s, b, h], n is mp parallelism
# after forward shape: [s/n, b, h]
@staticmethod
def forward(ctx, input):
return scatter(input)
@staticmethod
def backward(ctx, grad):
return all_gather(grad)
class GatherOp(PyLayer):
# input shape: [s/n, b, h], n is mp parallelism
# after forward shape: [s, b, h]
@staticmethod
def forward(ctx, input):
return all_gather(input)
@staticmethod
def backward(ctx, grad):
return scatter(grad)
# All gather along the first dim during forward pass
# All reduce and scatter along the first dim during backward pass
class AllGatherOp(PyLayer):
# input shape: [s/n, b, h], n is mp parallelism
# after forward shape: [s, b, h]
@staticmethod
def forward(ctx, input):
return all_gather(input)
# grad shape: [s, b, h], n is mp parallelism
# after forward shape: [s/n, b, h]
@staticmethod
def backward(ctx, grad):
return reduce_scatter(grad)
# All reduce and scatter along the first dim during forward pass
# All gather along the first dim during backward pass
class ReduceScatterOp(PyLayer):
# input shape: [s, b, h], n is mp parallelism
# after forward shape: [s/n, b, h]
@staticmethod
def forward(ctx, input):
return reduce_scatter(input)
# grad shape: [s/n, b, h], n is mp parallelism
# after forward shape: [s, b, h]
@staticmethod
def backward(ctx, grad):
return all_gather(grad)
###################################################
# #
# Modified Parallel Linear Operator #
# #
###################################################
def mark_as_sequence_parallel_parameter(parameter):
parameter.sequence_parallel = True
def is_sequence_parallel_parameter(parameter):
return getattr(parameter, "sequence_parallel", False)
def create_fused_allreduce_gradient_hook(parameter_list, accumulation_steps):
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_model_parallel_group()
step = [0]
accumulation_steps *= len(parameter_list)
def __impl__(grad):
step[0] += 1
if step[0] == accumulation_steps:
step[0] = 0
fused_allreduce_gradients_with_group(
parameter_list, group=group, scale=1.0
)
return grad
return __impl__
def create_non_fused_allreduce_gradient_hook(param, accumulation_steps):
hcg = fleet.get_hybrid_communicate_group()
pg = hcg.get_model_parallel_group().process_group
step = [0]
@paddle.autograd.no_grad()
def __impl__():
step[0] += 1
if (step[0] % accumulation_steps) == 0:
if hasattr(param, "main_grad"):
pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()
return __impl__
def register_sequence_parallel_allreduce_hooks(
model, accumulation_steps, fuse_sequence_parallel_allreduce
):
if accumulation_steps <= 0 or not paddle.distributed.is_initialized():
return
mp_group = fleet.get_hybrid_communicate_group().get_model_parallel_group()
if mp_group.nranks <= 1:
return
params = []
for p in model.parameters():
if is_sequence_parallel_parameter(p):
params.append(p)
if fuse_sequence_parallel_allreduce:
hook = create_fused_allreduce_gradient_hook(params, accumulation_steps)
for p in params:
p._register_backward_hook(hook)
else:
for p in params:
hook = create_non_fused_allreduce_gradient_hook(
p, accumulation_steps
)
p._register_backward_hook(hook)
def is_fused_matmul_bias_supported():
if (
paddle.is_compiled_with_cuda()
and not paddle.is_compiled_with_rocm()
or paddle.is_compiled_with_xpu()
):
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
else:
return False
class ColumnSequenceParallelLinear(Layer):
def __init__(
self,
in_features,
out_features,
weight_attr=None,
has_bias=None,
gather_output=True,
fuse_matmul_bias=False,
mp_group=None,
name=None,
):
super().__init__()
hcg = fleet.get_hybrid_communicate_group()
self.model_parallel_group = (
hcg.get_model_parallel_group() if mp_group is None else mp_group
)
self.world_size = (
hcg.get_model_parallel_group().nranks
if mp_group is None
else mp_group.nranks
)
self._name = name
self.is_mp = self.world_size > 1
assert (
gather_output is False
), "If sequence_parallel is True, \
gather_output is False"
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
out_features, self.world_size
)
)
self.output_size_per_partition = out_features // self.world_size
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
# initialize bias to zero like Megatron
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True,
)
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None
self.linear = F.linear
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in ColumnSequenceParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher, or use xpu version."
)
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
def forward(self, x):
# sequence parallelism is same as model parallelism
# if sequence parallel is true, input shape is [s, b, h]
# else input shape is [b, s, h]
if self.is_mp:
input_parallel = AllGatherOp.apply(x)
else:
input_parallel = x
output = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)
return output
class MPScale(PyLayer):
@staticmethod
def forward(ctx, x, mp_degree):
out = paddle.scale(x, 1.0 / mp_degree)
return out
@staticmethod
def backward(ctx, dout):
return dout
class RowSequenceParallelLinear(Layer):
def __init__(
self,
in_features,
out_features,
weight_attr=None,
has_bias=True,
input_is_parallel=False,
fuse_matmul_bias=False,
mp_group=None,
name=None,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
assert (
input_is_parallel is True
), "If sequence_parallel is True, \
input_is_parallel should be true."
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self._name = name
hcg = fleet.get_hybrid_communicate_group()
self.model_parallel_group = (
hcg.get_model_parallel_group() if mp_group is None else mp_group
)
self.world_size = (
hcg.get_model_parallel_group().nranks
if mp_group is None
else mp_group.nranks
)
self.rank = (
hcg.get_model_parallel_group().rank
if mp_group is None
else mp_group.rank
)
self.is_mp = self.world_size > 1
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(
in_features, self.world_size
)
)
self.input_size_per_partition = in_features // self.world_size
if self.is_mp and paddle.in_dynamic_mode():
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.weight.is_distributed = True if self.is_mp else False
# if sequence parallel is true,
# register hook to all_reduce gradient of weight and bias
if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True,
)
if self.is_mp:
mark_as_sequence_parallel_parameter(self.bias)
else:
self.bias = None
self.linear = F.linear
self.mp_scale = None
if fuse_matmul_bias:
if not is_fused_matmul_bias_supported():
raise NotImplementedError(
"You set fuse_matmul_bias=True in RowParallelLinear, "
"however, the paddle you are using not support this operation. "
"Please set fuse_matmul_bias=False or use paddle compiled "
"with cuda 11.6 or higher."
)
from paddle.incubate.nn.functional import fused_linear
self.linear = fused_linear
if self.is_mp and has_bias:
self.mp_scale = MPScale.apply
def forward(self, x):
input_parallel = x
if self.is_mp:
if self.mp_scale is not None:
bias = self.mp_scale(self.bias, self.world_size)
else:
bias = None
output_parallel = self.linear(
input_parallel, self.weight, bias, name=self._name
)
output_ = ReduceScatterOp.apply(output_parallel)
# if self.bias is not none, sequence parallel will use
# register_hook to all_reduce self.bias
if bias is None and self.bias is not None:
output = output_ + self.bias
else:
output = output_
else:
output = self.linear(
input_parallel, self.weight, self.bias, name=self._name
)
return output
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.utils import sequence_parallel_utils as spu
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + rank_id)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
def parallel_matmul(lm_output, logit_weights, parallel_output):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
world_size = hcg.get_model_parallel_world_size()
rank = hcg.get_model_parallel_rank()
if world_size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=model_parallel_group
)
logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(
logits, group=model_parallel_group
)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
class SimpleSPNet(paddle.nn.Layer):
def __init__(
self,
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
):
super().__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, : (inner_size // 2)]
init_fc2_data = np_fc2[: (inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2) :]
init_fc2_data = np_fc2[(inner_size // 2) :, :]
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
self.linear1 = spu.ColumnSequenceParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)
),
gather_output=False,
has_bias=True,
)
self.linear2 = spu.RowSequenceParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)
),
input_is_parallel=True,
has_bias=True,
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5)
# if sequence parallel is true,
# register hook to all_reduce gradient of weight, bias
spu.mark_as_sequence_parallel_parameter(self.norm.weight)
spu.mark_as_sequence_parallel_parameter(self.norm.bias)
spu.register_sequence_parallel_allreduce_hooks(self, 1, False)
def forward(self, x):
x = self.embedding(x)
x = paddle.transpose(x, perm=[1, 0, 2])
x = spu.ScatterOp.apply(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.norm(x)
x = self.linear3(x)
x = paddle.transpose(x, perm=[1, 0, 2])
x = parallel_matmul(x, self.embedding.weight, False)
return x
class SimpleDPNet(paddle.nn.Layer):
def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
):
super().__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.norm = paddle.nn.LayerNorm(hidden_size, epsilon=1e-5)
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.norm(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistSPSyncTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
},
}
fleet.init(is_collective=True, strategy=strategy)
def build_model_optimizer_train(
self,
batchs,
fp16=False,
amp_level="O1",
mp_sync_param=False,
mp_sync_grad=False,
mp_sync_moment=False,
):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
paddle.seed(2023)
np.random.seed(2023)
random.seed(2023)
set_random_seed(1024, dp_id, rank_id)
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleSPNet(
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
)
optimizer = paddle.optimizer.AdamW(
learning_rate=0.1, parameters=model.parameters()
)
if fp16 and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2'
)
strategy = fleet.fleet._user_defined_strategy
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": mp_sync_param,
"sync_grad": mp_sync_grad,
"sync_moment": mp_sync_moment,
},
}
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return self.train_batch(batchs, model, optimizer, fp16, amp_level)
def train_batch(self, batchs, model, optimizer, fp16=False, amp_level="O1"):
losses = []
if fp16:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
for batch in batchs:
with paddle.amp.auto_cast(enable=fp16, level=amp_level):
output = model(batch)
loss = output.mean()
losses.append(loss.numpy())
if fp16:
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.clear_grad()
return losses
def mp_sync_base(
self, mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False
):
batchs = []
for _ in range(5):
np_data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
batchs.append(paddle.to_tensor(np_data))
losses = self.build_model_optimizer_train(batchs)
losses_sync = self.build_model_optimizer_train(
batchs,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses)):
np.testing.assert_allclose(
losses[i], losses_sync[i], rtol=1e-5, atol=1e-5
)
# test fp16 O1
losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True)
losses_sync_fp16 = self.build_model_optimizer_train(
batchs,
fp16=True,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses_fp16)):
np.testing.assert_allclose(
losses_fp16[i], losses_sync_fp16[i], rtol=1e-5, atol=1e-5
)
# test fp16 O2
losses_fp16_O2 = self.build_model_optimizer_train(
batchs, fp16=True, amp_level="O2"
)
losses_sync_fp16_O2 = self.build_model_optimizer_train(
batchs,
fp16=True,
amp_level="O2",
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses_fp16_O2)):
np.testing.assert_allclose(
losses_fp16_O2[i], losses_sync_fp16_O2[i], rtol=1e-5, atol=1e-5
)
def test_mp_sync_param(self):
self.mp_sync_base(mp_sync_param=True)
def test_mp_sync_grad(self):
self.mp_sync_base(mp_sync_grad=True)
def test_mp_sync_moment(self):
self.mp_sync_base(mp_sync_moment=True)
def test_mp_sync_all(self):
self.mp_sync_base(
mp_sync_param=True, mp_sync_grad=True, mp_sync_moment=True
)
class TestDistSPSyncModelTraning(TestDistSPSyncTraning):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
"sync_mode": "average",
"sync_param_name": ["embedding", "layer_norm", ".b_"],
},
}
fleet.init(is_collective=True, strategy=strategy)
class TestDistSPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
def train_batch(self, batch, model, optimizer, is_mp):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self, model):
optimizer = paddle.optimizer.SGD(
learning_rate=0.001, parameters=model.parameters()
)
return optimizer
def build_model_optimizer(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model_a = SimpleSPNet(
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
)
optimizer_a = self.build_optimizer(model_a)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_b = self.build_optimizer(model_b)
return model_a, optimizer_a, model_b, optimizer_b
def test_mp_model(self):
(
model_a,
optimizer_a,
model_b,
optimizer_b,
) = self.build_model_optimizer()
for _ in range(5):
np_data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
batch = paddle.to_tensor(np_data)
loss_a = self.train_batch(batch, model_a, optimizer_a, True)
loss_b = self.train_batch(batch, model_b, optimizer_b, False)
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=1e-5, atol=1e-5
)
if __name__ == "__main__":
unittest.main()
......@@ -24,6 +24,11 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_mp_model(self):
self.run_mnist_2gpu('hybrid_parallel_mp_model.py')
def test_hybrid_parallel_mp_model_with_sequence_parallel(self):
self.run_mnist_2gpu(
'hybrid_parallel_mp_model_with_sequence_parallel.py'
)
def test_hybrid_parallel_mp_amp(self):
self.run_mnist_2gpu('hybrid_parallel_mp_amp.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册