未验证 提交 bd79ae8a 编写于 作者: Z zhaocaibei123 提交者: GitHub

brpc_ps_client upgrade (#36943)

* test

* rm test

* add memory_sparse_table and brpc communication upgrade dependency

* fix

* add dense optimizer & fix dump bug & add some strategy fields

* fix

* fix

* remove thread_pool thread_queue

* add memory sparse table

* update memory sparse table

* update memory sparse table

* update cmake

* upgrade brpc_ps_client

* remove show/click_const in ctr_accessor

* fix deconstructor
上级 abd4ab9c
......@@ -14,6 +14,7 @@
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <vector>
......@@ -23,6 +24,7 @@
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
......@@ -53,10 +55,9 @@ class DownpourPsClientService : public PsService {
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
size_t _rank;
......@@ -77,7 +78,7 @@ class DownpourBrpcClosure : public PSClientClosure {
}
}
virtual ~DownpourBrpcClosure() {}
virtual void Run() override {
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
......@@ -97,47 +98,87 @@ class DownpourBrpcClosure : public PSClientClosure {
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
struct SharedSparsePushData {
SharedSparsePushData() {}
~SharedSparsePushData() noexcept {}
size_t kv_num;
std::vector<uint64_t> key_list;
std::vector<std::string> value_list;
};
struct SparsePushTaskData {
std::vector<SharedSparsePushData> shared_data; // sparse数据按key hash分片
};
// push sparse 对象池
struct SparseTaskPool {
std::shared_ptr<SparsePushTaskData> get() {
std::lock_guard<std::mutex> lock(_mutex);
if (_pool.empty()) {
return std::make_shared<SparsePushTaskData>();
} else {
auto ret = _pool.back();
_pool.pop_back();
return ret;
}
}
void push(std::shared_ptr<SparsePushTaskData> data) {
std::lock_guard<std::mutex> lock(_mutex);
_pool.push_back(std::move(data));
}
std::vector<std::shared_ptr<SparsePushTaskData>> _pool;
std::mutex _mutex;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; }
void operator()(T *&x) const { delete[] x; } // NOLINT
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
// _running = false;
// try {
// _async_push_dense_thread.join();
// _async_push_sparse_thread.join();
//} catch (...) {
//}
if (_running) {
flush();
_running = false;
}
if (_async_push_dense_thread.joinable()) {
_async_push_dense_thread.join();
}
if (_async_push_sparse_thread.joinable()) {
_async_push_sparse_thread.join();
}
if (_server_started) {
_server.Stop(1000);
_server.Join();
_server_started = false;
}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) override;
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) override;
std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> clear() override;
std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override;
std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override;
std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override;
std::future<int32_t> start_profiler() override;
std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override;
void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
......@@ -146,6 +187,9 @@ class BrpcPsClient : public PSClient {
size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t table_id);
void push_dense_task_consume();
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num,
......@@ -164,13 +208,16 @@ class BrpcPsClient : public PSClient {
void *done);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
std::future<int32_t> send_client2client_msg(int msg_type, int to_client_id,
const std::string &msg) override;
// for local save sparse
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
void print_queue_size();
void print_queue_size_thread();
protected:
virtual size_t get_server_nums() { return _server_channels.size(); }
inline brpc::Channel *get_sparse_channel(size_t server_id) {
......@@ -182,7 +229,7 @@ class BrpcPsClient : public PSClient {
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
virtual int32_t initialize() override;
int32_t initialize() override;
private:
// virtual int32_t initialize() override;
......@@ -200,38 +247,74 @@ class BrpcPsClient : public PSClient {
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
std::atomic<uint32_t> _async_call_num; // 异步请求计数
// 异步push dense task
std::thread _async_push_dense_thread;
typedef AsyncRequestTask<std::shared_ptr<std::vector<float>>> DenseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<DenseAsyncTask *>>
_push_dense_task_queue_map;
// 异步push sparse task
std::thread _async_push_sparse_thread;
typedef AsyncRequestTask<std::shared_ptr<SparsePushTaskData>> SparseAsyncTask;
std::unordered_map<uint32_t, paddle::framework::Channel<SparseAsyncTask *>>
_push_sparse_task_queue_map;
std::unordered_map<uint32_t, uint32_t> _push_sparse_merge_count_map;
std::thread _print_thread;
int push_sparse_async_shard_merge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
ValueAccessor *accessor);
int push_sparse_async_shard_push(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, int table_id, int shard_idx, // NOLINT
DownpourBrpcClosure *closure, ValueAccessor *accessor);
SparseTaskPool _sparse_task_pool;
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
std::future<int32_t> push_dense_raw_gradient(int table_id,
float *total_send_data,
size_t total_send_data_size,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
std::future<int32_t> push_sparse_param(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num, void *done) override;
std::future<int32_t> push_sparse(size_t table_id, const uint64_t *keys,
const float **update_values,
size_t num) override;
void push_sparse_task_consume();
private:
int32_t start_client_service();
void push_dense_raw_gradient(std::shared_ptr<DenseAsyncTask> &task, // NOLINT
float *total_send_data,
size_t total_send_data_size,
DownpourBrpcClosure *closure);
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
bool _server_started = false;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
......
......@@ -628,6 +628,8 @@ void AsyncCommunicator::Start() {
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop";
_worker_ptr->finalize_worker();
VLOG(0) << "Communicator finalize_worker done";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
......
......@@ -114,9 +114,9 @@ class PSClient {
size_t region_num,
size_t table_id) = 0;
// virtual std::future<int32_t> push_dense(const Region *regions,
// size_t region_num,
// size_t table_id) = 0;
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
......@@ -222,10 +222,10 @@ class PSClient {
const uint64_t *keys,
const float **update_values,
size_t num, void *done) = 0;
// virtual std::future<int32_t> push_sparse(size_t table_id,
// const uint64_t *keys,
// const float **update_values,
// size_t num) = 0;
virtual std::future<int32_t> push_sparse(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num) = 0;
protected:
virtual int32_t initialize() = 0;
......
......@@ -270,8 +270,8 @@ bool CtrCommonAccessor::create_value(int stage, const float* value) {
return true;
} else if (stage == 1) {
// operation
auto show = CtrCommonPushValue::show_const(value);
auto click = CtrCommonPushValue::click_const(value);
auto show = CtrCommonPushValue::show(const_cast<float*>(value));
auto click = CtrCommonPushValue::click(const_cast<float*>(value));
auto score = show_click_score(show, click);
if (score <= 0) {
return false;
......@@ -302,8 +302,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
i < common_feature_value.embedx_w_index(); i++) {
os << " " << v[i];
}
auto show = common_feature_value.show_const(v);
auto click = common_feature_value.click_const(v);
auto show = common_feature_value.show(const_cast<float*>(v));
auto click = common_feature_value.click(const_cast<float*>(v));
auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold()) {
for (auto i = common_feature_value.embedx_w_index();
......
......@@ -61,14 +61,7 @@ class CtrCommonAccessor : public ValueAccessor {
float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
float& embedx_w(float* val) { return val[embedx_w_index()]; }
float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
float show_const(const float* val) {
float s = val[show_index()];
return s;
}
float click_const(const float* val) {
float c = val[click_index()];
return c;
}
int embed_sgd_dim;
int embedx_dim;
int embedx_sgd_dim;
......@@ -103,14 +96,6 @@ class CtrCommonAccessor : public ValueAccessor {
static float& click(float* val) {
return val[CtrCommonPushValue::click_index()];
}
static float show_const(const float* val) {
float s = val[show_index()];
return s;
}
static float click_const(const float* val) {
float c = val[click_index()];
return c;
}
static float& embed_g(float* val) {
return val[CtrCommonPushValue::embed_g_index()];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册