未验证 提交 6b74cf76 编写于 作者: W wuhuachaocoding 提交者: GitHub

mp sync params & grads & opt states. (#51428)

上级 f80a0fe9
......@@ -50,11 +50,19 @@ message ShardingConfig {
optional bool enable_tuning = 15 [ default = false ]; // incubate for auto parallel
}
// for dygraph
message MpConfig {
optional bool sync_param= 1 [ default = false ];
optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ];
}
message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
optional MpConfig mp_configs = 5;
}
message AMPConfig {
......
......@@ -1696,6 +1696,12 @@ class DistributedStrategy:
check_configs_key(
self.strategy.hybrid_configs, hybrid_config, "hybrid_configs"
)
if "mp_configs" in configs:
assign_configs_value(
self.strategy.hybrid_configs.mp_configs, configs["mp_configs"]
)
configs.pop("mp_configs")
assign_configs_value(self.strategy.hybrid_configs, configs)
@property
......
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import framework
from paddle.autograd import no_grad
from paddle.distributed import fleet
from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip
......@@ -292,6 +294,83 @@ class HybridParallelOptimizer:
self._inner_opt._grad_clip, hcg
)
def _filter_fn(self, param):
p_name = param.name
tar_param = ["embedding", "layer_norm", ".b_"]
if param.is_distributed is False:
for tar in tar_param:
if tar in p_name:
return True
return False
def _step(self, parameters_list):
mp_group = self._hcg.get_model_parallel_group()
src_rank = self._hcg.get_model_parallel_group_src_rank()
params = None
mp_configs = None
if mp_group.nranks > 1:
mp_configs = fleet.fleet._user_defined_strategy.hybrid_configs[
"mp_configs"
]
if mp_configs and (
mp_configs.sync_param
or mp_configs.sync_grad
or mp_configs.sync_moment
):
params = sorted(
[p for p in parameters_list if self._filter_fn(p)],
key=lambda p: p.name,
)
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_grad:
for p in params:
if p.grad is None:
continue
paddle.distributed.broadcast(
p.grad, src=src_rank, group=mp_group, sync_op=True
)
self._inner_opt.step()
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_param:
for p in params:
paddle.distributed.broadcast(
p, src=src_rank, group=mp_group, sync_op=True
)
if mp_group.nranks > 1 and mp_configs and mp_configs.sync_moment:
for p in params:
# support opt state of adam and adamw to broadcast now.
if isinstance(
self._inner_opt,
(paddle.optimizer.Adam, paddle.optimizer.AdamW),
):
if (
self._inner_opt._multi_precision
and p.name in self._master_weights
):
paddle.distributed.broadcast(
self._inner_opt._master_weights[p.name],
src=src_rank,
group=mp_group,
sync_op=True,
)
moment1 = self._inner_opt._get_accumulator(
self._inner_opt._moment1_acc_str, p
)
moment2 = self._inner_opt._get_accumulator(
self._inner_opt._moment2_acc_str, p
)
paddle.distributed.broadcast(
moment1, src=src_rank, group=mp_group, sync_op=True
)
paddle.distributed.broadcast(
moment2, src=src_rank, group=mp_group, sync_op=True
)
@no_grad()
@framework.dygraph_only
def step(self):
......@@ -302,7 +381,7 @@ class HybridParallelOptimizer:
if self._dp_enable:
fused_allreduce_gradients(list(parameters_list), self._hcg)
self._inner_opt.step()
self._step(parameters_list)
@no_grad()
def minimize(
......
......@@ -181,6 +181,150 @@ class SimpleDPNet(paddle.nn.Layer):
return x
class TestDistMPSyncTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": False,
"sync_grad": False,
"sync_moment": False,
},
}
fleet.init(is_collective=True, strategy=strategy)
def build_model_optimizer_train(
self,
batchs,
fp16=False,
mp_sync_param=False,
mp_sync_grad=False,
mp_sync_moment=False,
):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
paddle.seed(2023)
np.random.seed(2023)
random.seed(2023)
set_random_seed(1024, dp_id, rank_id)
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model = SimpleMPNet(
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
)
optimizer = paddle.optimizer.AdamW(
learning_rate=0.1, parameters=model.parameters()
)
strategy = fleet.fleet._user_defined_strategy
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1,
"mp_configs": {
"sync_param": mp_sync_param,
"sync_grad": mp_sync_grad,
"sync_moment": mp_sync_moment,
},
}
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
return self.train_batch(batchs, model, optimizer, fp16)
def train_batch(self, batchs, model, optimizer, fp16=False):
losses = []
if fp16:
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = fleet.distributed_scaler(scaler)
for batch in batchs:
with paddle.amp.auto_cast(enable=fp16, level='O1'):
output = model(batch)
loss = output.mean()
losses.append(loss.numpy())
if fp16:
scaled = scaler.scale(loss)
scaled.backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.clear_grad()
return losses
def mp_sync_base(
self, mp_sync_param=False, mp_sync_grad=False, mp_sync_moment=False
):
batchs = []
for _ in range(5):
np_data = np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
batchs.append(paddle.to_tensor(np_data))
losses = self.build_model_optimizer_train(batchs)
losses_sync = self.build_model_optimizer_train(
batchs,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses)):
np.testing.assert_allclose(losses[i], losses_sync[i], rtol=1e-6)
# test fp16
losses_fp16 = self.build_model_optimizer_train(batchs, fp16=True)
losses_sync_fp16 = self.build_model_optimizer_train(
batchs,
fp16=True,
mp_sync_param=mp_sync_param,
mp_sync_grad=mp_sync_grad,
mp_sync_moment=mp_sync_moment,
)
for i in range(len(losses_fp16)):
np.testing.assert_allclose(
losses_fp16[i], losses_sync_fp16[i], rtol=1e-6
)
def test_mp_sync_param(self):
self.mp_sync_base(mp_sync_param=True)
def test_mp_sync_grad(self):
self.mp_sync_base(mp_sync_grad=True)
def test_mp_sync_moment(self):
self.mp_sync_base(mp_sync_moment=True)
def test_mp_sync_all(self):
self.mp_sync_base(
mp_sync_param=True, mp_sync_grad=True, mp_sync_moment=True
)
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册