未验证 提交 0034273b 编写于 作者: T tangwei12 提交者: GitHub

add service (#29560)

* add service, remove ut on mac

* fix heter_profiler & add heter stop method

* fix code style
上级 c0163837
......@@ -14,3 +14,17 @@ endif()
add_subdirectory(table)
add_subdirectory(test)
# open it until CI support brpc
return()
add_subdirectory(service)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(fleet
SRCS fleet.cc
DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS})
target_link_libraries(fleet z)
此差异已折叠。
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace distributed {
using framework::LoDTensor;
using framework::Scope;
using framework::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {
scale_sparse_gradient_with_batch_size_ = true;
// trainer sleep some time for pserver core dump
sleep_seconds_before_fail_exit_ = 300;
// pserver request server timeout ms
client2client_request_timeout_ms_ = 500000;
// pserver connect server timeout_ms
client2client_connect_timeout_ms_ = 10000;
// pserver request max retry
client2client_max_retry_ = 3;
}
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names
// Param<out>: fea_values
void PullSparseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_dim,
const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode
// pull immediately to tensors
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<LoDTensor*>* outputs); // NOLINT
// pull dense variables from server in sync mod
// Param<in>: scope, table_id, var_names
// Param<out>: void
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
// pull dense variables from server in async mod
// Param<in>: scope, table_id, var_names
// Param<out>: pull_dense_status
void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* pull_dense_status,
bool in_cpu);
// push dense parameters(not gradients) to server in sync mode
void PushDenseParamSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status,
float scale_datanorm, int batch_size);
// push dense variables to server in sync mode
void PushDenseVarsSync(Scope* scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
void PushSparseVarsAsync(
const Scope& scope, const uint64_t table_id, const std::string& grad,
std::vector<std::future<int32_t>>* push_sparse_status);
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
// Param<out>: push_values, push_sparse_status
void PushSparseVarsWithLabelAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<uint64_t>& fea_keys,
const std::vector<float>& fea_labels,
const std::vector<std::string>& sparse_key_names,
const std::vector<std::string>& sparse_grad_names, const int emb_dim,
std::vector<std::vector<float>>* push_values,
std::vector<std::future<int32_t>>* push_sparse_status,
const int batch_size, const bool use_cvm, const bool dump_slot,
std::vector<uint64_t>* sparse_push_keys, const bool no_cvm);
// Push sparse variables to server in async mode
void PushSparseFromTensorWithLabelAsync(
const Scope& scope, const uint64_t table_id, int fea_dim,
uint64_t padding_id, bool scale_sparse, const std::string& accesor,
const std::string& click_name, platform::Place place,
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
// init server
void LoadSparseOnServer(const std::string& path, const std::string& meta,
uint32_t table_id);
// init server
// void InitServer(const std::string& dist_desc,
// const std::vector<uint64_t>& host_sign_list, int index);
void InitServer(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, int index);
// init trainer
void InitWorker(const std::string& dist_desc,
const std::vector<std::string>& host_sign_list, Scope* scope,
const RpcCtxMap& send_ctx,
const std::unordered_map<uint64_t, std::vector<std::string>>&
dense_varnames,
const std::map<std::string, std::string>& envs, int node_num,
int index);
// stop server
void StopServer();
// finalize worker to make worker can be stop
void FinalizeWorker();
// run server with ip port
uint64_t RunServer(const std::string& ip, uint32_t port);
// get client info
std::vector<uint64_t> GetClientsInfo();
// create client to client connection
void CreateClient2ClientConnection();
// flush all push requests
void ClientFlush();
// barrier with barrier table
void BarrierWithTable(uint32_t barrier_type);
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModel(const std::string& path, const int mode);
// mode = 0, save all feature
// mode = 1, save delta feature, which means save diff
void SaveModelOneTable(const uint64_t table_id, const std::string& path,
const int mode);
// clear all models, release their memory
void ClearModel();
// clear one table
void ClearOneTable(const uint64_t table_id);
// shrink sparse table
void ShrinkSparseTable(int table_id);
// shrink dense table
void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay,
int emb_dim);
typedef std::function<int32_t(int, int, const std::string&)> MsgHandlerFunc;
// register client to client communication
int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler);
// send client to client message
std::future<int32_t> SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg);
// FleetWrapper singleton
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::distributed::FleetWrapper());
}
return s_instance_;
}
// this performs better than rand_r, especially large data
std::default_random_engine& LocalRandomEngine();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
private:
static std::shared_ptr<FleetWrapper> s_instance_;
size_t GetAbsoluteSum(size_t start, size_t end, size_t level,
const framework::LoD& lod);
protected:
static bool is_initialized_;
std::map<uint64_t, std::vector<paddle::distributed::Region>> _regions;
bool scale_sparse_gradient_with_batch_size_;
int32_t sleep_seconds_before_fail_exit_;
int client2client_request_timeout_ms_;
int client2client_connect_timeout_ms_;
int client2client_max_retry_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
} // end namespace distributed
} // end namespace paddle
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
PROTO sendrecv.proto
DEPS ${BRPC_DEPS} )
set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include "paddle/fluid/distributed/service/ps_client.h"
namespace paddle {
namespace distributed {
class DownpourPsClientService : public PsService {
public:
DownpourPsClientService() {}
virtual ~DownpourPsClientService() {}
virtual int32_t configure(PSClient *client, size_t rank_id) {
_client = client;
_rank = rank_id;
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
size_t _rank;
PSClient *_client;
};
class DownpourBrpcClosure : public PSClientClosure {
public:
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
: PSClientClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id);
int check_save_response(size_t request_idx, int cmd_id);
std::string get_response(size_t request_idx, int cmd_id);
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
template <class T>
struct array_deleter {
void operator()(T *&x) const { delete[] x; }
};
class BrpcPsClient : public PSClient {
public:
BrpcPsClient() {}
virtual ~BrpcPsClient() {
// _running = false;
// try {
// _async_push_dense_thread.join();
// _async_push_sparse_thread.join();
//} catch (...) {
//}
}
virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id) override;
virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> clear() override;
virtual std::future<int32_t> clear(uint32_t table_id) override;
virtual std::future<int32_t> stop_server() override;
virtual std::future<int32_t> start_profiler() override;
virtual std::future<int32_t> stop_profiler() override;
virtual void finalize_worker() override;
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id);
virtual std::future<int32_t> push_dense_param(const Region *regions,
size_t region_num,
size_t table_id);
virtual std::future<int32_t> pull_sparse(float **select_values,
size_t table_id,
const uint64_t *keys, size_t num);
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
virtual std::future<int32_t> pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
int pserver_idx);
virtual std::future<int32_t> flush();
virtual std::future<int32_t> send_client2client_msg(
int msg_type, int to_client_id, const std::string &msg) override;
private:
virtual int32_t initialize() override;
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector<std::string> &param);
inline brpc::Channel *get_sparse_channel(size_t server_id) {
return _server_channels[server_id][0].get();
}
inline brpc::Channel *get_dense_channel(size_t server_id) {
return _server_channels[server_id][1].get();
}
inline brpc::Channel *get_cmd_channel(size_t server_id) {
return _server_channels[server_id][2].get();
}
bool _running = false;
bool _flushing = false;
std::atomic<uint32_t> _async_call_num; //异步请求计数
std::vector<std::shared_ptr<brpc::Channel>>
_client_channels; // client2client
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
_server_channels; // client2server
virtual std::future<int32_t> push_dense_raw_gradient(
int table_id, float *total_send_data, size_t total_send_data_size,
void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient(
size_t table_id, const uint64_t *keys, const float **update_values,
size_t num, void *done) override;
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
size_t table_id, const uint64_t *keys, const float **update_values,
uint32_t num, void *done, int pserver_idx) override;
virtual std::future<int32_t> push_sparse_param(size_t table_id,
const uint64_t *keys,
const float **update_values,
size_t num,
void *done) override;
virtual size_t get_server_nums() { return _server_channels.size(); }
private:
int32_t start_client_service();
float _mae = 0;
float _mse = 0;
uint16_t _push_times = 0;
brpc::Server _server;
DownpourPsClientService _service;
std::atomic_uint grad_num_{0};
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "butil/endpoint.h"
#include "iomanip"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace distributed {
int32_t BrpcPsServer::initialize() {
auto &service_config = _config.downpour_server_param().service_param();
if (!service_config.has_service_class()) {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service = CREATE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
return -1;
}
_service.reset(service);
if (service->configure(this) != 0 || service->initialize() != 0) {
LOG(ERROR) << "service initialize failed, service_name:"
<< service_config.service_class();
return -1;
}
if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
LOG(ERROR) << "service add to brpc failed, service:"
<< service_config.service_class();
return -1;
}
return 0;
}
uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
std::unique_lock<std::mutex> lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
int num_threads = std::thread::hardware_concurrency();
brpc::ServerOptions options;
options.num_threads = num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port;
return 0;
}
VLOG(0) << "BrpcPsServer::start registe_ps_server";
_environment->registe_ps_server(ip, port, _rank);
VLOG(0) << "BrpcPsServer::start wait";
cv_.wait(lock, [&] { return stoped_; });
PSHost host;
host.ip = ip;
host.port = port;
host.rank = _rank;
VLOG(0) << "BrpcPsServer::start return host.rank";
return host.rank;
}
int32_t BrpcPsServer::port() { return _server.listen_address().port; }
int32_t PsService::initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &PsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param;
_service_handler_map[PS_BARRIER] = &PsService::barrier;
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
return 0;
}
#define CHECK_TABLE_EXIST(table, request, response) \
if (table == NULL) { \
std::string err_msg("table not found with table_id:"); \
err_msg.append(std::to_string(request.table_id())); \
set_response_code(response, -1, err_msg.c_str()); \
return -1; \
}
int32_t PsService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
return 0;
}
size_t shard_num = _server->environment()->get_ps_servers().size();
auto &table_map = *(_server->table());
for (auto itr : table_map) {
itr.second->set_shard(_rank, shard_num);
}
_is_initialize_shard_info = true;
}
return 0;
}
void PsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
return;
}
response->set_err_code(0);
response->set_err_msg("");
auto *table = _server->table(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
}
int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_dense");
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 1 for num of dense");
return 0;
}
uint32_t num = *(const uint32_t *)request.params(0).c_str();
if (num < 0) {
set_response_code(response, -1,
"PsRequestMessage.datas[0] is invalid, num must >= 0");
return 0;
}
std::vector<float> res_data;
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_dense(res_data.data(), num);
cntl->response_attachment().append((char *)res_data.data(),
res_data.size() * sizeof(float));
return 0;
}
int32_t PsService::push_dense_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
push_buffer.resize(0);
push_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_buffer.data()), req_buffer_size);
uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t));
if (table->push_dense_param(values, num) != 0) {
set_response_code(response, -1, "push_dense_param failed");
}
return 0;
}
int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense");
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
// set_response_code(response, 0, "push dense data is empty");
return 0;
}
/*
Push Content:
|--num--|---valuesData---|
|--4B---|----------------|
*/
uint32_t num = *(const uint32_t *)(request.data().data());
const float *values =
(const float *)(request.data().data() + sizeof(uint32_t));
if (table->push_dense(values, num) != 0) {
set_response_code(response, -1, "push_dense failed");
}
return 0;
}
int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
auto trainer_id = request.client_id();
auto barrier_type = request.params(0);
table->barrier(trainer_id, barrier_type);
return 0;
}
int32_t PsService::push_sparse_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse_param(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse_param error");
}
return 0;
}
int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_geo_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto trainer_id = request.client_id();
std::vector<float> values;
std::vector<uint64_t> ids;
table->pull_geo_param(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
cntl->response_attachment().append((char *)ids.data(),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(),
values.size() * sizeof(float));
return 0;
}
int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
push_sparse_request_buffer.resize(0);
push_sparse_request_buffer.reserve(req_buffer_size);
const char *data = (const char *)cntl->request_attachment().fetch(
const_cast<char *>(push_sparse_request_buffer.data()), req_buffer_size);
/*
Attachment Content:
|---keysData---|
|---8*{num}B---|
*/
const uint64_t *keys = (const uint64_t *)data;
std::vector<float> res_data;
res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
table->pull_sparse(res_data.data(), keys, num);
cntl->response_attachment().append((char *)res_data.data(),
res_data.size() * sizeof(float));
return 0;
}
int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
if (push_data.size() < 1) {
// set_response_code(response, 0, "push sparse data is empty");
return 0;
}
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
/*
Push Content:
|---keysData---|---valuesData---|
|---8*{num}B---|----------------|
*/
const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->push_sparse(keys, values, num) != 0) {
set_response_code(response, -1, "push_sparse error");
}
return 0;
}
int32_t PsService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
paddle::framework::BinaryArchive ar;
ar << ret.first << ret.second;
std::string table_info(ar.Buffer(), ar.Length());
response.set_data(table_info);
return 0;
}
int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 2 for path & load_param");
return -1;
}
if (table->load(request.params(0), request.params(1)) != 0) {
set_response_code(response, -1, "table load failed");
return -1;
}
return 0;
}
int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
LOG(ERROR) << "load table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 2, path&mode");
return -1;
}
table->flush();
int32_t feasign_size = 0;
feasign_size = table->save(request.params(0), request.params(1));
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
int32_t all_feasign_size = 0;
int32_t feasign_size = 0;
for (auto &itr : table_map) {
feasign_size = save_one_table(itr.second.get(), request, response, cntl);
if (feasign_size < 0) {
LOG(ERROR) << "save table[" << itr.first << "] failed";
return -1;
}
}
return 0;
}
int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
if (table->shrink() != 0) {
set_response_code(response, -1, "table shrink failed");
}
return 0;
}
int32_t PsService::clear_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
table->clear();
return 0;
}
int32_t PsService::clear_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) {
return -1;
}
}
return 0;
}
int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->stop();
LOG(INFO) << "Server Stoped";
});
t_stop.detach();
return 0;
}
int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "brpc/channel.h"
#include "brpc/controller.h"
#include "brpc/server.h"
#include <memory>
#include <vector>
#include "paddle/fluid/distributed/service/server.h"
namespace paddle {
namespace distributed {
class BrpcPsServer : public PSServer {
public:
BrpcPsServer() {}
virtual ~BrpcPsServer() {}
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
stoped_ = true;
cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
virtual int32_t port();
private:
virtual int32_t initialize();
mutable std::mutex mutex_;
std::condition_variable cv_;
bool stoped_ = false;
brpc::Server _server;
std::shared_ptr<PsBaseService> _service;
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class PsService;
typedef int32_t (PsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class PsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t initialize_shard_info();
int32_t pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_dense_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t clear_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
};
class DownpourPServerBrpcClosure : public PServerClosure {
public:
DownpourPServerBrpcClosure(size_t num, PServerCallBack callback)
: PServerClosure(callback) {
_waiting_num = num;
_cntls.resize(num);
_requests.resize(num);
_responses.resize(num);
for (size_t i = 0; i < num; ++i) {
_cntls[i].reset(new brpc::Controller());
}
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
}
}
PsRequestMessage *request(size_t i) { return &_requests[i]; }
PsResponseMessage *response(size_t i) { return &_responses[i]; }
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
int check_response(size_t request_idx, int cmd_id) { return 1; }
int check_save_response(size_t request_idx, int cmd_id) { return 1; }
private:
std::atomic<int32_t> _waiting_num;
std::vector<PsRequestMessage> _requests;
std::vector<PsResponseMessage> _responses;
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
};
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include <limits>
#include <memory>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
framework::proto::VarType::Type VarMessageToVarType(
VariableMessage::Type type) {
switch (type) {
case VariableMessage::FP32:
return framework::proto::VarType::FP32; // NOLINT
case VariableMessage::FP64:
return framework::proto::VarType::FP64; // NOLINT
case VariableMessage::INT32:
return framework::proto::VarType::INT32; // NOLINT
case VariableMessage::INT64:
return framework::proto::VarType::INT64; // NOLINT
case VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"VarMessageToVarType:Unsupported type %d", type));
}
}
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* request, butil::IOBuf* iobuf) {
// 1. message_name
request->set_message_name(message_name);
// 2. var_names
for (auto& send_var_name : send_var_name_val) {
request->add_send_var_names(send_var_name);
}
for (auto& recv_var_name : recv_var_name_val) {
request->add_recv_var_names(recv_var_name);
}
// 3. VarMessage
for (auto& send_var_name : send_var_name_val) {
auto* send_var_msg = request->add_var_messages();
butil::IOBuf temp_iobuf;
send_var_msg->set_varname(send_var_name);
framework::Variable* var = scope->FindVar(send_var_name);
if (var->IsType<framework::LoDTensor>()) {
SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf);
} else if (var->IsType<framework::SelectedRows>()) {
SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf);
}
iobuf->append(temp_iobuf);
}
}
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
for (auto& each : lod) {
VarMsg::LodData* lod_inner = var_msg->add_lod();
for (auto& d : each) {
lod_inner->add_lod_data(d);
}
}
}
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
framework::SelectedRows* slr = var->GetMutable<framework::SelectedRows>();
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());
if (platform::is_cpu_place(tensor->place())) {
memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t));
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
&(*rows)[0], rows->size() * sizeof(int64_t), stream);
#endif
}
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}
// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(tensor->data<void>()),
data_len);
} else {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
iobuf->append(reinterpret_cast<const char*>(&data_len), 8);
iobuf->append(reinterpret_cast<const char*>(temp_ptr), data_len);
delete[] temp_ptr;
#endif
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope) {
butil::IOBufBytesIterator io_buffer_itr(*iobuf);
// size_t shard_buffer_remain = res_io_buffer.size();
for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size();
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->FindVar(msg.varname());
PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
}
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(framework::make_ddim(vec_dim));
framework::LoD lod;
for (int i = 0; i < msg.lod_level(); ++i) {
framework::Vector<size_t> v;
for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) {
v.push_back(msg.lod(i).lod_data(j));
}
lod.push_back(v);
}
tensor->set_lod(lod);
void* tensor_data =
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
unsigned long data_len;
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), (void*)temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
delete[] temp_ptr;
#endif
}
}
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& io_buffer_itr,
const platform::DeviceContext& ctx) {
const auto place = ctx.GetPlace();
auto* slr = var->GetMutable<framework::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.slr_height());
memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t));
slr->set_rows(tmp_rows);
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
vec_dim.push_back(x);
}
tensor->Resize(framework::make_ddim(vec_dim));
void* tensor_data =
tensor->mutable_data(place, VarMessageToVarType(msg.data_type()));
// IO Buffer
if (platform::is_cpu_place(place)) {
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(tensor_data, data_len);
} else if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
char* temp_ptr =
new char[tensor->numel() * framework::SizeOfType(tensor->type())];
unsigned long data_len;
io_buffer_itr.copy_and_forward((void*)(&data_len), 8);
io_buffer_itr.copy_and_forward(temp_ptr, data_len);
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data,
platform::CPUPlace(), temp_ptr,
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
delete[] temp_ptr;
#endif
}
}
} // namespace distributed
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "brpc/channel.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
namespace grpc {
class ByteBuffer;
} // namespace grpc
namespace paddle {
namespace framework {
class Scope;
class Variable;
} // namespace framework
namespace platform {
class DeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
const std::vector<std::string>& send_var_name_val,
const std::vector<std::string>& recv_var_name_val,
const platform::DeviceContext& ctx, const framework::Scope* scope,
MultiVarMsg* var_msg, butil::IOBuf* iobuf);
void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf);
void SerializeSelectedRows(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request,
butil::IOBuf* iobuf);
// Deserialize for Server
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
framework::Scope* scope);
// Deserialize for Client
void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
const butil::IOBuf* iobuf,
const platform::DeviceContext& ctx,
const framework::Scope* scope);
void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx);
} // namespace distributed
} // namespace paddle
此差异已折叠。
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <numeric>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/communicator_common.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/ps_client.h"
DECLARE_bool(communicator_is_sgd_optimizer);
namespace paddle {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string &var_name,
const std::vector<std::shared_ptr<Variable>> &vars,
Scope *scope, bool merge_add = true) {
PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument(
"vector vars are empty."));
auto cpu_place = platform::CPUPlace();
auto &var0 = vars[0];
auto *out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto &var : vars) {
auto &var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
var_t.dims(), dims,
platform::errors::InvalidArgument("vars should have the same dims."));
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext, T>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<T>::Flatten(*out_t);
for (auto &var : vars) {
auto &in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto &slr0 = var0->Get<framework::SelectedRows>();
auto *out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows *> inputs;
inputs.reserve(vars.size());
for (auto &var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
auto dev_ctx = paddle::platform::CPUDeviceContext();
if (merge_add) {
paddle::operators::math::scatter::MergeAdd<
paddle::platform::CPUDeviceContext, T>
merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
paddle::operators::math::scatter::MergeAverage<
paddle::platform::CPUDeviceContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!",
var0->Type()));
}
}
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
using RecvCtxMap = std::unordered_map<uint64_t, std::vector<std::string>>;
using SparseValue = std::unordered_map<int64_t, std::vector<float>>;
class Communicator {
public:
Communicator();
explicit Communicator(const std::map<std::string, std::string> &envs_) {
VLOG(0) << "Communicator Init Envs";
for (auto &iter : envs_) {
envs[iter.first] = iter.second;
VLOG(0) << iter.first << ": " << iter.second;
}
barrier_table_id_ = std::stoi(envs.at("barrier_table_id"));
trainer_id_ = std::stoi(envs.at("trainer_id"));
trainers_ = std::stoi(envs.at("trainers"));
}
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope);
// 2. send dense param
virtual void RpcSendDenseParam(const std::vector<std::string> &varnames,
int table_id, const Scope &scope);
// 3. send dense grad
virtual void RpcSendDense(const CommContext &ctx, const Scope &scope);
// 4. send sparse grad
virtual void RpcSendSparse(const std::string &var_name, int table_id,
const Scope &scope);
// 5. send sparse param
virtual void RpcSendSparseParam(const std::string &varname, int table_id,
const Scope &scope);
// 6. recv sparse param
virtual void RpcRecvSparse(const std::string &varname, int table_id,
Scope *scope);
virtual ~Communicator() {}
virtual void RpcProfilerControl();
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
virtual void Clean() {}
virtual bool Check(const int table_id) = 0;
virtual bool Check(const std::vector<std::string> &var_tables) = 0;
virtual void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) = 0;
virtual void RecvNoBarrier() {}
virtual void Barrier() {}
virtual void BarrierWithTable(uint32_t barrier_type) {
auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type);
rets.wait();
}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
virtual void InitEnvs() = 0;
virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) {}
static Communicator *GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator *InitInstance(
const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list, Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>, send_ctx,
recv_ctx, dist_desc, host_sign_list, recv_scope,
std::ref(envs));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap &send_ctx,
const RecvCtxMap &recv_ctx,
const std::string &dist_desc,
const std::vector<std::string> &host_sign_list,
Scope *recv_scope,
const std::map<std::string, std::string> &envs) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T(std::ref(envs)));
communicator_->InitEnvs();
communicator_->InitBrpcClient(dist_desc, host_sign_list);
communicator_->InitImpl(send_ctx, recv_ctx, recv_scope);
}
}
PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return _worker_ptr;
}
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
protected:
bool running_ = false;
bool waiting_ = true;
bool flushing_ = false;
bool do_server_profiler_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
std::unordered_map<std::string, std::string> envs;
// 计算每个shard 对 dense的存储量
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
}
void init_gflag(const std::string &gflags);
paddle::distributed::PSParameter _ps_param;
paddle::distributed::PaddlePSEnvironment _ps_env;
int servers_ = 0;
int trainers_;
int trainer_id_ = 0;
int barrier_table_id_ = 0;
RpcCtxMap send_varname_to_ctx_;
RecvCtxMap recv_varname_to_ctx_;
Scope *recv_scope_; // should be global scope
std::unique_ptr<Scope> xpu_temp_scope_;
std::atomic<uint32_t> _async_call_num{0};
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() : Communicator() {}
explicit AsyncCommunicator(const std::map<std::string, std::string> &envs)
: Communicator(envs) {}
~AsyncCommunicator();
void InitEnvs() {
independent_recv_ = static_cast<bool>(
std::stoi(envs.at("communicator_independent_recv_thread")));
min_send_grad_num_before_recv_ =
std::stoi(envs.at("communicator_min_send_grad_num_before_recv"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
}
void Start() override;
void Stop() override;
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
virtual void MainThread();
virtual void RecvThread();
virtual bool Check(const int table_id);
virtual bool Check(const std::vector<std::string> &var_tables);
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
virtual void SendByCommunicator();
virtual void SendGlobalStep(int batches) {}
virtual void RecvByCommunicator();
virtual void RecvNoBarrier();
virtual int BatchesCounter() { return 1; }
virtual void BarrierSend() {}
virtual void BarrierRecv() {}
virtual void BarrierWeakUp() {}
protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
int min_send_grad_num_before_recv_;
int thread_pool_size_;
int max_merge_var_num_;
int send_wait_times_;
int send_queue_size_;
bool need_global_step_ = false;
bool independent_recv_ = true;
int parallel_task_nums_ = 0;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::unique_ptr<std::thread> recv_thread_{nullptr};
std::unique_ptr<Scope> send_scope_; // an independent scope
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class HalfAsyncCommunicator : public AsyncCommunicator {
public:
HalfAsyncCommunicator() {}
explicit HalfAsyncCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "HalfAsyncCommunicator Initialized";
}
void MainThread() override;
void SendByCommunicator() override;
void Clean() override;
void Barrier() override;
void BarrierTriggerDecrement() override;
void BarrierTriggerReset(int initial_val) override;
int BatchesCounter();
void BarrierWeakUp();
protected:
// mutex for Wait for barrier
std::mutex barrier_mutex_;
std::condition_variable barrier_cond_;
std::atomic<int64_t> barrier_trigger_{0};
std::atomic<int64_t> barrier_counter_{0};
};
class SyncCommunicator : public HalfAsyncCommunicator {
public:
SyncCommunicator() : HalfAsyncCommunicator() {}
explicit SyncCommunicator(const std::map<std::string, std::string> &envs)
: HalfAsyncCommunicator(envs) {}
void InitEnvs() {
// enfore to recv after send
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size"));
need_global_step_ =
static_cast<bool>(std::stoi(envs.at("need_global_step")));
VLOG(0) << "SyncCommunicator Initialized";
}
void BarrierSend();
void BarrierRecv();
private:
std::vector<std::string> pserver_endpoints_{};
};
class GeoCommunicator : public AsyncCommunicator {
public:
GeoCommunicator() : AsyncCommunicator() {}
explicit GeoCommunicator(const std::map<std::string, std::string> &envs)
: AsyncCommunicator(envs) {}
void InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RecvCtxMap &recv_varname_to_ctx,
Scope *recv_scope) override;
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id);
void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname, std::vector<int64_t> &sparse_ids,
int table_id, int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
void MainThread() override;
void InitEnvs() {
independent_recv_ = false;
min_send_grad_num_before_recv_ = 0;
send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times"));
thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size"));
// id_queue's size
max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num"));
send_queue_size_ = max_merge_var_num_;
VLOG(0) << "GeoCommunicator Initialized";
}
void Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) override;
void SendByCommunicator() { return; }
void SendGlobalStep(int batches) override { return; }
void RecvByCommunicator() override { return; }
inline std::string GradToParam(const std::string var_name) {
std::string param_name = var_name.substr(0, var_name.size() - 5);
return param_name;
}
inline std::string SplitedGradToParam(const std::string delta_name) {
// delta_name: emb.delta0
auto pos = delta_name.find(".block");
std::string param_name = delta_name.substr(0, pos);
return param_name;
}
private:
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
std::unordered_map<
std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<std::vector<int64_t>>>>>
sparse_id_queues_;
};
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/env.h"
namespace paddle {
namespace distributed {} // namespace distributed
} // namespace paddle
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/service/heter_client.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/timer.h"
DECLARE_int32(rpc_deadline);
namespace paddle {
namespace distributed {
DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms");
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;
void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
}
}
void HeterClient::Stop() {
running_ = false;
if (!is_initialized_) {
VLOG(0) << "HeterClient is not inited, do nothing";
} else {
if (main_thread_) {
auto status = StopHeterWorker();
status.wait();
main_thread_->join();
main_thread_.reset(nullptr);
}
VLOG(1) << "HeterClient Stop Done";
}
}
void HeterClient::RpcProfilerControl() {
if (trainer_id_ == 0) {
if (!do_server_profiler_ && platform::IsProfileEnabled()) {
// send profiler start flag
do_server_profiler_ = true;
auto start_status = StartProfiler();
start_status.wait();
} else if (do_server_profiler_ && !platform::IsProfileEnabled()) {
// send profiler end flag
auto stop_status = StopProfiler();
stop_status.wait();
do_server_profiler_ = false;
}
}
}
void HeterClient::CreateClient2XpuConnection() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connection_type = "single";
options.timeout_ms = pserver_timeout_ms;
xpu_channels_.resize(xpu_list_.size());
for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterServer channel init fail";
}
}
}
void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;
VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
// Todo: get correct channel
int num = trainer_id_ % xpu_channels_.size();
brpc::Controller cntl;
cntl.set_timeout_ms(pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE(
cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
ctx, p_scope);
}
std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size();
paddle::distributed::DownpourBrpcClosure* closure =
new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(trainer_id_);
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> HeterClient::StartProfiler() {
return SendCmd(-1, PS_START_PROFILER, {});
}
std::future<int32_t> HeterClient::StopProfiler() {
return SendCmd(-1, PS_STOP_PROFILER, {});
}
} // end namespace distributed
} // end namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -16,3 +16,16 @@ cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_acc
set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS})
# open it until CI support brpc
return()
set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册