diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index bb6af1b3195f705bfa3813f3c0eb98a72f939212..9c751c5044701b23fce0e9eb82dabdcec1346109 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -284,7 +284,7 @@ class OffloadHelper(object): break vars_name = [] - if op.type == "adam": + if op.type == "adam" or op.type == "adamw": # {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} = # adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']}) vars_name.append(op.desc.input("Moment1")[0])