From 56d305348a05be0d8d61c8c8f3367fc69fad638a Mon Sep 17 00:00:00 2001 From: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com> Date: Fri, 23 Sep 2022 14:21:54 +0800 Subject: [PATCH] fl-ps bug fix (#46356) * back fl * delete ssl cert * . * make warning * . * unittest paral degree * solve unittest * heter & multi cloud commm ready * . * . * fl-ps v1.0 * . * support N + N mode * . * . * . * . * delete print * . * . * . * . * fix bug * . * . * fl-ps with coordinator ready * merge dev * update message parse only * update fl client scheduler * fix bug * update multithreads sync * fix ci errors * update role_maker.py * update role_maker.py * fix ci error: windows py import error * fix ci error: windows py import error * fix windows ci pylib import error * add dump fields & params * try to fix windows import fleet error * fix ps FLAGS error * fix logging risk * fix logging possible risk * write trainer_desc file * support split sparse params in local & remote * fix import paddle.fluid.core.PSGPU * fix import paddle.fluid.core.PSGPU * add remote_sparse & local_sparse config * fix unittest * fix test_dist_fleet_geo table error * fix PADDLE_ENFORCE error * fix other's pr conflict * forbidden ssd table * . * recover ssd table code * recover file mode * debug auc 0.5 * adapt for nn fl-ps * adapt for nn fl-ps * add learning_rate_0 intializer op * recover ssd table * modify file mode * flps del fake-init op * bug fix --- python/paddle/distributed/ps/the_one_ps.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 5765a5e24b2..0ce5e70788e 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -192,6 +192,8 @@ class Accessor: sgd_param.name = "SparseNaiveSGDRule" if common_accessor.accessor_class == "adam": sgd_param.name = "SparseAdamSGDRule" + else: # for fl-ps, because geo accessor is 'sum' + sgd_param.name = "SparseAdamSGDRule" if sgd_param.name == "SparseAdaGradSGDRule" or sgd_param.name == "StdAdaGradSGDRule": if not sgd_param.adagrad.HasField("learning_rate"): -- GitLab