diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index f725a4351f2dce4cb1d7eff83c0614ed9aaa7dab..c7d1267298f8504d8f9e759d7c3d2bea266dc8c0 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -15,12 +15,14 @@ ###### import os +from collections import defaultdict from distutils.util import strtobool from functools import reduce import paddle from paddle import framework from paddle.fluid.dygraph import base as imperative_base +from paddle.nn import ClipGradByGlobalNorm from ...utils.log_util import logger @@ -55,11 +57,6 @@ class DygraphShardingOptimizer: # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm def __init__(self, optimizer, hcg): - # TODO(pangengzheng): support param_groups - if isinstance(optimizer._parameter_list[0], dict): - raise TypeError( - "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter" - ) if not hasattr(optimizer, '_apply_optimize') or not callable( optimizer._apply_optimize ): @@ -67,8 +64,20 @@ class DygraphShardingOptimizer: "the optimzier object should have _apply_optimize function" ) - # the self._parameter_list holds the whole model paramters - self._parameter_list = optimizer._parameter_list + self._using_param_groups = isinstance( + optimizer._parameter_list[0], dict + ) + + self._parameter_list = [] + self._param_2_group_id = {} + if self._using_param_groups: + for idx, param_group in enumerate(optimizer._param_groups): + for param in param_group['params']: + self._param_2_group_id[id(param)] = idx + self._parameter_list.append(param) + else: + self._parameter_list = optimizer._parameter_list + self._inner_opt = optimizer self._hcg = hcg self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() @@ -77,18 +86,35 @@ class DygraphShardingOptimizer: self._rank2params = self._partition_parameters() self._param2rank = self._map_param_to_rank() - self._set_inner_opt_attr( - '_parameter_list', self._rank2params[self._sharding_rank] - ) - self._set_inner_opt_attr( - '_param_groups', self._rank2params[self._sharding_rank] - ) + if self._using_param_groups: + param_groups = [ + {"params": []} for _ in range(len(optimizer._param_groups)) + ] + for idx, pg in enumerate(optimizer._param_groups): + param_groups[idx].update( + {k: v for k, v in pg.items() if k != 'params'} + ) + for param in self._rank2params[self._sharding_rank]: + group_id = self._param_2_group_id[id(param)] + param_groups[group_id]['params'].append(param) + + self._set_inner_opt_attr('_param_groups', param_groups) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) + self._param_groups = self._parameter_list + else: + self._set_inner_opt_attr( + '_param_groups', self._rank2params[self._sharding_rank] + ) + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) def clear_grad(self, set_to_zero=True): """ should clear grad for all parameters in model """ - # for p in self._parameter_list: if hasattr(p, "main_grad") and p.main_grad is not None: assert p._grad_ivar() is None @@ -225,6 +251,9 @@ class DygraphShardingOptimizer: # NOTE in dygraph mode, the only different between step and minimize is that minimize # allow user to customize the parameters for updating on each step + assert ( + not self._using_param_groups + ), "minimize() is not support if using param_groups" input_param_names = {param.name for param in parameters} parameters = list( filter( @@ -250,7 +279,7 @@ class DygraphShardingOptimizer: # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params # TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. origin_clip = self._inner_opt._grad_clip - if not isinstance(self._parameter_list[0], dict): + if not self._using_param_groups: params_grads = [] for param in self._parameter_list: if ( @@ -286,6 +315,35 @@ class DygraphShardingOptimizer: if g_shard_norm_align_dp: # restore the grad clip self._set_inner_opt_attr('_grad_clip', origin_clip) + else: + # optimize parameters in groups + for param_group in self._inner_opt._param_groups: + params_grads = defaultdict(lambda: []) + + # TODO(shenliang03): support ClipGradByGlobalNorm in sharding when using param_groups + grad_clip = param_group['grad_clip'] + assert not isinstance( + grad_clip, ClipGradByGlobalNorm + ), "ClipGradByGlobalNorm is not support if using param_groups in sharding" + + for param in param_group['params']: + if param.stop_gradient: + continue + + grad_var = param._grad_ivar() + if ( + hasattr(param, "main_grad") + and param.main_grad is not None + ): + grad_var = param.main_grad + + params_grads['params'].append((param, grad_var)) + params_grads.update( + {k: v for k, v in param_group.items() if k != 'params'} + ) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) # sync parameters across sharding ranks self._sharding_sync_parameters() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py index 82c132df99635fdc33a1fb48db2b76f81f7c0bd4..b22f03d4882c2cc95854ca94cb39dae5a7fa36cc 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_model.py @@ -18,7 +18,6 @@ import unittest import numpy as np import paddle -import paddle.distributed as dist from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, @@ -57,72 +56,6 @@ def parallel_matmul(lm_output, logit_weights, parallel_output): return logits -class SimpleMPNet(paddle.nn.Layer): - def __init__( - self, - vocab_size, - hidden_size, - inner_size, - output_size, - np_fc1, - np_fc2, - mp_id, - ): - super().__init__() - - if mp_id == 0: - init_fc1_data = np_fc1[:, : (inner_size // 2)] - init_fc2_data = np_fc2[: (inner_size // 2), :] - else: - init_fc1_data = np_fc1[:, (inner_size // 2) :] - init_fc2_data = np_fc2[(inner_size // 2) :, :] - - self.linear1 = fleet.meta_parallel.ColumnParallelLinear( - hidden_size, - inner_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign(init_fc1_data) - ), - gather_output=False, - has_bias=True, - ) - - self.linear2 = fleet.meta_parallel.RowParallelLinear( - inner_size, - hidden_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Assign(init_fc2_data) - ), - input_is_parallel=True, - has_bias=True, - ) - - self.linear3 = paddle.nn.Linear( - hidden_size, - output_size, - weight_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - bias_attr=paddle.framework.ParamAttr( - initializer=paddle.nn.initializer.Constant(0.0) - ), - ) - - self.embedding = fleet.meta_parallel.VocabParallelEmbedding( - vocab_size, - hidden_size, - weight_attr=paddle.nn.initializer.Constant(value=0.5), - ) - - def forward(self, x): - x = self.embedding(x) - x = self.linear1(x) - x = self.linear2(x) - x = self.linear3(x) - x = parallel_matmul(x, self.embedding.weight, False) - return x - - class SimpleDPNet(paddle.nn.Layer): def __init__( self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 @@ -230,12 +163,6 @@ class TestDistMPTraning(unittest.TestCase): return optimizer def build_model_optimizer(self, Optimizer="adam"): - hcg = fleet.get_hybrid_communicate_group() - word_size = hcg.get_model_parallel_world_size() - sharding_id = hcg.get_sharding_parallel_rank() - dp_id = hcg.get_data_parallel_rank() - rank_id = dist.get_rank() - np_fc1 = np.random.random_sample((hidden_size, inner_size)) np_fc2 = np.random.random_sample((inner_size, hidden_size)) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_param_group.py b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_param_group.py new file mode 100644 index 0000000000000000000000000000000000000000..1a70e04b0576ac385ee79d68c807286df8865d4e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/hybrid_parallel_sharding_param_group.py @@ -0,0 +1,203 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +from hybrid_parallel_sharding_model import SimpleDPNet + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 +STEPS = 10 + + +class TestDistMPTraning(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 2, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=self.strategy) + self.data = [ + np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + for _ in range(STEPS) + ] + + def train_batch(self, batch, model, optimizer): + output = model(batch) + loss = output.mean() + loss.backward() # do backward + optimizer.step() # update parameters + optimizer.clear_grad() + return loss + + def build_optimizer(self, model, strategy=None, Optimizer="adam"): + clip = paddle.nn.ClipGradByNorm(0.7) + param_groups = [ + { + "params": model.linear1.parameters(), + "weight_decay": 0.0001, + "learning_rate": 0.1, + }, + { + "params": model.linear2.parameters(), + "weight_decay": 0.020, + "learning_rate": 0.01, + }, + { + "params": model.linear3.parameters(), + "weight_decay": 0.1, + "learning_rate": 0.1, + }, + ] + + if Optimizer == "adam": + optimizer = paddle.optimizer.AdamW( + parameters=param_groups, + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + ) + else: + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + parameters=model.parameters(), + grad_clip=clip, + ) + return optimizer + + def build_model_optimizer(self, Optimizer="adam"): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model_a = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_a = self.build_optimizer( + model_a, + strategy=self.strategy, + Optimizer=Optimizer, + ) + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + + model_b = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + optimizer_b = self.build_optimizer( + model_b, + strategy=self.strategy, + Optimizer=Optimizer, + ) + + return model_a, optimizer_a, model_b, optimizer_b + + def sharding_model(self, Optimizer, sharded_accumulators): + model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( + Optimizer=Optimizer + ) + + self.assertTrue( + isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer) + ) + + for idx in range(STEPS): + if idx > 1: + self.assertTrue( + set(optimizer_a._inner_opt._inner_opt.state_dict().keys()) + == sharded_accumulators[paddle.distributed.get_rank()] + ) + + if paddle.distributed.get_rank() == 0: + batch_sharding = paddle.to_tensor(self.data[idx][:2]) + else: + batch_sharding = paddle.to_tensor(self.data[idx][2:]) + + batch_single = paddle.to_tensor(self.data[idx]) + loss_a = self.train_batch(batch_sharding, model_a, optimizer_a) + loss_b = self.train_batch(batch_single, model_b, optimizer_b) + + np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy()) + for j in range(len(model_a.parameters())): + np.testing.assert_allclose( + model_a.parameters()[j].numpy(), + model_b.parameters()[j].numpy(), + rtol=1e-6, + ) + + def test_sharding_adam(self): + sharded_accumulators = [ + { + 'linear_0.b_0_moment1_0', + 'linear_1.b_0_moment1_0', + 'linear_2.w_0_moment1_0', + 'linear_2.b_0_moment1_0', + 'linear_0.b_0_moment2_0', + 'linear_1.b_0_moment2_0', + 'linear_2.w_0_moment2_0', + 'linear_2.b_0_moment2_0', + 'linear_0.b_0_beta1_pow_acc_0', + 'linear_1.b_0_beta1_pow_acc_0', + 'linear_2.w_0_beta1_pow_acc_0', + 'linear_2.b_0_beta1_pow_acc_0', + 'linear_0.b_0_beta2_pow_acc_0', + 'linear_1.b_0_beta2_pow_acc_0', + 'linear_2.w_0_beta2_pow_acc_0', + 'linear_2.b_0_beta2_pow_acc_0', + }, + { + 'linear_0.w_0_moment1_0', + 'linear_1.w_0_moment1_0', + 'linear_0.w_0_moment2_0', + 'linear_1.w_0_moment2_0', + 'linear_0.w_0_beta1_pow_acc_0', + 'linear_1.w_0_beta1_pow_acc_0', + 'linear_0.w_0_beta2_pow_acc_0', + 'linear_1.w_0_beta2_pow_acc_0', + }, + ] + + self.sharding_model( + Optimizer="adam", + sharded_accumulators=sharded_accumulators, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py index a7014be0c2806d6f07acf3dc7cbb9d293743b5dd..cd60e61a1589872c04c87bf331a940ef4831cab3 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -31,6 +31,10 @@ class TestHybridParallel(TestMultipleGpus): os.environ["FLAGS_shard_norm_align_dp"] = "1" self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') + def test_hybrid_parallel_sharding_param_group(self): + # test shard grad reduce + self.run_mnist_2gpu('hybrid_parallel_sharding_param_group.py') + def test_hybrid_parallel_sharding_state_dict(self): self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py')