未验证 提交 503f422e 编写于 作者: W wuhuachaocoding 提交者: GitHub

add mp_sync config. (#53254)

上级 00f747f2
......@@ -55,6 +55,7 @@ message MpConfig {
optional bool sync_param= 1 [ default = false ];
optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ];
optional string sync_mode= 4 [ default = 'broadcast' ];
}
message PpConfig {
......
......@@ -146,6 +146,8 @@ class DistributedStrategy:
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.sync_param_name = ["embedding", "layer_norm", ".b_"]
self.__lock_attr = True
logger.info("distributed strategy initialized")
......@@ -1698,6 +1700,10 @@ class DistributedStrategy:
)
if "mp_configs" in configs:
if "sync_param_name" in configs["mp_configs"]:
self.sync_param_name = configs["mp_configs"]["sync_param_name"]
configs["mp_configs"].pop("sync_param_name")
assign_configs_value(
self.strategy.hybrid_configs.mp_configs, configs["mp_configs"]
)
......
......@@ -303,9 +303,20 @@ class HybridParallelOptimizer:
inner_opt._grad_clip, hcg
)
def _filter_fn(self, param):
def _insert_sync(self, sync_var, src, mp_group, sync_mode):
if sync_mode == "broadcast":
paddle.distributed.broadcast(
sync_var, src=src, group=mp_group, sync_op=True
)
else:
paddle.distributed.all_reduce(
sync_var, group=mp_group, sync_op=True
)
sync_var.scale_(1.0 / mp_group.nranks)
def _filter_fn(self, param, strategy):
p_name = param.name
tar_param = ["embedding", "layer_norm", ".b_"]
tar_param = strategy.sync_param_name
if param.is_distributed is False:
for tar in tar_param:
if tar in p_name:
......@@ -329,26 +340,48 @@ class HybridParallelOptimizer:
or mp_configs.sync_moment
):
params = sorted(
[p for p in parameters_list if self._filter_fn(p)],
[
p
for p in parameters_list
if self._filter_fn(p, fleet.fleet._user_defined_strategy)
],
key=lambda p: p.name,
)
# Grad sync before opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
if p.grad is None:
continue
paddle.distributed.broadcast(
p.grad, src=src_rank, group=mp_group, sync_op=True
)
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p.grad is None
self._insert_sync(
p.main_grad, src_rank, mp_group, mp_configs.sync_mode
)
elif p.grad is not None:
self._insert_sync(
p.grad, src_rank, mp_group, mp_configs.sync_mode
)
self._inner_opt.step()
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params:
paddle.distributed.broadcast(
p, src=src_rank, group=mp_group, sync_op=True
)
# Param sync after opt
self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode)
# Master param sync after opt
if (
hasattr(self._inner_opt, "_multi_precision")
and self._inner_opt._multi_precision
and p.name in self._inner_opt._master_weights
):
self._insert_sync(
self._inner_opt._master_weights[p.name],
src_rank,
mp_group,
mp_configs.sync_mode,
)
# Moment sync after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
# support opt state of adam and adamw to broadcast now.
......@@ -357,28 +390,30 @@ class HybridParallelOptimizer:
(paddle.optimizer.Adam, paddle.optimizer.AdamW),
):
if (
self._inner_opt._multi_precision
and p.name in self._master_weights
p.name
in self._inner_opt._accumulators[
self._inner_opt._moment1_acc_str
]
):
paddle.distributed.broadcast(
self._inner_opt._master_weights[p.name],
src=src_rank,
group=mp_group,
sync_op=True,
moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, p
)
self._insert_sync(
moment1, src_rank, mp_group, mp_configs.sync_mode
)
moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, p
)
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, p
)
paddle.distributed.broadcast(
moment1, src=src_rank, group=mp_group, sync_op=True
)
paddle.distributed.broadcast(
moment2, src=src_rank, group=mp_group, sync_op=True
)
if (
p.name
in self._inner_opt._accumulators[
self._inner_opt._moment2_acc_str
]
):
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, p
)
self._insert_sync(
moment2, src_rank, mp_group, mp_configs.sync_mode
)
@no_grad()
@framework.dygraph_only
......
......@@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
self,
batchs,
fp16=False,
amp_level="O1",
mp_sync_param=False,
mp_sync_grad=False,
mp_sync_moment=False,
......@@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase):
learning_rate=0.1, parameters=model.parameters()
)
if fp16 and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2'
)
strategy = fleet.fleet._user_defined_strategy
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
......@@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase):
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return self.train_batch(batchs, model, optimizer, fp16)
return self.train_batch(batchs, model, optimizer, fp16, amp_level)
def train_batch(self, batchs, model, optimizer, fp16=False):
def train_batch(self, batchs, model, optimizer, fp16=False, amp_level="O1"):
losses = []
if fp16:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
for batch in batchs:
with paddle.amp.auto_cast(enable=fp16, level='O1'):
with paddle.amp.auto_cast(enable=fp16, level=amp_level):
output = model(batch)
loss = output.mean()
losses.append(loss.numpy())
......@@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
for i in range(len(losses)):
np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6)
# test fp16
# test fp16 O1
losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True)
losses_sync_fp16 = self.build_model_optimizer_train(
batchs,
......@@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase):
losses_fp16[i], losses_sync_fp16[i], rtol=1e-6
)
# test fp16 O2
losses_fp16_O2 = self.build_model_optimizer_train(
batchs, fp16=True, amp_level="O2"
)
losses_sync_fp16_O2 = self.build_model_optimizer_train(
batchs,
fp16=True,
amp_level="O2",
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses_fp16_O2)):
np.testing.assert_allclose(
losses_fp16_O2[i], losses_sync_fp16_O2[i], rtol=1e-6
)
def test_mp_sync_param(self):
self.mp_sync_base(mp_sync_param=True)
......@@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase):
)
class TestDistMPSyncModelTraning(TestDistMPSyncTraning):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
"sync_mode": "average",
"sync_param_name": ["embedding", "layer_norm", ".b_"],
},
}
fleet.init(is_collective=True, strategy=strategy)
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
......
......@@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
def test_hybrid_parallel_mp_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"mp_configs": {
"sync_param": True,
"sync_grad": False,
"sync_moment": False,
"sync_mode": "broadcast",
"sync_param_name": ["embedding", "layer_norm", ".w", ".b_"],
},
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_param, True)
self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_grad, False)
self.assertEqual(
strategy.hybrid_configs["mp_configs"].sync_moment, False
)
self.assertEqual(
strategy.hybrid_configs["mp_configs"].sync_mode, "broadcast"
)
self.assertEqual(
strategy.sync_param_name, ["embedding", "layer_norm", ".w", ".b_"]
)
def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册