diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 4fd087bbd50afef8080723cba0e897541a164fb0..8197f4368f20fd6b9d612f1195f4a552cfdb4a31 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -79,7 +79,7 @@ class OffloadHelper(object): # {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} = # adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']}) vars_name.append(op.desc.input("Moment1")[0]) - vars_name.append(op.desc.input("Moment1")[0]) + vars_name.append(op.desc.input("Moment2")[0]) elif op.type == 'momentum': pass elif op.type == 'lars':