diff --git a/paddle_fl/core/strategy/fl_distribute_transpiler.py b/paddle_fl/core/strategy/fl_distribute_transpiler.py index 6c5667e404193061f4c1e79d0ee79539974944b7..28b369ce9b3ada38eedc438f10f3b1da9725833b 100644 --- a/paddle_fl/core/strategy/fl_distribute_transpiler.py +++ b/paddle_fl/core/strategy/fl_distribute_transpiler.py @@ -691,7 +691,7 @@ class FLDistributeTranspiler(object): opti_to_param = dict() param_to_opti = dict() for op in self.optimize_ops: - if op.type == "sgd": + if (op.type == "sgd") or (op.type == "adam"): origin_name = op.output("ParamOut") var = self.origin_program.global_block().var(origin_name[0]) new_var_name = "%s.opti.trainer_%d" % (origin_name[0],