未验证 提交 623df429 编写于 作者: Z zhaoyingli 提交者: GitHub

support ClipGradByGlobalNorm in sharding (#36012)

* support ClipGradByGlobalNorm in sharding

* support ClipGradByGlobalNorm in sharding

* test=allcase
上级 2fd8deea
...@@ -12,5 +12,6 @@ ...@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer from .hybrid_parallel_optimizer import HybridParallelOptimizer
from .hybrid_parallel_gradscaler import HybridParallelGradScaler from .hybrid_parallel_gradscaler import HybridParallelGradScaler
from .dygraph_sharding_optimizer import DygraphShardingOptimizer
__all__ = [] __all__ = []
...@@ -88,6 +88,13 @@ class HybridParallelClipGrad: ...@@ -88,6 +88,13 @@ class HybridParallelClipGrad:
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group()) global_norm_var_dist, group=self._hcg.get_check_parallel_group())
# In Sharding mode, param and grad is mapping different rank in optimizer.
# ClipGradByGlobalNorm need allreduce to get globol norm
if self._hcg.get_sharding_parallel_world_size() > 1:
paddle.distributed.all_reduce(
global_norm_var_not_dist,
group=self._hcg.get_sharding_parallel_group())
global_norm_var = layers.sqrt(global_norm_var_dist + global_norm_var = layers.sqrt(global_norm_var_dist +
global_norm_var_not_dist) global_norm_var_not_dist)
...@@ -139,8 +146,13 @@ class HybridParallelOptimizer: ...@@ -139,8 +146,13 @@ class HybridParallelOptimizer:
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \ logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.") "optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad( if self._sharding_enable:
self._inner_opt._grad_clip, hcg) # change sharding inner_optimizer's _grad_clip
self._inner_opt._inner_optimizer._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
else:
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
......
...@@ -183,21 +183,23 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -183,21 +183,23 @@ class TestDistMPTraning(unittest.TestCase):
strategy=None, strategy=None,
is_sharding=True, is_sharding=True,
Optimizer="adam"): Optimizer="adam"):
clip = paddle.nn.ClipGradByGlobalNorm(0.5)
if Optimizer == "adam": if Optimizer == "adam":
if is_sharding: if is_sharding:
optimizer = DygraphShardingOptimizer( optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(), hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy, user_defined_strategy=strategy,
params=model.parameters(), params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Adam, inner_optimizer_class=paddle.optimizer.AdamW,
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, ) weight_decay=0.00001,
grad_clip=clip)
else: else:
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(), parameters=model.parameters(),
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, ) weight_decay=0.00001,
grad_clip=clip)
else: else:
if is_sharding: if is_sharding:
optimizer = DygraphShardingOptimizer( optimizer = DygraphShardingOptimizer(
...@@ -205,10 +207,13 @@ class TestDistMPTraning(unittest.TestCase): ...@@ -205,10 +207,13 @@ class TestDistMPTraning(unittest.TestCase):
user_defined_strategy=strategy, user_defined_strategy=strategy,
params=model.parameters(), params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Momentum, inner_optimizer_class=paddle.optimizer.Momentum,
learning_rate=0.001, ) learning_rate=0.001,
grad_clip=clip)
else: else:
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.Momentum(
learning_rate=0.001, parameters=model.parameters()) learning_rate=0.001,
parameters=model.parameters(),
grad_clip=clip)
return optimizer return optimizer
def build_model_optimizer(self, Optimizer="adam"): def build_model_optimizer(self, Optimizer="adam"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册