未验证 提交 c3ebe13b 编写于 作者: S ShenLiang 提交者: GitHub

[Distributed]Support param_group in sharding-stage1 (#56626)

* support param group in sharding

* fix utest
上级 f2562f19
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
###### ######
import os import os
from collections import defaultdict
from distutils.util import strtobool from distutils.util import strtobool
from functools import reduce from functools import reduce
import paddle import paddle
from paddle import framework from paddle import framework
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
from paddle.nn import ClipGradByGlobalNorm
from ...utils.log_util import logger from ...utils.log_util import logger
...@@ -55,11 +57,6 @@ class DygraphShardingOptimizer: ...@@ -55,11 +57,6 @@ class DygraphShardingOptimizer:
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def __init__(self, optimizer, hcg): def __init__(self, optimizer, hcg):
# TODO(pangengzheng): support param_groups
if isinstance(optimizer._parameter_list[0], dict):
raise TypeError(
"Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter"
)
if not hasattr(optimizer, '_apply_optimize') or not callable( if not hasattr(optimizer, '_apply_optimize') or not callable(
optimizer._apply_optimize optimizer._apply_optimize
): ):
...@@ -67,8 +64,20 @@ class DygraphShardingOptimizer: ...@@ -67,8 +64,20 @@ class DygraphShardingOptimizer:
"the optimzier object should have _apply_optimize function" "the optimzier object should have _apply_optimize function"
) )
# the self._parameter_list holds the whole model paramters self._using_param_groups = isinstance(
self._parameter_list = optimizer._parameter_list optimizer._parameter_list[0], dict
)
self._parameter_list = []
self._param_2_group_id = {}
if self._using_param_groups:
for idx, param_group in enumerate(optimizer._param_groups):
for param in param_group['params']:
self._param_2_group_id[id(param)] = idx
self._parameter_list.append(param)
else:
self._parameter_list = optimizer._parameter_list
self._inner_opt = optimizer self._inner_opt = optimizer
self._hcg = hcg self._hcg = hcg
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
...@@ -77,18 +86,35 @@ class DygraphShardingOptimizer: ...@@ -77,18 +86,35 @@ class DygraphShardingOptimizer:
self._rank2params = self._partition_parameters() self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank() self._param2rank = self._map_param_to_rank()
self._set_inner_opt_attr( if self._using_param_groups:
'_parameter_list', self._rank2params[self._sharding_rank] param_groups = [
) {"params": []} for _ in range(len(optimizer._param_groups))
self._set_inner_opt_attr( ]
'_param_groups', self._rank2params[self._sharding_rank] for idx, pg in enumerate(optimizer._param_groups):
) param_groups[idx].update(
{k: v for k, v in pg.items() if k != 'params'}
)
for param in self._rank2params[self._sharding_rank]:
group_id = self._param_2_group_id[id(param)]
param_groups[group_id]['params'].append(param)
self._set_inner_opt_attr('_param_groups', param_groups)
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
self._param_groups = self._parameter_list
else:
self._set_inner_opt_attr(
'_param_groups', self._rank2params[self._sharding_rank]
)
self._set_inner_opt_attr(
'_parameter_list', self._rank2params[self._sharding_rank]
)
def clear_grad(self, set_to_zero=True): def clear_grad(self, set_to_zero=True):
""" """
should clear grad for all parameters in model should clear grad for all parameters in model
""" """
#
for p in self._parameter_list: for p in self._parameter_list:
if hasattr(p, "main_grad") and p.main_grad is not None: if hasattr(p, "main_grad") and p.main_grad is not None:
assert p._grad_ivar() is None assert p._grad_ivar() is None
...@@ -225,6 +251,9 @@ class DygraphShardingOptimizer: ...@@ -225,6 +251,9 @@ class DygraphShardingOptimizer:
# NOTE in dygraph mode, the only different between step and minimize is that minimize # NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step # allow user to customize the parameters for updating on each step
assert (
not self._using_param_groups
), "minimize() is not support if using param_groups"
input_param_names = {param.name for param in parameters} input_param_names = {param.name for param in parameters}
parameters = list( parameters = list(
filter( filter(
...@@ -250,7 +279,7 @@ class DygraphShardingOptimizer: ...@@ -250,7 +279,7 @@ class DygraphShardingOptimizer:
# otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params
# TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. # TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp.
origin_clip = self._inner_opt._grad_clip origin_clip = self._inner_opt._grad_clip
if not isinstance(self._parameter_list[0], dict): if not self._using_param_groups:
params_grads = [] params_grads = []
for param in self._parameter_list: for param in self._parameter_list:
if ( if (
...@@ -286,6 +315,35 @@ class DygraphShardingOptimizer: ...@@ -286,6 +315,35 @@ class DygraphShardingOptimizer:
if g_shard_norm_align_dp: if g_shard_norm_align_dp:
# restore the grad clip # restore the grad clip
self._set_inner_opt_attr('_grad_clip', origin_clip) self._set_inner_opt_attr('_grad_clip', origin_clip)
else:
# optimize parameters in groups
for param_group in self._inner_opt._param_groups:
params_grads = defaultdict(lambda: [])
# TODO(shenliang03): support ClipGradByGlobalNorm in sharding when using param_groups
grad_clip = param_group['grad_clip']
assert not isinstance(
grad_clip, ClipGradByGlobalNorm
), "ClipGradByGlobalNorm is not support if using param_groups in sharding"
for param in param_group['params']:
if param.stop_gradient:
continue
grad_var = param._grad_ivar()
if (
hasattr(param, "main_grad")
and param.main_grad is not None
):
grad_var = param.main_grad
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v for k, v in param_group.items() if k != 'params'}
)
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
# sync parameters across sharding ranks # sync parameters across sharding ranks
self._sharding_sync_parameters() self._sharding_sync_parameters()
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer, DygraphShardingOptimizer,
...@@ -57,72 +56,6 @@ def parallel_matmul(lm_output, logit_weights, parallel_output): ...@@ -57,72 +56,6 @@ def parallel_matmul(lm_output, logit_weights, parallel_output):
return logits return logits
class SimpleMPNet(paddle.nn.Layer):
def __init__(
self,
vocab_size,
hidden_size,
inner_size,
output_size,
np_fc1,
np_fc2,
mp_id,
):
super().__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, : (inner_size // 2)]
init_fc2_data = np_fc2[: (inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2) :]
init_fc2_data = np_fc2[(inner_size // 2) :, :]
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)
),
gather_output=False,
has_bias=True,
)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)
),
input_is_parallel=True,
has_bias=True,
)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)
),
)
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5),
)
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = parallel_matmul(x, self.embedding.weight, False)
return x
class SimpleDPNet(paddle.nn.Layer): class SimpleDPNet(paddle.nn.Layer):
def __init__( def __init__(
self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
...@@ -230,12 +163,6 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -230,12 +163,6 @@ class TestDistMPTraning(unittest.TestCase):
return optimizer return optimizer
def build_model_optimizer(self, Optimizer="adam"): def build_model_optimizer(self, Optimizer="adam"):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
sharding_id = hcg.get_sharding_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
np_fc1 = np.random.random_sample((hidden_size, inner_size)) np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size)) np_fc2 = np.random.random_sample((inner_size, hidden_size))
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from hybrid_parallel_sharding_model import SimpleDPNet
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)
self.data = [
np.random.randint(
0,
vocab_size,
(
batch_size,
seq_length,
),
)
for _ in range(STEPS)
]
def train_batch(self, batch, model, optimizer):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self, model, strategy=None, Optimizer="adam"):
clip = paddle.nn.ClipGradByNorm(0.7)
param_groups = [
{
"params": model.linear1.parameters(),
"weight_decay": 0.0001,
"learning_rate": 0.1,
},
{
"params": model.linear2.parameters(),
"weight_decay": 0.020,
"learning_rate": 0.01,
},
{
"params": model.linear3.parameters(),
"weight_decay": 0.1,
"learning_rate": 0.1,
},
]
if Optimizer == "adam":
optimizer = paddle.optimizer.AdamW(
parameters=param_groups,
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip,
)
return optimizer
def build_model_optimizer(self, Optimizer="adam"):
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model_a = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_a = self.build_optimizer(
model_a,
strategy=self.strategy,
Optimizer=Optimizer,
)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(
vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2
)
optimizer_b = self.build_optimizer(
model_b,
strategy=self.strategy,
Optimizer=Optimizer,
)
return model_a, optimizer_a, model_b, optimizer_b
def sharding_model(self, Optimizer, sharded_accumulators):
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
Optimizer=Optimizer
)
self.assertTrue(
isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer)
)
for idx in range(STEPS):
if idx > 1:
self.assertTrue(
set(optimizer_a._inner_opt._inner_opt.state_dict().keys())
== sharded_accumulators[paddle.distributed.get_rank()]
)
if paddle.distributed.get_rank() == 0:
batch_sharding = paddle.to_tensor(self.data[idx][:2])
else:
batch_sharding = paddle.to_tensor(self.data[idx][2:])
batch_single = paddle.to_tensor(self.data[idx])
loss_a = self.train_batch(batch_sharding, model_a, optimizer_a)
loss_b = self.train_batch(batch_single, model_b, optimizer_b)
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
for j in range(len(model_a.parameters())):
np.testing.assert_allclose(
model_a.parameters()[j].numpy(),
model_b.parameters()[j].numpy(),
rtol=1e-6,
)
def test_sharding_adam(self):
sharded_accumulators = [
{
'linear_0.b_0_moment1_0',
'linear_1.b_0_moment1_0',
'linear_2.w_0_moment1_0',
'linear_2.b_0_moment1_0',
'linear_0.b_0_moment2_0',
'linear_1.b_0_moment2_0',
'linear_2.w_0_moment2_0',
'linear_2.b_0_moment2_0',
'linear_0.b_0_beta1_pow_acc_0',
'linear_1.b_0_beta1_pow_acc_0',
'linear_2.w_0_beta1_pow_acc_0',
'linear_2.b_0_beta1_pow_acc_0',
'linear_0.b_0_beta2_pow_acc_0',
'linear_1.b_0_beta2_pow_acc_0',
'linear_2.w_0_beta2_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0',
},
{
'linear_0.w_0_moment1_0',
'linear_1.w_0_moment1_0',
'linear_0.w_0_moment2_0',
'linear_1.w_0_moment2_0',
'linear_0.w_0_beta1_pow_acc_0',
'linear_1.w_0_beta1_pow_acc_0',
'linear_0.w_0_beta2_pow_acc_0',
'linear_1.w_0_beta2_pow_acc_0',
},
]
self.sharding_model(
Optimizer="adam",
sharded_accumulators=sharded_accumulators,
)
if __name__ == "__main__":
unittest.main()
...@@ -31,6 +31,10 @@ class TestHybridParallel(TestMultipleGpus): ...@@ -31,6 +31,10 @@ class TestHybridParallel(TestMultipleGpus):
os.environ["FLAGS_shard_norm_align_dp"] = "1" os.environ["FLAGS_shard_norm_align_dp"] = "1"
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
def test_hybrid_parallel_sharding_param_group(self):
# test shard grad reduce
self.run_mnist_2gpu('hybrid_parallel_sharding_param_group.py')
def test_hybrid_parallel_sharding_state_dict(self): def test_hybrid_parallel_sharding_state_dict(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册