提交 c6e82785 编写于 作者: Q Qiao Longfei

init async_sparse_param_update_recorder

上级 039d783d
...@@ -51,6 +51,7 @@ endif() ...@@ -51,6 +51,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS simple_threadpool)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ThreadPool.h>
namespace paddle {
namespace operators {
namespace distributed {
class ConcurrentSet {
public:
ConcurrentSet() : pool_(new ::ThreadPool(1)) {}
~ConcurrentSet() {}
std::future<void> Update(const std::vector<int64_t>& rows) {
auto task = [this, &rows] {
for (auto row : rows) {
set_.insert(row);
}
};
return pool_->enqueue(std::move(task));
}
std::future<void> GetAndClear(std::vector<int64_t>* result) {
auto task = [this, result] {
result->clear();
result->insert(result->end(), set_.begin(), set_.end());
set_.clear();
};
return pool_->enqueue(std::move(task));
}
private:
std::unordered_set<int64_t> set_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
};
class AsyncSparseParamUpdateRecorder {
using TrainerToRows = std::vector<std::unique_ptr<ConcurrentSet>>;
public:
AsyncSparseParamUpdateRecorder(
const std::unordered_map<std::string, std::string>& grad_to_param,
int trainer_num)
: grad_to_param_(grad_to_param) {
for (auto iter = grad_to_param.begin(); iter != grad_to_param.end();
iter++) {
auto& param_name = iter->second;
param_to_updated_rows_[param_name] = TrainerToRows();
auto& trainer_to_rows = param_to_updated_rows_[param_name];
for (auto i = 0; i < trainer_num; ++i) {
trainer_to_rows.emplace_back(new ConcurrentSet());
}
}
}
~AsyncSparseParamUpdateRecorder() {}
void Update(const std::string& grad_name,
const std::vector<int64_t>& update_rows) {
auto& param_name = grad_to_param_.at(grad_name);
auto& trainer_to_rows = param_to_updated_rows_.at(param_name);
std::vector<std::future<void>> futures;
for (auto& set : trainer_to_rows) {
futures.push_back(set->Update(update_rows));
}
for (auto& f : futures) {
f.wait();
}
}
void GetAndClear(const std::string& param_name, int trainer_id,
std::vector<int64_t>* result) {
param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result)
.wait();
}
private:
std::unordered_map<std::string, std::string> grad_to_param_;
std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include <algorithm>
#include "gtest/gtest.h"
namespace paddle {
namespace operators {
namespace distributed {
TEST(ConcurrentSet, Update) {
ConcurrentSet concurrent_set;
std::vector<int64_t> in1 = {1, 2, 3, 4};
std::vector<int64_t> in2 = {2, 3, 5, 6};
std::vector<std::future<void>> futures;
futures.push_back(concurrent_set.Update(in1));
futures.push_back(concurrent_set.Update(in2));
for (auto &f : futures) {
f.wait();
}
std::unordered_set<int64_t> in;
std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin()));
std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin()));
std::vector<int64_t> ret;
concurrent_set.GetAndClear(&ret).wait();
std::unordered_set<int64_t> out;
std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin()));
EXPECT_EQ(in, out);
concurrent_set.GetAndClear(&ret).wait();
EXPECT_EQ(ret.size(), 0);
}
} // namespace distributed
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册