未验证 提交 a00f5bd4 编写于 作者: Y Yuang Liu 提交者: GitHub

Optimize check finite when using sharding comm overlap. (#55766)

上级 dc82fa96
......@@ -96,20 +96,17 @@ class DygraphShardingOptimizer:
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
if not self.tensor_fusion:
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
if not self.tensor_fusion and not self.comm_overlap:
local_params = self._rank2params[self._sharding_rank]
self._set_inner_opt_attr('_parameter_list', local_params)
self._set_inner_opt_attr('_param_groups', local_params)
else:
self._tensor_fusion()
decay_params = [
p.name for p in self._rank2decay[self._sharding_rank]
]
fused_params = self._rank2fused[self._sharding_rank]
local_fused_params = self._rank2fused[self._sharding_rank]
apply_decay_param_fun = lambda x: x in decay_params
all_fused_params = []
......@@ -118,8 +115,15 @@ class DygraphShardingOptimizer:
self._parameter_list = all_fused_params
self._param_groups = all_fused_params
self._set_inner_opt_attr('_parameter_list', fused_params)
self._set_inner_opt_attr('_param_groups', fused_params)
self._set_inner_opt_attr('_parameter_list', local_fused_params)
self._set_inner_opt_attr('_param_groups', local_fused_params)
if self.comm_overlap:
# Only set local param for check finite when comm overlap.
# Under comm overlap, all grads will be communicated before check_finite.
# Therefore, each sharding rank can get all grads' info at check_finite.
# Without comm overlap, all grads will be communicated after check_finite,
# which means each sharding rank should do check_finite to all grads.
self._local_parameter_list = local_fused_params
origin_decay_param_fun = getattr(
self._inner_opt, '_apply_decay_param_fun', None
)
......@@ -127,6 +131,14 @@ class DygraphShardingOptimizer:
self._set_inner_opt_attr(
'_apply_decay_param_fun', apply_decay_param_fun
)
# Note: during the tensor fusion for parameters, the allocator will apply for
# some extra GPU memory for the fused big paramters. This extra GPU memory will
# be useless at once the fusion has done. But the Paddle's allocator won't
# release those memory, it will hold that part in the memory poll. So after
# tensor fusion, the 'reserved' memory will increase but the 'allocate' memory
# won't change. To avoid failure on some other applications (such as some nvtx
# operations), here we manulay let the allocator release the cached memory.
paddle.device.cuda.empty_cache()
def clear_grad(self, set_to_zero=True):
"""
......
......@@ -47,20 +47,35 @@ def distributed_scaler(scaler):
else:
param_grads_fp32.append(param._grad_ivar())
else:
param_grads = [
param._grad_ivar()
for param in optimizer._parameter_list
if param._grad_ivar() is not None
]
strategy = fleet.fleet._user_defined_strategy
sharding_stage_1_overlap = strategy.hybrid_configs[
'sharding_configs'
].comm_overlap
if sharding_stage_1_overlap:
# If sharding stage 1 enable comm overlap and need do loss scale. Here we have to wait all comm tasks.
# If no need do loss scale, the wait for all comm tasks will do in the optimizer step.
assert hasattr(optimizer, "_comm_buffers")
assert hasattr(optimizer, "_sharding_enable")
if optimizer._sharding_enable:
# disable origin grad reduce in hybrid optimizer step
optimizer._sharding_enable = False
for buffer in optimizer._comm_buffers:
buffer.scale_grads()
# For sharding stage 1 under comm overlap, each rank only have to check finite for the response params.
# For now, only sharding stage 1 contains this attr, this can be promoted to stage 2 and stage 3.
assert hasattr(optimizer, "_local_parameter_list")
parameters = optimizer._local_parameter_list
else:
parameters = optimizer._parameter_list
param_grads_fp16 = [
param._grad_ivar()
for param in optimizer._parameter_list
for param in parameters
if (param._grad_ivar() is not None)
and (param._grad_ivar().dtype == core.VarDesc.VarType.FP16)
]
param_grads_fp32 = [
param._grad_ivar()
for param in optimizer._parameter_list
for param in parameters
if (param._grad_ivar() is not None)
and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32)
]
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class SimpleDPNet(paddle.nn.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size):
super().__init__()
self.linear1 = paddle.nn.Linear(hidden_size, inner_size)
self.linear2 = paddle.nn.Linear(inner_size, hidden_size)
self.linear3 = paddle.nn.Linear(hidden_size, output_size)
self.embedding = paddle.nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistSharding(unittest.TestCase):
def setUp(self):
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
self.strategy.hybrid_configs["sharding_configs"].tensor_fusion = True
self.strategy.hybrid_configs["sharding_configs"].comm_overlap = True
self.strategy.hybrid_configs["sharding_configs"].accumulate_steps = 1
fleet.init(is_collective=True, strategy=self.strategy)
self.data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
if paddle.distributed.get_rank() == 0:
self.batch_sharding = paddle.to_tensor(self.data[:2])
else:
self.batch_sharding = paddle.to_tensor(self.data[2:])
def build_optimizer(self, model):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.001,
grad_clip=clip,
)
return optimizer
def build_model_optimizer(self):
model = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size)
optimizer = self.build_optimizer(model)
model, optimizer = paddle.amp.decorate(
model, optimizers=optimizer, level="O2", dtype="float16"
)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return model, optimizer, scaler
def sharding_model(self):
model, optimizer, scaler = self.build_model_optimizer()
for idx in range(STEPS):
with paddle.amp.auto_cast(enable=True, level='O2'):
output = model(self.batch_sharding)
loss = output.mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
def test_sharding_adam(self):
self.sharding_model()
if __name__ == "__main__":
unittest.main()
......@@ -33,6 +33,9 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_sharding_tensor_fusion(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion.py')
def test_hybrid_parallel_sharding_tensor_fusion_amp(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model_with_fusion_amp.py')
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册