未验证 提交 c40122d9 编写于 作者: S sneaxiy 提交者: GitHub

fix sharding_stage1 amp O2 decorate bug (#48960)

上级 fd373579
......@@ -252,14 +252,26 @@ def check_models(models):
)
def _is_valid_optimizer(optimizer):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
return isinstance(
optimizer,
(
paddle.optimizer.Optimizer,
paddle.fluid.optimizer.Optimizer,
DygraphShardingOptimizer,
),
)
def check_optimizers(optimizers):
for optimizer in optimizers:
if not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
if not _is_valid_optimizer(optimizer):
raise RuntimeError(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
"Current train mode is pure fp16, optimizers should be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer or DygraphShardingOptimizer, but receive {}.".format(
type(optimizer)
)
)
......@@ -477,6 +489,20 @@ class StateDictHook:
state_dict[key] = param_applied
def _set_multi_precision(optimizer, multi_precision):
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
DygraphShardingOptimizer,
)
optimizer = (
optimizer._inner_optimizer
if isinstance(optimizer, DygraphShardingOptimizer)
else optimizer
)
if hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
@dygraph_only
def amp_decorate(
models,
......@@ -582,10 +608,7 @@ def amp_decorate(
if optimizers is not None:
# check optimizers
optimizers_is_list = False
if isinstance(
optimizers,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
if _is_valid_optimizer(optimizers):
optimizers_is_list = False
optimizers = [optimizers]
check_optimizers(optimizers)
......@@ -596,13 +619,10 @@ def amp_decorate(
raise TypeError(
"optimizers must be either a single optimizer or a list of optimizers."
)
# supprot master_weight
for idx_opt in range(len(optimizers)):
if hasattr(optimizers[idx_opt], '_multi_precision'):
if master_weight is False:
optimizers[idx_opt]._multi_precision = False
else:
optimizers[idx_opt]._multi_precision = True
# support master_weight
use_multi_precision = not (master_weight is False)
for opt in optimizers:
_set_multi_precision(opt, use_multi_precision)
if save_dtype is not None:
if not (save_dtype in ['float16', 'bfloat16', 'float32', 'float64']):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册