提交 25e2b417 编写于 作者: Q Qiao Longfei

add AsyncSparseParamUpdateRecorder test

上级 c6e82785
......@@ -51,7 +51,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS simple_threadpool)
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS enforce simple_threadpool)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
......
......@@ -25,6 +25,8 @@
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
......@@ -62,11 +64,12 @@ class AsyncSparseParamUpdateRecorder {
public:
AsyncSparseParamUpdateRecorder(
const std::unordered_map<std::string, std::string>& grad_to_param,
int trainer_num)
: grad_to_param_(grad_to_param) {
int trainer_num,
const std::unordered_map<std::string, std::string>& grad_to_param)
: trainer_num_(trainer_num), grad_to_param_(grad_to_param) {
for (auto iter = grad_to_param.begin(); iter != grad_to_param.end();
iter++) {
param_to_grad_[iter->second] = iter->first;
auto& param_name = iter->second;
param_to_updated_rows_[param_name] = TrainerToRows();
auto& trainer_to_rows = param_to_updated_rows_[param_name];
......@@ -76,31 +79,35 @@ class AsyncSparseParamUpdateRecorder {
}
}
~AsyncSparseParamUpdateRecorder() {}
~AsyncSparseParamUpdateRecorder() = default;
void Update(const std::string& grad_name,
const std::vector<int64_t>& update_rows) {
auto& param_name = grad_to_param_.at(grad_name);
auto& trainer_to_rows = param_to_updated_rows_.at(param_name);
std::vector<std::future<void>> futures;
for (auto& set : trainer_to_rows) {
futures.push_back(set->Update(update_rows));
}
for (auto& f : futures) {
f.wait();
// no need to wait here because GetAndClear will wait.
set->Update(update_rows);
}
}
void GetAndClear(const std::string& param_name, int trainer_id,
std::vector<int64_t>* result) {
PADDLE_ENFORCE_LT(trainer_id, trainer_num_);
param_to_updated_rows_.at(param_name)[trainer_id]
->GetAndClear(result)
.wait();
}
bool HasParam(const std::string& param_name) {
return param_to_grad_.find(param_name) != param_to_grad_.end();
}
private:
const int trainer_num_;
std::unordered_map<std::string, std::string> grad_to_param_;
std::unordered_map<std::string, std::string> param_to_grad_;
std::unordered_map<std::string, TrainerToRows> param_to_updated_rows_;
};
......
......@@ -22,7 +22,7 @@ namespace paddle {
namespace operators {
namespace distributed {
TEST(ConcurrentSet, Update) {
TEST(ConcurrentSet, All) {
ConcurrentSet concurrent_set;
std::vector<int64_t> in1 = {1, 2, 3, 4};
std::vector<int64_t> in2 = {2, 3, 5, 6};
......@@ -51,6 +51,45 @@ TEST(ConcurrentSet, Update) {
EXPECT_EQ(ret.size(), 0);
}
TEST(AsyncSparseParamUpdateRecorder, All) {
std::unordered_map<std::string, std::string> grad_to_param;
grad_to_param["grad1"] = "param1";
grad_to_param["grad2"] = "param2";
int trainer_num = 10;
AsyncSparseParamUpdateRecorder recorder(trainer_num, grad_to_param);
std::vector<int64_t> in1 = {1, 2, 3, 4};
std::vector<int64_t> in2 = {2, 3, 5, 6};
std::unordered_set<int64_t> in;
std::copy(in1.begin(), in1.end(), std::inserter(in, in.begin()));
std::copy(in2.begin(), in2.end(), std::inserter(in, in.begin()));
recorder.Update("grad1", in1);
recorder.Update("grad1", in2);
EXPECT_TRUE(recorder.HasParam("param1"));
EXPECT_TRUE(recorder.HasParam("param2"));
EXPECT_FALSE(recorder.HasParam("param3"));
std::vector<int64_t> ret;
EXPECT_ANY_THROW(recorder.GetAndClear("param1", trainer_num, &ret));
for (int i = 0; i < trainer_num; ++i) {
std::vector<int64_t> ret;
std::unordered_set<int64_t> out;
recorder.GetAndClear("param1", i, &ret);
std::copy(ret.begin(), ret.end(), std::inserter(out, out.begin()));
EXPECT_EQ(in, out);
recorder.GetAndClear("param1", i, &ret);
EXPECT_EQ(ret.size(), 0);
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册