diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 972b4f67a8388ce68952fa90aaa224cd45c6d226..7e14a73d638e5c6cac3238348f60bd5e4809249b 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -51,6 +51,7 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) +cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS simple_threadpool) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..17f0bf02720bc0d8f80c518945db0e65525db745 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include + +namespace paddle { +namespace operators { +namespace distributed { + +class ConcurrentSet { + public: + ConcurrentSet() : pool_(new ::ThreadPool(1)) {} + ~ConcurrentSet() {} + + std::future Update(const std::vector& rows) { + auto task = [this, &rows] { + for (auto row : rows) { + set_.insert(row); + } + }; + return pool_->enqueue(std::move(task)); + } + + std::future GetAndClear(std::vector* result) { + auto task = [this, result] { + result->clear(); + result->insert(result->end(), set_.begin(), set_.end()); + set_.clear(); + }; + return pool_->enqueue(std::move(task)); + } + + private: + std::unordered_set set_; + std::unique_ptr<::ThreadPool> pool_{nullptr}; +}; + +class AsyncSparseParamUpdateRecorder { + using TrainerToRows = std::vector>; + + public: + AsyncSparseParamUpdateRecorder( + const std::unordered_map& grad_to_param, + int trainer_num) + : grad_to_param_(grad_to_param) { + for (auto iter = grad_to_param.begin(); iter != grad_to_param.end(); + iter++) { + auto& param_name = iter->second; + param_to_updated_rows_[param_name] = TrainerToRows(); + auto& trainer_to_rows = param_to_updated_rows_[param_name]; + for (auto i = 0; i < trainer_num; ++i) { + trainer_to_rows.emplace_back(new ConcurrentSet()); + } + } + } + + ~AsyncSparseParamUpdateRecorder() {} + + void Update(const std::string& grad_name, + const std::vector& update_rows) { + auto& param_name = grad_to_param_.at(grad_name); + auto& trainer_to_rows = param_to_updated_rows_.at(param_name); + + std::vector> futures; + for (auto& set : trainer_to_rows) { + futures.push_back(set->Update(update_rows)); + } + for (auto& f : futures) { + f.wait(); + } + } + + void GetAndClear(const std::string& param_name, int trainer_id, + std::vector* result) { + param_to_updated_rows_.at(param_name)[trainer_id] + ->GetAndClear(result) + .wait(); + } + + private: + std::unordered_map grad_to_param_; + std::unordered_map param_to_updated_rows_; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..598bb59021ac4783a6da2f094c8949b298abf16e --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +#include + +#include "gtest/gtest.h" + +namespace paddle { +namespace operators { +namespace distributed { + +TEST(ConcurrentSet, Update) { + ConcurrentSet concurrent_set; + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::vector> futures; + futures.push_back(concurrent_set.Update(in1)); + futures.push_back(concurrent_set.Update(in2)); + + for (auto &f : futures) { + f.wait(); + } + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + std::vector ret; + concurrent_set.GetAndClear(&ret).wait(); + + std::unordered_set out; + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + concurrent_set.GetAndClear(&ret).wait(); + EXPECT_EQ(ret.size(), 0); +} + +} // namespace distributed +} // namespace operators +} // namespace paddle