From 13dca2c9b342d006208b8bdda208f7a2d26a25d6 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Wed, 5 Feb 2020 11:25:08 +0800 Subject: [PATCH] [Cherry-pick]Fix geo init & send (#22413) * Fix GEO-SGD init & send Bug (#22375) * test=develop, fix geo Send & Init * test=release/1.7,test=develop, cherry-pick 8f36c39 --- paddle/fluid/operators/distributed/communicator.cc | 1 + python/paddle/fluid/communicator.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index ff9909edd7d..9b14e5ca685 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -433,6 +433,7 @@ void GeoSgdCommunicator::Send(const std::vector &sparse_var_names, } GeoSgdDenseParamInit(training_scope_, old_scope_.get(), local_var_name); } + return; } std::shared_ptr ids_table = std::make_shared(); diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index 58596a926ff..a833072f309 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -61,7 +61,8 @@ class Communicator(object): varnames = "&".join(vs["var_names"]) sections = "&".join([str(v) for v in vs["sections"]]) endpoints = "&".join(vs["epmap"]) - is_sparse = "1" if vs["is_sparse"] else "0" + # record parameter sparse or dense + is_sparse = "1" if vs["is_sparse"] == ['True'] else "0" push_var_names.append(k) envs[k] = "#".join([varnames, sections, endpoints, is_sparse]) -- GitLab