diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 7e14a73d638e5c6cac3238348f60bd5e4809249b..4d21fce5b2525e23441df0b0d21b96bb0797d1bb 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -51,7 +51,7 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) -cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS simple_threadpool) +cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS enforce simple_threadpool) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h index 17f0bf02720bc0d8f80c518945db0e65525db745..4b071f6706779221864962458a1f6ab5648bc845 100644 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -25,6 +25,8 @@ #include +#include "paddle/fluid/platform/enforce.h" + namespace paddle { namespace operators { namespace distributed { @@ -62,11 +64,12 @@ class AsyncSparseParamUpdateRecorder { public: AsyncSparseParamUpdateRecorder( - const std::unordered_map& grad_to_param, - int trainer_num) - : grad_to_param_(grad_to_param) { + int trainer_num, + const std::unordered_map& grad_to_param) + : trainer_num_(trainer_num), grad_to_param_(grad_to_param) { for (auto iter = grad_to_param.begin(); iter != grad_to_param.end(); iter++) { + param_to_grad_[iter->second] = iter->first; auto& param_name = iter->second; param_to_updated_rows_[param_name] = TrainerToRows(); auto& trainer_to_rows = param_to_updated_rows_[param_name]; @@ -76,31 +79,35 @@ class AsyncSparseParamUpdateRecorder { } } - ~AsyncSparseParamUpdateRecorder() {} + ~AsyncSparseParamUpdateRecorder() = default; void Update(const std::string& grad_name, const std::vector& update_rows) { auto& param_name = grad_to_param_.at(grad_name); auto& trainer_to_rows = param_to_updated_rows_.at(param_name); - std::vector> futures; for (auto& set : trainer_to_rows) { - futures.push_back(set->Update(update_rows)); - } - for (auto& f : futures) { - f.wait(); + // no need to wait here because GetAndClear will wait. + set->Update(update_rows); } } void GetAndClear(const std::string& param_name, int trainer_id, std::vector* result) { + PADDLE_ENFORCE_LT(trainer_id, trainer_num_); param_to_updated_rows_.at(param_name)[trainer_id] ->GetAndClear(result) .wait(); } + bool HasParam(const std::string& param_name) { + return param_to_grad_.find(param_name) != param_to_grad_.end(); + } + private: + const int trainer_num_; std::unordered_map grad_to_param_; + std::unordered_map param_to_grad_; std::unordered_map param_to_updated_rows_; }; diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc index 598bb59021ac4783a6da2f094c8949b298abf16e..af29230bad2eba7a3210de6db404cc96633162f4 100644 --- a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc @@ -22,7 +22,7 @@ namespace paddle { namespace operators { namespace distributed { -TEST(ConcurrentSet, Update) { +TEST(ConcurrentSet, All) { ConcurrentSet concurrent_set; std::vector in1 = {1, 2, 3, 4}; std::vector in2 = {2, 3, 5, 6}; @@ -51,6 +51,45 @@ TEST(ConcurrentSet, Update) { EXPECT_EQ(ret.size(), 0); } +TEST(AsyncSparseParamUpdateRecorder, All) { + std::unordered_map grad_to_param; + grad_to_param["grad1"] = "param1"; + grad_to_param["grad2"] = "param2"; + + int trainer_num = 10; + + AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param); + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + recorder.Update("grad1", in1); + recorder.Update("grad1", in2); + + EXPECT_TRUE(recorder.HasParam("param1")); + EXPECT_TRUE(recorder.HasParam("param2")); + EXPECT_FALSE(recorder.HasParam("param3")); + + std::vector ret; + EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret)); + + for (int i = 0; i < trainer_num; ++i) { + std::vector ret; + std::unordered_set out; + + recorder.GetAndClear("param1", i, &ret); + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + recorder.GetAndClear("param1", i, &ret); + EXPECT_EQ(ret.size(), 0); + } +} + } // namespace distributed } // namespace operators } // namespace paddle