diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 1d6732dd21e5c82057ee939a95f7812fa938fdbd..ecca873c9d06909f564a8aa0fab996b1c33e6912 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -89,6 +89,8 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; VLOG(0) << "communicator_merge_sparse_grad: " << FLAGS_communicator_merge_sparse_grad; + VLOG(0) << "communicator_is_sgd_optimizer: " + << FLAGS_communicator_is_sgd_optimizer; if (send_varname_to_ctx.size() == 0) { VLOG(0) << "nothing need to be send, will not start send_thread"; diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 50582e6f34b7ed3f523cb5d1e293d817ee4ec2c8..be61a0281cd42f5a0e1f0738701f4d9c30932972 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -24,6 +24,7 @@ limitations under the License. */ #include #include #include +#include "gflags/gflags.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" @@ -37,6 +38,8 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +DECLARE_bool(communicator_is_sgd_optimizer); + namespace paddle { namespace operators { namespace distributed { @@ -138,8 +141,10 @@ inline void MergeVars(const std::string& var_name, auto in = EigenVector::Flatten(in_t); result.device(*cpu_ctx.eigen_device()) = result + in; } - result.device(*cpu_ctx.eigen_device()) = - result / static_cast(vars.size()); + if (!FLAGS_communicator_is_sgd_optimizer) { + result.device(*cpu_ctx.eigen_device()) = + result / static_cast(vars.size()); + } } else if (var0->IsType()) { auto& slr0 = var0->Get(); auto* out_slr = out_var->GetMutable(); @@ -151,9 +156,16 @@ inline void MergeVars(const std::string& var_name, inputs.push_back(&var->Get()); } auto dev_ctx = paddle::platform::CPUDeviceContext(); - math::scatter::MergeAverage - merge_average; - merge_average(dev_ctx, inputs, out_slr); + if (FLAGS_communicator_is_sgd_optimizer) { + math::scatter::MergeAdd + merge_add; + merge_add(dev_ctx, inputs, out_slr); + } else { + math::scatter::MergeAverage + merge_average; + merge_average(dev_ctx, inputs, out_slr); + } + VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() << " dims: " << slr0.value().dims(); } else { diff --git a/paddle/fluid/operators/distributed/communicator_test.cc b/paddle/fluid/operators/distributed/communicator_test.cc index 66e36d012b10a0e1d627ee44dcde9e68f66cc719..5294ac33d15611a003eeb7971891e8ca85ec6a73 100644 --- a/paddle/fluid/operators/distributed/communicator_test.cc +++ b/paddle/fluid/operators/distributed/communicator_test.cc @@ -42,7 +42,6 @@ TEST(communicator, merge_lod_tensors) { } out_value += static_cast(i); } - out_value = out_value / 10.0; const std::string out_name = "Out"; std::unique_ptr scope; scope.reset(new framework::Scope()); @@ -96,7 +95,7 @@ TEST(communicator, merge_selected_rows) { std::vector out_values; out_values.reserve(10); for (auto i = 0; i < 10; ++i) { - out_values.push_back(static_cast((i * (10 - i)) / 10.0)); + out_values.push_back(static_cast(i * (10 - i))); } for (auto i = 0; i < out_slr.rows().size(); ++i) { ASSERT_EQ(out_slr.rows()[i], i); diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index f79adf707083cf2cd95db3af5c047b9c9d849198..56362391a25d2e09b366399b496507776f60e67d 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -139,6 +139,13 @@ void ParameterSend::operator()(const RpcContext &rpc_ctx, auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto &send_rows = send_slr.rows(); + if (send_rows.size() == 0) { + LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which " + "may cause an unknown error. Please check the state of " + "use_double_buffer in pyreader async mode, you need to " + "turn it false."; + } + std::vector> outs_rows_idx; std::vector> outs_dense_idx; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c77d8b4e70b3577c6db3af9280c7e2997c308d90..9b5e0c92fca470cad82191f41500120cecdf7f6f 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -199,7 +199,9 @@ DEFINE_bool( */ DEFINE_int32(communicator_max_merge_var_num, 20, "max var num to merge and send"); - +DEFINE_bool(communicator_is_sgd_optimizer, true, + "gradient sent to the server is the sum of the gradients " + "calculated by each thread if optimizer is sgd"); /** * Distributed related FLAG * Name: FLAGS_communicator_send_queue_size diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 295dae7217299e8258a49aea05130755d615b4bc..14106668531e429e2d2d2e59ee6e9bf5f669002e 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -206,6 +206,7 @@ def __bootstrap__(): read_env_flags.append('communicator_fake_rpc') read_env_flags.append('communicator_send_wait_times') read_env_flags.append('communicator_merge_sparse_grad') + read_env_flags.append('communicator_is_sgd_optimizer') if core.is_compiled_with_brpc(): read_env_flags.append('max_body_size') #set brpc max body size diff --git a/python/paddle/fluid/tests/unittests/test_dist_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_ctr.py index 91947ded35330c7b35eb4560daa98c53653a13f4..f1bbce89821d57c76a91b48f5a71b253d78ff2dc 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_ctr.py @@ -113,7 +113,8 @@ class TestDistCTR2x2_ASYNC2(TestDistBase): "FLAGS_communicator_send_queue_size": "2", "FLAGS_communicator_max_merge_var_num": "2", "FLAGS_communicator_max_send_grad_num_before_recv": "2", - "FLAGS_communicator_independent_recv_thread": "0" + "FLAGS_communicator_independent_recv_thread": "0", + "FLAGS_communicator_is_sgd_optimizer": "0" } self.check_with_place(