未验证 提交 c3ebe13b 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed]Support param_group in sharding-stage1 (#56626)

* support param group in sharding

* fix utest
上级 f2562f19
......@@ -15,12 +15,14 @@
######
import os
from collections import defaultdict
from distutils.util import strtobool
from functools import reduce
import paddle
from paddle import framework
from paddle.fluid.dygraph import base as imperative_base
from paddle.nn import ClipGradByGlobalNorm
from ...utils.log_util import logger
......@@ -55,11 +57,6 @@ class DygraphShardingOptimizer:
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def __init__(self, optimizer, hcg):
# TODO(pangengzheng): support param_groups
if isinstance(optimizer._parameter_list[0], dict):
raise TypeError(
"Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
)
if not hasattr(optimizer, '_apply_optimize') or not callable(
optimizer._apply_optimize
):
......@@ -67,8 +64,20 @@ class DygraphShardingOptimizer:
"the optimzier object should have _apply_optimize function"
)
# the self._parameter_list holds the whole model paramters
self._parameter_list = optimizer._parameter_list
self._using_param_groups = isinstance(
optimizer._parameter_list[0], dict
)
self._parameter_list = []
self._param_2_group_id = {}
if self._using_param_groups:
for idx, param_group in enumerate(optimizer._param_groups):
for param in param_group['params']:
self._param_2_group_id[id(param)] = idx
self._parameter_list.append(param)
else:
self._parameter_list = optimizer._parameter_list
self._inner_opt = optimizer
self._hcg = hcg
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
......@@ -77,18 +86,35 @@ class DygraphShardingOptimizer:
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
if self._using_param_groups:
param_groups = [
{"params": []} for _ in range(len(optimizer._param_groups))
]
for idx, pg in enumerate(optimizer._param_groups):
param_groups[idx].update(
{k: v for k, v in pg.items() if k != 'params'}
)
for param in self._rank2params[self._sharding_rank]:
group_id = self._param_2_group_id[id(param)]
param_groups[group_id]['params'].append(param)
self._set_inner_opt_attr('_param_groups', param_groups)
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._param_groups = self._parameter_list
else:
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
def clear_grad(self, set_to_zero=True):
"""
should clear grad for all parameters in model
"""
#
for p in self._parameter_list:
if hasattr(p, "main_grad") and p.main_grad is not None:
assert p._grad_ivar() is None
......@@ -225,6 +251,9 @@ class DygraphShardingOptimizer:
# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step
assert (
not self._using_param_groups
), "minimize() is not support if using param_groups"
input_param_names = {param.name for param in parameters}
parameters = list(
filter(
......@@ -250,7 +279,7 @@ class DygraphShardingOptimizer:
# otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
origin_clip = self._inner_opt._grad_clip
if not isinstance(self._parameter_list[0], dict):
if not self._using_param_groups:
params_grads = []
for param in self._parameter_list:
if (
......@@ -286,6 +315,35 @@ class DygraphShardingOptimizer:
if g_shard_norm_align_dp:
# restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip)
else:
# optimize parameters in groups
for param_group in self._inner_opt._param_groups:
params_grads = defaultdict(lambda: [])
# TODO(shenliang03): support ClipGradByGlobalNorm in sharding when using param_groups
grad_clip = param_group['grad_clip']
assert not isinstance(
grad_clip, ClipGradByGlobalNorm
), "ClipGradByGlobalNorm is not support if using param_groups in sharding"
for param in param_group['params']:
if param.stop_gradient:
continue
grad_var = param._grad_ivar()
if (
hasattr(param, "main_grad")
and param.main_grad is not None
):
grad_var = param.main_grad
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
# sync parameters across sharding ranks
self._sharding_sync_parameters()
......
......@@ -18,7 +18,6 @@ import unittest
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
......@@ -57,72 +56,6 @@ def parallel_matmul(lm_output, logit_weights, parallel_output):
return logits
class SimpleMPNet(paddle.nn.Layer):
def __init__(
self,
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
):
super().__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, : (inner_size // 2)]
init_fc2_data = np_fc2[: (inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2) :]
init_fc2_data = np_fc2[(inner_size // 2) :, :]
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)
),
gather_output=False,
has_bias=True,
)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)
),
input_is_parallel=True,
has_bias=True,
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = parallel_matmul(x, self.embedding.weight, False)
return x
class SimpleDPNet(paddle.nn.Layer):
def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
......@@ -230,12 +163,6 @@ class TestDistMPTraning(unittest.TestCase):
return optimizer
def build_model_optimizer(self, Optimizer="adam"):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
sharding_id = hcg.get_sharding_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from hybrid_parallel_sharding_model import SimpleDPNet
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)
self.data = [
np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
for _ in range(STEPS)
]
def train_batch(self, batch, model, optimizer):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self, model, strategy=None, Optimizer="adam"):
clip = paddle.nn.ClipGradByNorm(0.7)
param_groups = [
{
"params": model.linear1.parameters(),
"weight_decay": 0.0001,
"learning_rate": 0.1,
},
{
"params": model.linear2.parameters(),
"weight_decay": 0.020,
"learning_rate": 0.01,
},
{
"params": model.linear3.parameters(),
"weight_decay": 0.1,
"learning_rate": 0.1,
},
]
if Optimizer == "adam":
optimizer = paddle.optimizer.AdamW(
parameters=param_groups,
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip,
)
return optimizer
def build_model_optimizer(self, Optimizer="adam"):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model_a = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_a = self.build_optimizer(
model_a,
strategy=self.strategy,
Optimizer=Optimizer,
)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_b = self.build_optimizer(
model_b,
strategy=self.strategy,
Optimizer=Optimizer,
)
return model_a, optimizer_a, model_b, optimizer_b
def sharding_model(self, Optimizer, sharded_accumulators):
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
Optimizer=Optimizer
)
self.assertTrue(
isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer)
)
for idx in range(STEPS):
if idx > 1:
self.assertTrue(
set(optimizer_a._inner_opt._inner_opt.state_dict().keys())
== sharded_accumulators[paddle.distributed.get_rank()]
)
if paddle.distributed.get_rank() == 0:
batch_sharding = paddle.to_tensor(self.data[idx][:2])
else:
batch_sharding = paddle.to_tensor(self.data[idx][2:])
batch_single = paddle.to_tensor(self.data[idx])
loss_a = self.train_batch(batch_sharding, model_a, optimizer_a)
loss_b = self.train_batch(batch_single, model_b, optimizer_b)
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
for j in range(len(model_a.parameters())):
np.testing.assert_allclose(
model_a.parameters()[j].numpy(),
model_b.parameters()[j].numpy(),
rtol=1e-6,
)
def test_sharding_adam(self):
sharded_accumulators = [
{
'linear_0.b_0_moment1_0',
'linear_1.b_0_moment1_0',
'linear_2.w_0_moment1_0',
'linear_2.b_0_moment1_0',
'linear_0.b_0_moment2_0',
'linear_1.b_0_moment2_0',
'linear_2.w_0_moment2_0',
'linear_2.b_0_moment2_0',
'linear_0.b_0_beta1_pow_acc_0',
'linear_1.b_0_beta1_pow_acc_0',
'linear_2.w_0_beta1_pow_acc_0',
'linear_2.b_0_beta1_pow_acc_0',
'linear_0.b_0_beta2_pow_acc_0',
'linear_1.b_0_beta2_pow_acc_0',
'linear_2.w_0_beta2_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
},
{
'linear_0.w_0_moment1_0',
'linear_1.w_0_moment1_0',
'linear_0.w_0_moment2_0',
'linear_1.w_0_moment2_0',
'linear_0.w_0_beta1_pow_acc_0',
'linear_1.w_0_beta1_pow_acc_0',
'linear_0.w_0_beta2_pow_acc_0',
'linear_1.w_0_beta2_pow_acc_0',
},
]
self.sharding_model(
Optimizer="adam",
sharded_accumulators=sharded_accumulators,
)
if __name__ == "__main__":
unittest.main()
......@@ -31,6 +31,10 @@ class TestHybridParallel(TestMultipleGpus):
os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_param_group(self):
# test shard grad reduce
self.run_mnist_2gpu('hybrid_parallel_sharding_param_group.py')
def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册