未验证 提交 bd79ae8a 编写于 作者: Z zhaocaibei123 提交者: GitHub

brpc_ps_client upgrade (#36943)

* test

* rm test

* add memory_sparse_table and brpc communication upgrade dependency

* fix

* add dense optimizer & fix dump bug & add some strategy fields

* fix

* fix

* remove thread_pool thread_queue

* add memory sparse table

* update memory sparse table

* update memory sparse table

* update cmake

* upgrade brpc_ps_client

* remove show/click_const in ctr_accessor

* fix deconstructor
上级 abd4ab9c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <ThreadPool.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -23,6 +24,7 @@ ...@@ -23,6 +24,7 @@
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -53,9 +55,8 @@ class DownpourPsClientService : public PsService { ...@@ -53,9 +55,8 @@ class DownpourPsClientService : public PsService {
_rank = rank_id; _rank = rank_id;
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request, PsResponseMessage *response,
PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
protected: protected:
...@@ -77,7 +78,7 @@ class DownpourBrpcClosure : public PSClientClosure { ...@@ -77,7 +78,7 @@ class DownpourBrpcClosure : public PSClientClosure {
} }
} }
virtual ~DownpourBrpcClosure() {} virtual ~DownpourBrpcClosure() {}
virtual void Run() override { void Run() override {
if (_waiting_num.fetch_sub(1) == 1) { if (_waiting_num.fetch_sub(1) == 1) {
_callback(this); _callback(this);
delete this; delete this;
...@@ -97,47 +98,87 @@ class DownpourBrpcClosure : public PSClientClosure { ...@@ -97,47 +98,87 @@ class DownpourBrpcClosure : public PSClientClosure {
std::vector<std::shared_ptr<brpc::Controller>> _cntls; std::vector<std::shared_ptr<brpc::Controller>> _cntls;
}; };
struct SharedSparsePushData {
SharedSparsePushData() {}
~SharedSparsePushData() noexcept {}
size_t kv_num;
std::vector<uint64_t> key_list;
std::vector<std::string> value_list;
};
struct SparsePushTaskData {
std::vector<SharedSparsePushData> shared_data; // sparse数据按key hash分片
};
// push sparse 对象池
struct SparseTaskPool {
std::shared_ptr<SparsePushTaskData> get() {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.empty()) {
return std::make_shared<SparsePushTaskData>();
} else {
auto ret = _pool.back();
_pool.pop_back();
return ret;
}
}
void push(std::shared_ptr<SparsePushTaskData> data) {
std::lock_guard<std::mutex> lock(_mutex);
_pool.push_back(std::move(data));
}
std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
std::mutex _mutex;
};
template <class T> template <class T>
struct array_deleter { struct array_deleter {
void operator()(T *&x) const { delete[] x; } void operator()(T *&x) const { delete[] x; } // NOLINT
}; };
class BrpcPsClient : public PSClient { class BrpcPsClient : public PSClient {
public: public:
BrpcPsClient() {} BrpcPsClient() {}
virtual ~BrpcPsClient() { virtual ~BrpcPsClient() {
// _running = false; if (_running) {
// try { flush();
// _async_push_dense_thread.join(); _running = false;
// _async_push_sparse_thread.join(); }
//} catch (...) { if (_async_push_dense_thread.joinable()) {
//} _async_push_dense_thread.join();
}
if (_async_push_sparse_thread.joinable()) {
_async_push_sparse_thread.join();
}
if (_server_started) {
_server.Stop(1000);
_server.Join();
_server_started = false;
}
} }
virtual int32_t create_client2client_connection( virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id, std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) override; const std::string threshold) override;
virtual std::future<int32_t> load(const std::string &epoch, std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch, std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> clear() override; std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override; std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override; std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override; std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override; std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override; void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id); size_t table_id);
...@@ -146,6 +187,9 @@ class BrpcPsClient : public PSClient { ...@@ -146,6 +187,9 @@ class BrpcPsClient : public PSClient {
size_t region_num, size_t region_num,
size_t table_id); size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values, virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id, size_t table_id,
const uint64_t *keys, size_t num, const uint64_t *keys, size_t num,
...@@ -164,13 +208,16 @@ class BrpcPsClient : public PSClient { ...@@ -164,13 +208,16 @@ class BrpcPsClient : public PSClient {
void *done); void *done);
virtual std::future<int32_t> flush(); virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg( std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
int msg_type, int to_client_id, const std::string &msg) override; const std::string &msg) override;
// for local save sparse // for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id, virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path); const std::string &path);
void print_queue_size();
void print_queue_size_thread();
protected: protected:
virtual size_t get_server_nums() { return _server_channels.size(); } virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) { inline brpc::Channel *get_sparse_channel(size_t server_id) {
...@@ -182,7 +229,7 @@ class BrpcPsClient : public PSClient { ...@@ -182,7 +229,7 @@ class BrpcPsClient : public PSClient {
inline brpc::Channel *get_cmd_channel(size_t server_id) { inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get(); return _server_channels[server_id][2].get();
} }
virtual int32_t initialize() override; int32_t initialize() override;
private: private:
// virtual int32_t initialize() override; // virtual int32_t initialize() override;
...@@ -200,38 +247,74 @@ class BrpcPsClient : public PSClient { ...@@ -200,38 +247,74 @@ class BrpcPsClient : public PSClient {
bool _running = false; bool _running = false;
bool _flushing = false; bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数 std::atomic<uint32_t> _async_call_num; // 异步请求计数
// 异步push dense task
std::thread _async_push_dense_thread;
typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
_push_dense_task_queue_map;
// 异步push sparse task
std::thread _async_push_sparse_thread;
typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
_push_sparse_task_queue_map;
std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;
std::thread _print_thread;
int push_sparse_async_shard_merge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);
int push_sparse_async_shard_push(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
std::vector<std::shared_ptr<brpc::Channel>> std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client _client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>> std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server _server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient( std::future<int32_t> push_dense_raw_gradient(int table_id,
int table_id, float *total_send_data, size_t total_send_data_size, float *total_send_data,
size_t total_send_data_size,
void *done) override; void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient( std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
size_t table_id, const uint64_t *keys, const float **update_values, const uint64_t *keys,
size_t num, void *done) override; const float **update_values,
size_t num,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial( std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values, size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override; uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id, std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, size_t num, void *done) override;
void *done) override; std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
private: private:
int32_t start_client_service(); int32_t start_client_service();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0; float _mae = 0;
float _mse = 0; float _mse = 0;
uint16_t _push_times = 0; uint16_t _push_times = 0;
brpc::Server _server; brpc::Server _server;
DownpourPsClientService _service; DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0}; std::atomic_uint grad_num_{0};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -628,6 +628,8 @@ void AsyncCommunicator::Start() { ...@@ -628,6 +628,8 @@ void AsyncCommunicator::Start() {
void AsyncCommunicator::Stop() { void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop"; VLOG(1) << "Communicator stop";
_worker_ptr->finalize_worker();
VLOG(0) << "Communicator finalize_worker done";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
......
...@@ -114,9 +114,9 @@ class PSClient { ...@@ -114,9 +114,9 @@ class PSClient {
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
// virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num, size_t region_num,
// size_t table_id) = 0; size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用 // future结束前keys和values缓冲区不能再次使用
...@@ -222,10 +222,10 @@ class PSClient { ...@@ -222,10 +222,10 @@ class PSClient {
const uint64_t *keys, const uint64_t *keys,
const float **update_values, const float **update_values,
size_t num, void *done) = 0; size_t num, void *done) = 0;
// virtual std::future<int32_t> push_sparse(size_t table_id, virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys, const uint64_t *keys,
// const float **update_values, const float **update_values,
// size_t num) = 0; size_t num) = 0;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t initialize() = 0;
......
...@@ -270,8 +270,8 @@ bool CtrCommonAccessor::create_value(int stage, const float* value) { ...@@ -270,8 +270,8 @@ bool CtrCommonAccessor::create_value(int stage, const float* value) {
return true; return true;
} else if (stage == 1) { } else if (stage == 1) {
// operation // operation
auto show = CtrCommonPushValue::show_const(value); auto show = CtrCommonPushValue::show(const_cast<float*>(value));
auto click = CtrCommonPushValue::click_const(value); auto click = CtrCommonPushValue::click(const_cast<float*>(value));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score <= 0) { if (score <= 0) {
return false; return false;
...@@ -302,8 +302,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) { ...@@ -302,8 +302,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
i < common_feature_value.embedx_w_index(); i++) { i < common_feature_value.embedx_w_index(); i++) {
os << " " << v[i]; os << " " << v[i];
} }
auto show = common_feature_value.show_const(v); auto show = common_feature_value.show(const_cast<float*>(v));
auto click = common_feature_value.click_const(v); auto click = common_feature_value.click(const_cast<float*>(v));
auto score = show_click_score(show, click); auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold()) { if (score >= _config.embedx_threshold()) {
for (auto i = common_feature_value.embedx_w_index(); for (auto i = common_feature_value.embedx_w_index();
......
...@@ -61,14 +61,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -61,14 +61,7 @@ class CtrCommonAccessor : public ValueAccessor {
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; } float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; } float& embedx_w(float* val) { return val[embedx_w_index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; } float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float show_const(const float* val) {
float s = val[show_index()];
return s;
}
float click_const(const float* val) {
float c = val[click_index()];
return c;
}
int embed_sgd_dim; int embed_sgd_dim;
int embedx_dim; int embedx_dim;
int embedx_sgd_dim; int embedx_sgd_dim;
...@@ -103,14 +96,6 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -103,14 +96,6 @@ class CtrCommonAccessor : public ValueAccessor {
static float& click(float* val) { static float& click(float* val) {
return val[CtrCommonPushValue::click_index()]; return val[CtrCommonPushValue::click_index()];
} }
static float show_const(const float* val) {
float s = val[show_index()];
return s;
}
static float click_const(const float* val) {
float c = val[click_index()];
return c;
}
static float& embed_g(float* val) { static float& embed_g(float* val) {
return val[CtrCommonPushValue::embed_g_index()]; return val[CtrCommonPushValue::embed_g_index()];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册