未验证 提交 d4bf8b1a 编写于 作者: S ShenLiang 提交者: GitHub

support unbalanced data for pipeline (#47199) (#47569)

* add unbalanced data

* fix utest
上级 ba4fbe71
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer import (
TestDistPPTraning,
set_random_seed,
ModelPipe,
batch_size,
length,
micro_batch_size,
vocab_size,
)
class TestDistPPTraningUnbalancedData(TestDistPPTraning):
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True
)
optimizer = paddle.optimizer.SGD(
learning_rate=scheduler, parameters=model.parameters()
)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x = []
for _ in range(batch_size // micro_batch_size):
size = micro_batch_size
x_data = np.random.randint(0, vocab_size, size=[size, length])
x.append(paddle.to_tensor(x_data))
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
if pp_id != 0:
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
unittest.main()
...@@ -22,13 +22,14 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -22,13 +22,14 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridPipeParallel(TestMultipleGpus): class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_pp_layer(self): def test_hybrid_parallel_pp_layer(self):
self.run_mnist_2gpu( self.run_mnist_2gpu(
os.path.abspath('../../hybrid_parallel_pp_layer.py')) os.path.abspath('../../hybrid_parallel_pp_layer.py')
)
self.run_mnist_2gpu( self.run_mnist_2gpu(
os.path.abspath('../../hybrid_parallel_pp_layer.py'), os.path.abspath('../../hybrid_parallel_pp_layer.py'),
eager_mode=False) eager_mode=False,
)
def test_hybrid_parallel_pp_tuple_inputs(self): def test_hybrid_parallel_pp_tuple_inputs(self):
self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py') self.run_mnist_2gpu('hybrid_parallel_pp_embedding.py')
...@@ -36,8 +37,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -36,8 +37,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_shared_weight(self): def test_hybrid_parallel_shared_weight(self):
self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') self.run_mnist_2gpu('hybrid_parallel_shared_weight.py')
self.run_mnist_2gpu('hybrid_parallel_shared_weight.py', self.run_mnist_2gpu(
eager_mode=False) 'hybrid_parallel_shared_weight.py', eager_mode=False
)
def test_pipeline_parallel_amp(self): def test_pipeline_parallel_amp(self):
self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') self.run_mnist_2gpu('hybrid_parallel_pp_amp.py')
...@@ -49,8 +51,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -49,8 +51,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_transformer(self): def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py', self.run_mnist_2gpu(
eager_mode=False) 'hybrid_parallel_pp_transformer.py', eager_mode=False
)
def test_hybrid_parallel_save_load(self): def test_hybrid_parallel_save_load(self):
self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py') self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py')
...@@ -64,6 +67,13 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -64,6 +67,13 @@ class TestHybridPipeParallel(TestMultipleGpus):
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py') self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py')
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False) self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py', eager_mode=False)
def test_hybrid_parallel_transformer_unbalanced_data(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer_unbalanced_data.py')
self.run_mnist_2gpu(
'hybrid_parallel_pp_transformer_unbalanced_data.py',
eager_mode=False,
)
if __name__ == "__main__": if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1" os.environ["FLAGS_enable_eager_mode"] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册