From c6e82785aa7ed6e9a92ddfe48d8f5628e6443d4b Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 22 Mar 2019 22:37:13 +0800 Subject: [PATCH] init async_sparse_param_update_recorder --- .../operators/distributed/CMakeLists.txt | 1 + .../async_sparse_param_update_recorder.h | 109 ++++++++++++++++++ ...async_sparse_param_update_recorder_test.cc | 56 +++++++++ 3 files changed, 166 insertions(+) create mode 100644 paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h create mode 100644 paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 972b4f67a8..7e14a73d63 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -51,6 +51,7 @@ endif() cc_test(rpc_server_test SRCS rpc_server_test.cc DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) +cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS simple_threadpool) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h new file mode 100644 index 0000000000..17f0bf0272 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include + +namespace paddle { +namespace operators { +namespace distributed { + +class ConcurrentSet { + public: + ConcurrentSet() : pool_(new ::ThreadPool(1)) {} + ~ConcurrentSet() {} + + std::future Update(const std::vector& rows) { + auto task = [this, &rows] { + for (auto row : rows) { + set_.insert(row); + } + }; + return pool_->enqueue(std::move(task)); + } + + std::future GetAndClear(std::vector* result) { + auto task = [this, result] { + result->clear(); + result->insert(result->end(), set_.begin(), set_.end()); + set_.clear(); + }; + return pool_->enqueue(std::move(task)); + } + + private: + std::unordered_set set_; + std::unique_ptr<::ThreadPool> pool_{nullptr}; +}; + +class AsyncSparseParamUpdateRecorder { + using TrainerToRows = std::vector>; + + public: + AsyncSparseParamUpdateRecorder( + const std::unordered_map& grad_to_param, + int trainer_num) + : grad_to_param_(grad_to_param) { + for (auto iter = grad_to_param.begin(); iter != grad_to_param.end(); + iter++) { + auto& param_name = iter->second; + param_to_updated_rows_[param_name] = TrainerToRows(); + auto& trainer_to_rows = param_to_updated_rows_[param_name]; + for (auto i = 0; i < trainer_num; ++i) { + trainer_to_rows.emplace_back(new ConcurrentSet()); + } + } + } + + ~AsyncSparseParamUpdateRecorder() {} + + void Update(const std::string& grad_name, + const std::vector& update_rows) { + auto& param_name = grad_to_param_.at(grad_name); + auto& trainer_to_rows = param_to_updated_rows_.at(param_name); + + std::vector> futures; + for (auto& set : trainer_to_rows) { + futures.push_back(set->Update(update_rows)); + } + for (auto& f : futures) { + f.wait(); + } + } + + void GetAndClear(const std::string& param_name, int trainer_id, + std::vector* result) { + param_to_updated_rows_.at(param_name)[trainer_id] + ->GetAndClear(result) + .wait(); + } + + private: + std::unordered_map grad_to_param_; + std::unordered_map param_to_updated_rows_; +}; + +} // namespace distributed +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc new file mode 100644 index 0000000000..598bb59021 --- /dev/null +++ b/paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" + +#include + +#include "gtest/gtest.h" + +namespace paddle { +namespace operators { +namespace distributed { + +TEST(ConcurrentSet, Update) { + ConcurrentSet concurrent_set; + std::vector in1 = {1, 2, 3, 4}; + std::vector in2 = {2, 3, 5, 6}; + + std::vector> futures; + futures.push_back(concurrent_set.Update(in1)); + futures.push_back(concurrent_set.Update(in2)); + + for (auto &f : futures) { + f.wait(); + } + + std::unordered_set in; + std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin())); + std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin())); + + std::vector ret; + concurrent_set.GetAndClear(&ret).wait(); + + std::unordered_set out; + std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin())); + + EXPECT_EQ(in, out); + + concurrent_set.GetAndClear(&ret).wait(); + EXPECT_EQ(ret.size(), 0); +} + +} // namespace distributed +} // namespace operators +} // namespace paddle -- GitLab