未验证 提交 d247cf17 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs in mp_layers、pp_layers and HybridParallelClipGrad (#36144)

* fix calling bug of HybridParallelClipGrad

* fix bugs of HybridParallelClipGrad

* add unittest of pp with HybridParallelClipGrad

* fix bugs in mp_layers.py

* update

* fix bugs in pp_layers.py

* update
上级 ec148cab
......@@ -52,6 +52,7 @@ class HybridParallelClipGrad:
params_and_grads = []
sum_square_list_dist = []
sum_square_list_not_dist = []
for p, g in params_grads:
if g is None:
continue
......@@ -64,29 +65,38 @@ class HybridParallelClipGrad:
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared') and
getattr(p, 'is_firstly_shared', True))
if not_shared_enable:
if p.is_distributed:
sum_square_list_dist.append(sum_square)
else:
sum_square_list_not_dist.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list_dist) + len(sum_square_list_not_dist) == 0:
return params_grads
global_norm_var_dist = layers.concat(sum_square_list_dist) if len(
sum_square_list_dist) != 0 else layers.concat(
[paddle.to_tensor([0.])])
global_norm_var_dist = layers.reduce_sum(global_norm_var_dist)
global_norm_var_not_dist = layers.concat(
sum_square_list_not_dist) if len(
sum_square_list_not_dist) != 0 else layers.concat(
[paddle.to_tensor([0.])])
global_norm_var_not_dist = layers.reduce_sum(global_norm_var_not_dist)
# add all reduce to get global norm of distributed params_and_grads in world size
# all reduce is not needed while getting global norm of non-distributed params_and_grads
# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group())
global_norm_var_dist,
group=self._hcg.get_check_parallel_group())
# add all reduce to get global norm of non-distributed params_and_grads in groups of pp
if self._hcg.get_pipe_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_pipe_parallel_group())
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
......@@ -143,8 +153,8 @@ class HybridParallelOptimizer:
if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and not self._use_dp_mode:
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")
logger.warning("While using ClipGradByGlobalNorm in TensorParallel, PipelineParallel " \
"or Sharding, the grad clip of original optimizer will be changed.")
if self._sharding_enable:
# change sharding inner_optimizer's _grad_clip
......
......@@ -70,7 +70,7 @@ class VocabParallelEmbedding(Layer):
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False
def forward(self, x):
if self.is_mp:
......@@ -135,7 +135,7 @@ class ColumnParallelLinear(Layer):
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
# initialize bias to zero like Megatron
......@@ -144,7 +144,7 @@ class ColumnParallelLinear(Layer):
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
self.bias.is_distributed = True if self.is_mp else False
else:
self.bias = None
......@@ -212,7 +212,7 @@ class RowParallelLinear(Layer):
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
self.weight.is_distributed = True if self.is_mp else False
if has_bias:
self.bias = self.create_parameter(
......
......@@ -261,6 +261,10 @@ class PipelineLayer(Layer):
src=min(comm['ranks']),
group=comm['group'])
for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']):
setattr(param, 'is_firstly_shared', False)
def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items():
param = getattr(self.shared_layers[key], comm['weight_attr'])
......@@ -316,6 +320,9 @@ class PipelineLayer(Layer):
self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr
for param in self.shared_layers[
layer.layer_name].parameters():
setattr(param, "is_firstly_shared", True)
if layer.forward_func is None:
self.run_function.append(self.shared_layers[
......
......@@ -53,6 +53,13 @@ class TestDistPPTraning(unittest.TestCase):
}
fleet.init(is_collective=True, strategy=strategy)
def build_optimizer(self, model):
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
return scheduler, optimizer
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
......@@ -63,10 +70,7 @@ class TestDistPPTraning(unittest.TestCase):
#construct model a
model_a = AlexNet(10)
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
parameters=model_a.parameters())
scheduler_a, optimizer_a = self.build_optimizer(model_a)
param_len = len(model_a.parameters())
......@@ -76,10 +80,7 @@ class TestDistPPTraning(unittest.TestCase):
# construct model b
model_b = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
parameters=model_b.parameters())
scheduler_b, optimizer_b = self.build_optimizer(model_b)
model_b = fleet.distributed_model(model_b)
optimizer_b = fleet.distributed_optimizer(optimizer_b)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import unittest
from hybrid_parallel_pp_alexnet import TestDistPPTraning
class TestPPClipGrad(TestDistPPTraning):
def build_optimizer(self, model):
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.5)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
return scheduler, optimizer
if __name__ == "__main__":
unittest.main()
......@@ -42,6 +42,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_recompute(self):
self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py')
def test_hybrid_parallel_pp_clip_grad(self):
self.run_mnist_2gpu('hybrid_parallel_pp_clip_grad.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册