diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 5765a5e24b20db17fe6d858460f9d8e3e7cff7f3..0ce5e70788e72b9987fc6d445e72526b40b5f8fe 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -192,6 +192,8 @@ class Accessor: sgd_param.name = "SparseNaiveSGDRule" if common_accessor.accessor_class == "adam": sgd_param.name = "SparseAdamSGDRule" + else: # for fl-ps, because geo accessor is 'sum' + sgd_param.name = "SparseAdamSGDRule" if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": if not sgd_param.adagrad.HasField("learning_rate"):