未验证 提交 99f60188 编写于 作者: S ShenLiang 提交者: GitHub

support unbalanced data for pipeline (#47199)

* add unbalanced data

* fix utest
上级 bafa890a
......@@ -355,51 +355,55 @@ class PipelineParallel(MetaParallelBase):
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def _load_micro_batch(self, cache_id):
inputs = self.data
def _check_data_vaild(self, data):
batch_size = data.shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self.micro_batch_size, self.accumulate_steps)
)
def _load_micro_batch_impl(self, inputs, cache_id):
begin = cache_id * self.micro_batch_size
end = begin + self.micro_batch_size
# The virtual first and last pipeline stage need data, all others don't need.
if isinstance(inputs, tuple):
output = []
for data in inputs:
if isinstance(data, list):
assert (
len(data) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(data),
)
output.append(data[cache_id].detach())
else:
self._check_data_vaild(data)
output.append(data[begin:end, :].detach())
return tuple(output)
elif isinstance(inputs, list):
assert (
len(inputs) == self.accumulate_steps
), "length of data should be %d, but it is %d" % (
self.accumulate_steps,
len(inputs),
)
return inputs[cache_id].detach()
else:
self._check_data_vaild(inputs)
return inputs[begin:end, :].detach()
def _load_micro_batch(self, cache_id):
inputs = self.data
if self.is_pipeline_first_stage():
assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[0], tuple):
assert (
len(inputs[0]) > 1
), "If you use tuple for input data, it should have at least two inputs."
batch_size = inputs[0][0].shape[0]
assert (
self.micro_batch_size * self.accumulate_steps == batch_size
), (
"batch_size needs to be divisible by micro_batch_size. Currently, "
"batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
% (batch_size, self.micro_batch_size, self.accumulate_steps)
)
data = [input[begin:end, :].detach() for input in inputs[0]]
return tuple(data)
else:
batch_size = inputs[0].shape[0]
assert (
self.micro_batch_size * self.accumulate_steps == batch_size
)
return inputs[0][begin:end, :].detach()
return self._load_micro_batch_impl(inputs[0], cache_id)
elif self.is_pipeline_last_stage():
assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[1], tuple):
batch_size = inputs[1][0].shape[0]
assert (
self.micro_batch_size * self.accumulate_steps == batch_size
)
data = [input[begin:end, :].detach() for input in inputs[1]]
return tuple(data)
else:
batch_size = inputs[1].shape[0]
assert (
self.micro_batch_size * self.accumulate_steps == batch_size
)
return inputs[1][begin:end, :].detach()
return self._load_micro_batch_impl(inputs[1], cache_id)
else:
# No data input is required for other stages
inputs = None
def _broadcast_final_loss(self):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer import (
TestDistPPTraning,
set_random_seed,
ModelPipe,
batch_size,
length,
micro_batch_size,
vocab_size,
)
class TestDistPPTraningUnbalancedData(TestDistPPTraning):
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x = []
for _ in range(batch_size // micro_batch_size):
size = micro_batch_size
x_data = np.random.randint(0, vocab_size, size=[size, length])
x.append(paddle.to_tensor(x_data))
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
if pp_id != 0:
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -64,6 +64,13 @@ class TestHybridPipeParallel(TestMultipleGpus):
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py')
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False)
def test_hybrid_parallel_transformer_unbalanced_data(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py')
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_unbalanced_data.py',
eager_mode=False,
)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册