未验证 提交 0b911330 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel] Add amp support for pipeline_parallel (#33951)

* add amp support for pp

* add amp untest
上级 2ef6188b
...@@ -30,8 +30,8 @@ class HybridParallelGradScaler: ...@@ -30,8 +30,8 @@ class HybridParallelGradScaler:
def __init__(self, scaler, hcg): def __init__(self, scaler, hcg):
self._scaler = scaler self._scaler = scaler
self._hcg = hcg self._hcg = hcg
self._is_mp = ( self._use_dp_mode = (
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL) self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL)
def scale(self, var): def scale(self, var):
return self._scaler.scale(var) return self._scaler.scale(var)
...@@ -67,7 +67,7 @@ class HybridParallelGradScaler: ...@@ -67,7 +67,7 @@ class HybridParallelGradScaler:
core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads, core.ops.check_finite_and_unscale(param_grads, self._scale, param_grads,
self._found_inf) self._found_inf)
# allreduce_max found_inf in check_group # allreduce_max found_inf in check_group
if self._is_mp: if not self._use_dp_mode:
self._found_inf = paddle.cast(self._found_inf, dtype="int32") self._found_inf = paddle.cast(self._found_inf, dtype="int32")
# TODO(shenliang03) Since the minimize call in the optimizer is # TODO(shenliang03) Since the minimize call in the optimizer is
# after the gradscaler, check_finite needs to synchronize global # after the gradscaler, check_finite needs to synchronize global
......
...@@ -106,11 +106,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -106,11 +106,12 @@ class PipelineParallel(MetaParallelBase):
group=self.pp_group) group=self.pp_group)
return loss return loss
def train_batch(self, data, optimizer, lr_scheduler=None): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
assert isinstance(optimizer, HybridParallelOptimizer), ( assert isinstance(optimizer, HybridParallelOptimizer), (
'optimizer should be HybridParallelOptimizer subclass.') 'optimizer should be HybridParallelOptimizer subclass.')
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.scaler = scaler
assert fluid.framework._dygraph_tracer()._has_grad, ( assert fluid.framework._dygraph_tracer()._has_grad, (
'Please enable the generation of gradients.') 'Please enable the generation of gradients.')
...@@ -143,8 +144,8 @@ class PipelineParallel(MetaParallelBase): ...@@ -143,8 +144,8 @@ class PipelineParallel(MetaParallelBase):
self._layers.allreduce_shared_weight_gradients() self._layers.allreduce_shared_weight_gradients()
# optimizer # optimizer
self._step()
self.train_loss = self._reduce_final_loss() self.train_loss = self._reduce_final_loss()
self._step()
return self.train_loss return self.train_loss
def _forward(self, cache_id): def _forward(self, cache_id):
...@@ -192,7 +193,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -192,7 +193,12 @@ class PipelineParallel(MetaParallelBase):
def _backward(self, cache_id): def _backward(self, cache_id):
if self.is_last_stage: if self.is_last_stage:
if self.scaler:
paddle.autograd.backward(
self.scaler.scale(self.caches['outputs'][cache_id]))
else:
paddle.autograd.backward(self.caches['outputs'][cache_id]) paddle.autograd.backward(self.caches['outputs'][cache_id])
self._send_gradients(cache_id) self._send_gradients(cache_id)
return return
self._recv_gradients(cache_id) self._recv_gradients(cache_id)
...@@ -441,6 +447,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -441,6 +447,9 @@ class PipelineParallel(MetaParallelBase):
p2p.recv(d, self.next_stage_id) p2p.recv(d, self.next_stage_id)
def _step(self): def _step(self):
if self.scaler:
self.scaler.minimize(self.optimizer, self.train_loss)
else:
self.optimizer.step() self.optimizer.step()
self.optimizer.clear_grad() self.optimizer.clear_grad()
if self.lr_scheduler: if self.lr_scheduler:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 4
micro_batch_size = 2
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)
#construct model a
model_a = AlexNet(10)
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
parameters=model_a.parameters())
scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)
param_len = len(model_a.parameters())
parameters = []
for param in model_a.parameters():
parameters.append(param.numpy())
# construct model b
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
parameters=model_b.parameters())
model_b = fleet.distributed_model(model_b)
optimizer_b = fleet.distributed_optimizer(optimizer_b)
scaler_b = paddle.amp.GradScaler(init_loss_scaling=2**5)
scaler_b = fleet.distributed_scaler(scaler_b)
for idx, param in enumerate(model_b.parameters()):
param.set_value(parameters[idx + pp_id * (param_len // 2)])
# construct reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True)
for step_id, data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype('float32').reshape(
batch_size, 1, 28, 28)
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
img.stop_gradient = True
label.stop_gradient = True
if step_id >= 5:
return True
with paddle.amp.auto_cast():
loss_a = model_a(img, label)
scaler_a.scale(loss_a).backward()
scaler_a.minimize(optimizer_a, loss_a)
optimizer_a.clear_grad()
scheduler_a.step()
with paddle.amp.auto_cast():
loss_b = model_b.train_batch(
[img, label], optimizer_b, scheduler_b, scaler=scaler_b)
print("loss: ", loss_a.numpy(), loss_b.numpy())
np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=5e-5)
if __name__ == "__main__":
unittest.main()
...@@ -30,6 +30,9 @@ class TestHybridPipeParallel(TestMultipleGpus): ...@@ -30,6 +30,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_pp_tuple_inputs(self): def test_hybrid_parallel_pp_tuple_inputs(self):
self.run_mnist_2gpu('hybrid_parallel_shared_weight.py') self.run_mnist_2gpu('hybrid_parallel_shared_weight.py')
def test_pipeline_parallel(self):
self.run_mnist_2gpu('hybrid_parallel_pp_amp.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册