未验证 提交 4236351c 编写于 作者: W wuhuachaocoding 提交者: GitHub

add mp_sync config. (#53204)

Co-authored-by: Ngongweibao <gongweibao@baidu.com>
上级 1f45b313
...@@ -55,6 +55,7 @@ message MpConfig { ...@@ -55,6 +55,7 @@ message MpConfig {
optional bool sync_param= 1 [ default = false ]; optional bool sync_param= 1 [ default = false ];
optional bool sync_grad= 2 [ default = false ]; optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ]; optional bool sync_moment= 3 [ default = false ];
optional string sync_mode= 4 [ default = 'broadcast' ];
} }
message PpConfig { message PpConfig {
......
...@@ -146,6 +146,8 @@ class DistributedStrategy: ...@@ -146,6 +146,8 @@ class DistributedStrategy:
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp'] self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.sync_param_name = ["embedding", "layer_norm", ".b_"]
self.__lock_attr = True self.__lock_attr = True
logger.info("distributed strategy initialized") logger.info("distributed strategy initialized")
...@@ -1698,6 +1700,10 @@ class DistributedStrategy: ...@@ -1698,6 +1700,10 @@ class DistributedStrategy:
) )
if "mp_configs" in configs: if "mp_configs" in configs:
if "sync_param_name" in configs["mp_configs"]:
self.sync_param_name = configs["mp_configs"]["sync_param_name"]
configs["mp_configs"].pop("sync_param_name")
assign_configs_value( assign_configs_value(
self.strategy.hybrid_configs.mp_configs, configs["mp_configs"] self.strategy.hybrid_configs.mp_configs, configs["mp_configs"]
) )
......
...@@ -298,9 +298,20 @@ class HybridParallelOptimizer: ...@@ -298,9 +298,20 @@ class HybridParallelOptimizer:
inner_opt._grad_clip, hcg inner_opt._grad_clip, hcg
) )
def _filter_fn(self, param): def _insert_sync(self, sync_var, src, mp_group, sync_mode):
if sync_mode == "broadcast":
paddle.distributed.broadcast(
sync_var, src=src, group=mp_group, sync_op=True
)
else:
paddle.distributed.all_reduce(
sync_var, group=mp_group, sync_op=True
)
sync_var.scale_(1.0 / mp_group.nranks)
def _filter_fn(self, param, strategy):
p_name = param.name p_name = param.name
tar_param = ["embedding", "layer_norm", ".b_"] tar_param = strategy.sync_param_name
if param.is_distributed is False: if param.is_distributed is False:
for tar in tar_param: for tar in tar_param:
if tar in p_name: if tar in p_name:
...@@ -324,26 +335,48 @@ class HybridParallelOptimizer: ...@@ -324,26 +335,48 @@ class HybridParallelOptimizer:
or mp_configs.sync_moment or mp_configs.sync_moment
): ):
params = sorted( params = sorted(
[p for p in parameters_list if self._filter_fn(p)], [
p
for p in parameters_list
if self._filter_fn(p, fleet.fleet._user_defined_strategy)
],
key=lambda p: p.name, key=lambda p: p.name,
) )
# Grad sync before opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad: if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params: for p in params:
if p.grad is None: if hasattr(p, "main_grad") and p.main_grad is not None:
continue assert p.grad is None
paddle.distributed.broadcast( self._insert_sync(
p.grad, src=src_rank, group=mp_group, sync_op=True p.main_grad, src_rank, mp_group, mp_configs.sync_mode
) )
elif p.grad is not None:
self._insert_sync(
p.grad, src_rank, mp_group, mp_configs.sync_mode
)
self._inner_opt.step() self._inner_opt.step()
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param: if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params: for p in params:
paddle.distributed.broadcast( # Param sync after opt
p, src=src_rank, group=mp_group, sync_op=True self._insert_sync(p, src_rank, mp_group, mp_configs.sync_mode)
)
# Master param sync after opt
if (
hasattr(self._inner_opt, "_multi_precision")
and self._inner_opt._multi_precision
and p.name in self._inner_opt._master_weights
):
self._insert_sync(
self._inner_opt._master_weights[p.name],
src_rank,
mp_group,
mp_configs.sync_mode,
)
# Moment sync after opt
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment: if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params: for p in params:
# support opt state of adam and adamw to broadcast now. # support opt state of adam and adamw to broadcast now.
...@@ -352,28 +385,30 @@ class HybridParallelOptimizer: ...@@ -352,28 +385,30 @@ class HybridParallelOptimizer:
(paddle.optimizer.Adam, paddle.optimizer.AdamW), (paddle.optimizer.Adam, paddle.optimizer.AdamW),
): ):
if ( if (
self._inner_opt._multi_precision p.name
and p.name in self._master_weights in self._inner_opt._accumulators[
self._inner_opt._moment1_acc_str
]
): ):
paddle.distributed.broadcast( moment1 = self._inner_opt._get_accumulator(
self._inner_opt._master_weights[p.name], self._inner_opt._moment1_acc_str, p
src=src_rank, )
group=mp_group, self._insert_sync(
sync_op=True, moment1, src_rank, mp_group, mp_configs.sync_mode
) )
moment1 = self._inner_opt._get_accumulator( if (
self._inner_opt._moment1_acc_str, p p.name
) in self._inner_opt._accumulators[
moment2 = self._inner_opt._get_accumulator( self._inner_opt._moment2_acc_str
self._inner_opt._moment2_acc_str, p ]
) ):
paddle.distributed.broadcast( moment2 = self._inner_opt._get_accumulator(
moment1, src=src_rank, group=mp_group, sync_op=True self._inner_opt._moment2_acc_str, p
) )
paddle.distributed.broadcast( self._insert_sync(
moment2, src=src_rank, group=mp_group, sync_op=True moment2, src_rank, mp_group, mp_configs.sync_mode
) )
@no_grad() @no_grad()
@framework.dygraph_only @framework.dygraph_only
......
...@@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -202,6 +202,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
self, self,
batchs, batchs,
fp16=False, fp16=False,
amp_level="O1",
mp_sync_param=False, mp_sync_param=False,
mp_sync_grad=False, mp_sync_grad=False,
mp_sync_moment=False, mp_sync_moment=False,
...@@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -232,6 +233,11 @@ class TestDistMPSyncTraning(unittest.TestCase):
learning_rate=0.1, parameters=model.parameters() learning_rate=0.1, parameters=model.parameters()
) )
if fp16 and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2'
)
strategy = fleet.fleet._user_defined_strategy strategy = fleet.fleet._user_defined_strategy
strategy.hybrid_configs = { strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size, "dp_degree": self.data_parallel_size,
...@@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -246,15 +252,15 @@ class TestDistMPSyncTraning(unittest.TestCase):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
return self.train_batch(batchs, model, optimizer, fp16) return self.train_batch(batchs, model, optimizer, fp16, amp_level)
def train_batch(self, batchs, model, optimizer, fp16=False): def train_batch(self, batchs, model, optimizer, fp16=False, amp_level="O1"):
losses = [] losses = []
if fp16: if fp16:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler) scaler = fleet.distributed_scaler(scaler)
for batch in batchs: for batch in batchs:
with paddle.amp.auto_cast(enable=fp16, level='O1'): with paddle.amp.auto_cast(enable=fp16, level=amp_level):
output = model(batch) output = model(batch)
loss = output.mean() loss = output.mean()
losses.append(loss.numpy()) losses.append(loss.numpy())
...@@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -295,7 +301,7 @@ class TestDistMPSyncTraning(unittest.TestCase):
for i in range(len(losses)): for i in range(len(losses)):
np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6) np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6)
# test fp16 # test fp16 O1
losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True) losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True)
losses_sync_fp16 = self.build_model_optimizer_train( losses_sync_fp16 = self.build_model_optimizer_train(
batchs, batchs,
...@@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -310,6 +316,24 @@ class TestDistMPSyncTraning(unittest.TestCase):
losses_fp16[i], losses_sync_fp16[i], rtol=1e-6 losses_fp16[i], losses_sync_fp16[i], rtol=1e-6
) )
# test fp16 O2
losses_fp16_O2 = self.build_model_optimizer_train(
batchs, fp16=True, amp_level="O2"
)
losses_sync_fp16_O2 = self.build_model_optimizer_train(
batchs,
fp16=True,
amp_level="O2",
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses_fp16_O2)):
np.testing.assert_allclose(
losses_fp16_O2[i], losses_sync_fp16_O2[i], rtol=1e-6
)
def test_mp_sync_param(self): def test_mp_sync_param(self):
self.mp_sync_base(mp_sync_param=True) self.mp_sync_base(mp_sync_param=True)
...@@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase): ...@@ -325,6 +349,26 @@ class TestDistMPSyncTraning(unittest.TestCase):
) )
class TestDistMPSyncModelTraning(TestDistMPSyncTraning):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
"sync_mode": "average",
"sync_param_name": ["embedding", "layer_norm", ".b_"],
},
}
fleet.init(is_collective=True, strategy=strategy)
class TestDistMPTraning(unittest.TestCase): class TestDistMPTraning(unittest.TestCase):
def setUp(self): def setUp(self):
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
......
...@@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -84,6 +84,36 @@ class TestStrategyConfig(unittest.TestCase):
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
def test_hybrid_parallel_mp_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 4,
"mp_configs": {
"sync_param": True,
"sync_grad": False,
"sync_moment": False,
"sync_mode": "broadcast",
"sync_param_name": ["embedding", "layer_norm", ".w", ".b_"],
},
}
self.assertEqual(strategy.hybrid_configs["dp_degree"], 1)
self.assertEqual(strategy.hybrid_configs["mp_degree"], 2)
self.assertEqual(strategy.hybrid_configs["pp_degree"], 4)
self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_param, True)
self.assertEqual(strategy.hybrid_configs["mp_configs"].sync_grad, False)
self.assertEqual(
strategy.hybrid_configs["mp_configs"].sync_moment, False
)
self.assertEqual(
strategy.hybrid_configs["mp_configs"].sync_mode, "broadcast"
)
self.assertEqual(
strategy.sync_param_name, ["embedding", "layer_norm", ".w", ".b_"]
)
def test_hybrid_parallel_configs_order(self): def test_hybrid_parallel_configs_order(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.hybrid_configs = { strategy.hybrid_configs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册