未验证 提交 c40122d9 编写于 作者: S sneaxiy 提交者: GitHub

fix sharding_stage1 amp O2 decorate bug (#48960)

上级 fd373579
...@@ -252,14 +252,26 @@ def check_models(models): ...@@ -252,14 +252,26 @@ def check_models(models):
) )
def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
paddle.fluid.optimizer.Optimizer,
DygraphShardingOptimizer,
),
)
def check_optimizers(optimizers): def check_optimizers(optimizers):
for optimizer in optimizers: for optimizer in optimizers:
if not isinstance( if not _is_valid_optimizer(optimizer):
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise RuntimeError( raise RuntimeError(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format( "Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format(
type(optimizer) type(optimizer)
) )
) )
...@@ -477,6 +489,20 @@ class StateDictHook: ...@@ -477,6 +489,20 @@ class StateDictHook:
state_dict[key] = param_applied state_dict[key] = param_applied
def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
optimizer = (
optimizer._inner_optimizer
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
@dygraph_only @dygraph_only
def amp_decorate( def amp_decorate(
models, models,
...@@ -582,10 +608,7 @@ def amp_decorate( ...@@ -582,10 +608,7 @@ def amp_decorate(
if optimizers is not None: if optimizers is not None:
# check optimizers # check optimizers
optimizers_is_list = False optimizers_is_list = False
if isinstance( if _is_valid_optimizer(optimizers):
optimizers,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
optimizers_is_list = False optimizers_is_list = False
optimizers = [optimizers] optimizers = [optimizers]
check_optimizers(optimizers) check_optimizers(optimizers)
...@@ -596,13 +619,10 @@ def amp_decorate( ...@@ -596,13 +619,10 @@ def amp_decorate(
raise TypeError( raise TypeError(
"optimizers must be either a single optimizer or a list of optimizers." "optimizers must be either a single optimizer or a list of optimizers."
) )
# supprot master_weight # support master_weight
for idx_opt in range(len(optimizers)): use_multi_precision = not (master_weight is False)
if hasattr(optimizers[idx_opt], '_multi_precision'): for opt in optimizers:
if master_weight is False: _set_multi_precision(opt, use_multi_precision)
optimizers[idx_opt]._multi_precision = False
else:
optimizers[idx_opt]._multi_precision = True
if save_dtype is not None: if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']): if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册