diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index ee9037dec1a5d02aa74978f28034ef3ca1ab4182..e99b8b76534369c81e79b274bfdd18fb0e73b394 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -14,3 +14,17 @@ endif() add_subdirectory(table) add_subdirectory(test) + +# open it until CI support brpc +return() + +add_subdirectory(service) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(fleet.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_library(fleet + SRCS fleet.cc + DEPS framework_proto ps_framework_proto ps_service variable_helper scope op_registry fs shell ${RPC_DEPS}) + +target_link_libraries(fleet z) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc new file mode 100644 index 0000000000000000000000000000000000000000..92211a72e748eb3ca7555a5a68707ad5a52dc4bf --- /dev/null +++ b/paddle/fluid/distributed/fleet.cc @@ -0,0 +1,585 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/fleet.h" +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::ProgramDesc; +using framework::VarDesc; +using framework::Variable; + +const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100; +std::shared_ptr FleetWrapper::s_instance_ = NULL; +bool FleetWrapper::is_initialized_ = false; + +std::shared_ptr FleetWrapper::pserver_ptr_ = NULL; + +void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, + int connect_timeout_ms, + int max_retry) { + client2client_request_timeout_ms_ = request_timeout_ms; + client2client_connect_timeout_ms_ = connect_timeout_ms; + client2client_max_retry_ = max_retry; +} + +void FleetWrapper::LoadSparseOnServer(const std::string& path, + const std::string& meta, + uint32_t table_id) { + VLOG(3) << "load sparse table " << table_id << " with " << path << " meta " + << meta; + pserver_ptr_->_server_ptr->table(table_id)->load(path, meta); +} + +void FleetWrapper::InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, + int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init server"; + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(), + index); + is_initialized_ = true; + } else { + VLOG(3) << "Server can be initialized only once"; + } +} + +// void FleetWrapper::InitWorker( +// const std::string& dist_desc, const std::vector& +// host_sign_list, Scope* scope, const RpcCtxMap& send_ctx, const +// std::unordered_map>& +// dense_varnames, +// const std::map& envs, int node_num, int index) +// { +// if (!is_initialized_) { +// VLOG(3) << "Going to init worker"; + +// Communicator::InitInstance( +// send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + +// pserver_ptr_ = std::shared_ptr( +// new paddle::distributed::PSCore()); +// pserver_ptr_->init_worker(dist_desc, _regions, +// const_cast(host_sign_list.data()), +// node_num, index); +// is_initialized_ = true; +// } else { +// VLOG(3) << "Worker can be initialized only once"; +// } +// } + +void FleetWrapper::InitWorker( + const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, int index) { + if (!is_initialized_) { + VLOG(3) << "Going to init worker"; + + Communicator::InitInstance( + send_ctx, dense_varnames, dist_desc, host_sign_list, scope, envs); + + pserver_ptr_ = std::shared_ptr( + new paddle::distributed::PSCore()); + pserver_ptr_->init_worker(dist_desc, _regions, &host_sign_list, node_num, + index); + is_initialized_ = true; + } else { + VLOG(3) << "Worker can be initialized only once"; + } +} + +void FleetWrapper::StopServer() { + VLOG(3) << "Going to stop server"; + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->stop_server(); + status.wait(); +} + +void FleetWrapper::FinalizeWorker() { + VLOG(3) << "Going to finalize worker"; + pserver_ptr_->finalize_worker(); +} + +void FleetWrapper::BarrierWithTable(uint32_t barrier_type) { + VLOG(3) << "Going to Barrier worker"; + auto* communicator = Communicator::GetInstance(); + communicator->BarrierWithTable(barrier_type); +} + +uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { + VLOG(3) << "Going to run server with ip " << ip << " port " << port; + auto ret = pserver_ptr_->run_server(ip, port); + return ret; +} + +std::vector FleetWrapper::GetClientsInfo() { + VLOG(3) << "Going to get client info"; + return pserver_ptr_->get_client_info(); + return std::vector(); +} + +void FleetWrapper::CreateClient2ClientConnection() { + VLOG(3) << "Going to create client2client connection"; + pserver_ptr_->create_client2client_connection( + client2client_request_timeout_ms_, client2client_connect_timeout_ms_, + client2client_max_retry_); +} + +std::future FleetWrapper::PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim) { + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (auto name : var_names) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + return pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); +} + +void FleetWrapper::PullSparseVarsSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, std::vector* fea_keys, + std::vector>* fea_values, int fea_value_dim, + const std::vector& var_emb_names) { + std::vector> pull_sparse_status; + pull_sparse_status.resize(0); + fea_keys->clear(); + fea_keys->resize(0); + fea_keys->reserve(MAX_FEASIGN_NUM); + for (size_t var_index = 0; var_index < var_names.size(); ++var_index) { + const std::string& name = var_names[var_index]; + Variable* var = scope.FindVar(name); + if (var == nullptr) { + continue; + } + LoDTensor* tensor = var->GetMutable(); + CHECK(tensor != nullptr) << "tensor of var " << name << " is null"; + int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + + // skip slots which do not have embedding + const std::string& emb_name = var_emb_names[var_index]; + Variable* emb_var = scope.FindVar(emb_name); + if (emb_var == nullptr) { + continue; + } + + for (auto i = 0u; i < len; ++i) { + if (ids[i] == 0u) { + continue; + } + fea_keys->push_back(static_cast(ids[i])); + } + } + fea_values->resize(fea_keys->size() + 1); + for (auto& t : *fea_values) { + t.resize(fea_value_dim); + } + std::vector pull_result_ptr; + for (auto& t : *fea_values) { + pull_result_ptr.push_back(t.data()); + } + auto status = pserver_ptr_->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size()); + pull_sparse_status.push_back(std::move(status)); + for (auto& t : pull_sparse_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + } +} + +void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, + platform::Place place, + std::vector* inputs, + std::vector* outputs) { + std::vector fea_keys; + std::vector pull_result_ptr; + fea_keys.reserve(MAX_FEASIGN_NUM / 100); + pull_result_ptr.reserve(MAX_FEASIGN_NUM / 100); + std::vector init_value(fea_dim, 0); + framework::LoDTensor* output = nullptr; + float* output_data = nullptr; + size_t output_index = -1; + size_t output_len = 0; + for (size_t index = 0; index < inputs->size(); ++index) { + const framework::LoDTensor* tensor = inputs->at(index); + const int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + for (size_t i = 0; i < len; ++i, output_len += fea_dim) { + if (!output || output_len == size_t(output->numel())) { + ++output_index; + CHECK(output_index < outputs->size()); // NOLINT + output = outputs->at(output_index); + output->set_lod(tensor->lod()); + output_data = output->mutable_data(place); + output_len = 0; + CHECK(output->numel() % fea_dim == 0); // NOLINT + CHECK(output_data != nullptr); // NOLINT + } + uint64_t real_id = static_cast(ids[i]); + if (real_id == padding_id) { + memcpy(output_data + output_len, init_value.data(), + sizeof(float) * fea_dim); + continue; + } + fea_keys.push_back(real_id); + pull_result_ptr.push_back(output_data + output_len); + } + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_sparse( + pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size()); + status.wait(); + auto ret = status.get(); + if (ret != 0) { + LOG(ERROR) << "fleet pull sparse failed, status[" << ret << "]"; + sleep(sleep_seconds_before_fail_exit_); + } +} + +void FleetWrapper::PullDenseVarsAsync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names, + std::vector>* pull_dense_status, bool in_cpu) { + auto& regions = _regions[tid]; + regions.clear(); + regions.resize(var_names.size()); + for (auto i = 0u; i < var_names.size(); ++i) { + std::string varname = var_names[i]; + if (!in_cpu) { + varname = var_names[i] + "pin"; + } + Variable* var = scope.FindVar(varname); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions[i] = std::move(reg); + } + auto status = pserver_ptr_->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + pull_dense_status->push_back(std::move(status)); +} + +void FleetWrapper::PullDenseVarsSync( + const Scope& scope, const uint64_t tid, + const std::vector& var_names) { + auto& regions = _regions[tid]; + regions.clear(); + regions.reserve(var_names.size()); + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto status = communicator->_worker_ptr->pull_dense(regions.data(), + regions.size(), tid); + status.wait(); +} + +void FleetWrapper::PushDenseParamSync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names) { + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->mutable_data(place); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + auto* communicator = Communicator::GetInstance(); + auto push_status = communicator->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + CHECK(status == 0) << "push dense param failed, status[" << status << "]"; +} + +void FleetWrapper::PushDenseVarsSync( + Scope* scope, const uint64_t table_id, + const std::vector& var_names) {} + +void FleetWrapper::PushDenseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, float scale_datanorm, + int batch_size) { + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(var_names, scope); +} + +void FleetWrapper::PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::string& grad_varname, + std::vector>* push_sparse_status) { + std::vector varnames; + varnames.push_back(grad_varname); + + auto* communicator = Communicator::GetInstance(); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + communicator->Send(varnames, scope); +} + +void FleetWrapper::PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, const int batch_size, + const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm) { + // not support + return; +} + +void FleetWrapper::PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, + std::vector* outputs) { + // not support + return; +} + +void FleetWrapper::LoadModel(const std::string& path, const int mode) { + auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model from path:" << path << " failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::LoadModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto ret = + pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "load model of table id: " << table_id + << ", from path: " << path << " failed"; + } +} + +void FleetWrapper::SaveModel(const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->save(path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "save model failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::SaveModelOneTable(const uint64_t table_id, + const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = + communicator->_worker_ptr->save(table_id, path, std::to_string(mode)); + ret.wait(); + if (ret.get() != 0) { + LOG(ERROR) << "save model of table id: " << table_id + << ", to path: " << path << " failed"; + } +} + +void FleetWrapper::PrintTableStat(const uint64_t table_id) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->print_table_stat(table_id); + ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "print table stat failed"; + } +} + +void FleetWrapper::ShrinkSparseTable(int table_id) { + auto ret = pserver_ptr_->_worker_ptr->shrink(table_id); + ret.wait(); +} + +void FleetWrapper::ClearModel() { + auto ret = pserver_ptr_->_worker_ptr->clear(); + ret.wait(); +} + +void FleetWrapper::ClearOneTable(const uint64_t table_id) { + auto ret = pserver_ptr_->_worker_ptr->clear(table_id); + ret.wait(); +} + +void FleetWrapper::ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, + float decay, int emb_dim) { + std::vector regions; + for (std::string& name : var_list) { + if (name.find("batch_sum") != std::string::npos) { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + VLOG(0) << "prepare shrink dense batch_sum"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + + // show_batch_sum += N * log(decay) + std::string size_name = name; + size_name.replace(size_name.find("batch_sum"), size_name.length(), + "batch_size"); + Variable* var_size = scope->FindVar(size_name); + CHECK(var_size != nullptr) << "var[" << size_name << "] not found"; + VLOG(3) << "shrink dense batch_sum: " << name << ", " << size_name; + float* g_size = var_size->GetMutable()->data(); + + for (int k = 0; k < tensor->numel(); k += emb_dim) { + g[k] = g[k] + g_size[k] * log(decay); + } + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } else { + Variable* var = scope->FindVar(name); + CHECK(var != nullptr) << "var[" << name << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->data(); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto push_status = pserver_ptr_->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); + push_status.wait(); + auto status = push_status.get(); + if (status != 0) { + // PADDLE_THORW(platform::errors::Fatal( + // "push shrink dense param failed, status is [%d].", status)); + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +void FleetWrapper::ClientFlush() { + auto ret = pserver_ptr_->_worker_ptr->flush(); + ret.wait(); +} + +int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, + MsgHandlerFunc handler) { + VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; + VLOG(3) << "pserver_ptr_=" << pserver_ptr_; + VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; + return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, + handler); +} + +std::future FleetWrapper::SendClientToClientMsg( + int msg_type, int to_client_id, const std::string& msg) { + return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, + to_client_id, msg); +} + +std::default_random_engine& FleetWrapper::LocalRandomEngine() { + struct engine_wrapper_t { + std::default_random_engine engine; + + engine_wrapper_t() { + struct timespec tp; + clock_gettime(CLOCK_REALTIME, &tp); + double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, (uint64_t)(cur_time * 1000)}; + engine.seed(sseq); + } + }; + thread_local engine_wrapper_t r; + return r.engine; +} + +size_t FleetWrapper::GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod) { + if (level >= lod.size() - 1) { + return end - start; + } + size_t ret = 0; + for (size_t i = start; i < end - 1; ++i) { + size_t pos1 = lod[level][i]; + size_t pos2 = lod[level][i + 1]; + ret += GetAbsoluteSum(pos1, pos2, level + 1, lod); + } + return ret; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h new file mode 100644 index 0000000000000000000000000000000000000000..7f106fafbf2e2e3cb8fd4e7769d97314ee2f31e5 --- /dev/null +++ b/paddle/fluid/distributed/fleet.h @@ -0,0 +1,246 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/distributed/service/service.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/io/shell.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::Scope; +using framework::SelectedRows; +using framework::Variable; + +using RpcCtxMap = std::unordered_map; + +class FleetWrapper { + public: + virtual ~FleetWrapper() {} + FleetWrapper() { + scale_sparse_gradient_with_batch_size_ = true; + // trainer sleep some time for pserver core dump + sleep_seconds_before_fail_exit_ = 300; + // pserver request server timeout ms + client2client_request_timeout_ms_ = 500000; + // pserver connect server timeout_ms + client2client_connect_timeout_ms_ = 10000; + // pserver request max retry + client2client_max_retry_ = 3; + } + + // set client to client communication config + void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, + int max_retry); + + // Pull sparse variables from server in sync mode + // Param: scope, table_id, var_names, fea_keys, fea_dim, var_emb_names + // Param: fea_values + void PullSparseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, + int fea_dim, + const std::vector& var_emb_names); + + // Pull sparse variables from server in async mode + // Param: scope, table_id, var_names, fea_keys, fea_dim + // Param: fea_values std::future + std::future PullSparseVarsAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector* fea_keys, + std::vector>* fea_values, int fea_dim); + + // Pull sparse variables from server in sync mode + // pull immediately to tensors + void PullSparseToTensorSync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, platform::Place place, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // pull dense variables from server in sync mod + // Param: scope, table_id, var_names + // Param: void + void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + // pull dense variables from server in async mod + // Param: scope, table_id, var_names + // Param: pull_dense_status + void PullDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* pull_dense_status, + bool in_cpu); + + // push dense parameters(not gradients) to server in sync mode + void PushDenseParamSync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names); + + void PushDenseVarsAsync(const Scope& scope, const uint64_t table_id, + const std::vector& var_names, + std::vector>* push_sparse_status, + float scale_datanorm, int batch_size); + + // push dense variables to server in sync mode + void PushDenseVarsSync(Scope* scope, const uint64_t table_id, + const std::vector& var_names); + + void PushSparseVarsAsync( + const Scope& scope, const uint64_t table_id, const std::string& grad, + std::vector>* push_sparse_status); + // This is specially designed for click/show stats in server + // Param: scope, table_id, fea_keys, fea_labels, sparse_key_names, + // sparse_grad_names, batch_size, use_cvm, dump_slot + // Param: push_values, push_sparse_status + void PushSparseVarsWithLabelAsync( + const Scope& scope, const uint64_t table_id, + const std::vector& fea_keys, + const std::vector& fea_labels, + const std::vector& sparse_key_names, + const std::vector& sparse_grad_names, const int emb_dim, + std::vector>* push_values, + std::vector>* push_sparse_status, + const int batch_size, const bool use_cvm, const bool dump_slot, + std::vector* sparse_push_keys, const bool no_cvm); + + // Push sparse variables to server in async mode + void PushSparseFromTensorWithLabelAsync( + const Scope& scope, const uint64_t table_id, int fea_dim, + uint64_t padding_id, bool scale_sparse, const std::string& accesor, + const std::string& click_name, platform::Place place, + const std::vector& input_names, + std::vector* inputs, // NOLINT + std::vector* outputs); // NOLINT + + // Push sparse variables to server in Async mode + // Param: scope, table_id, fea_keys, sparse_grad_names + // Param: push_values, push_sparse_status + + // init server + void LoadSparseOnServer(const std::string& path, const std::string& meta, + uint32_t table_id); + // init server + // void InitServer(const std::string& dist_desc, + // const std::vector& host_sign_list, int index); + void InitServer(const std::string& dist_desc, + const std::vector& host_sign_list, int index); + // init trainer + void InitWorker(const std::string& dist_desc, + const std::vector& host_sign_list, Scope* scope, + const RpcCtxMap& send_ctx, + const std::unordered_map>& + dense_varnames, + const std::map& envs, int node_num, + int index); + + // stop server + void StopServer(); + // finalize worker to make worker can be stop + void FinalizeWorker(); + // run server with ip port + uint64_t RunServer(const std::string& ip, uint32_t port); + // get client info + std::vector GetClientsInfo(); + // create client to client connection + void CreateClient2ClientConnection(); + // flush all push requests + void ClientFlush(); + + // barrier with barrier table + void BarrierWithTable(uint32_t barrier_type); + + void PrintTableStat(const uint64_t table_id); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModel(const std::string& path, const int mode); + // mode = 0, load all feature + // mode = 1, load delta feature, which means load diff + void LoadModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModel(const std::string& path, const int mode); + // mode = 0, save all feature + // mode = 1, save delta feature, which means save diff + void SaveModelOneTable(const uint64_t table_id, const std::string& path, + const int mode); + // clear all models, release their memory + void ClearModel(); + // clear one table + void ClearOneTable(const uint64_t table_id); + // shrink sparse table + void ShrinkSparseTable(int table_id); + // shrink dense table + void ShrinkDenseTable(int table_id, Scope* scope, + std::vector var_list, float decay, + int emb_dim); + + typedef std::function MsgHandlerFunc; + // register client to client communication + int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); + // send client to client message + std::future SendClientToClientMsg(int msg_type, int to_client_id, + const std::string& msg); + + // FleetWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::distributed::FleetWrapper()); + } + return s_instance_; + } + // this performs better than rand_r, especially large data + std::default_random_engine& LocalRandomEngine(); + + static std::shared_ptr pserver_ptr_; + + private: + static std::shared_ptr s_instance_; + size_t GetAbsoluteSum(size_t start, size_t end, size_t level, + const framework::LoD& lod); + + protected: + static bool is_initialized_; + std::map> _regions; + bool scale_sparse_gradient_with_batch_size_; + int32_t sleep_seconds_before_fail_exit_; + int client2client_request_timeout_ms_; + int client2client_connect_timeout_ms_; + int client2client_max_retry_; + DISABLE_COPY_AND_ASSIGN(FleetWrapper); +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/CMakeLists.txt b/paddle/fluid/distributed/service/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c767ad2b3fa6b0462a379795005c2bddf377516 --- /dev/null +++ b/paddle/fluid/distributed/service/CMakeLists.txt @@ -0,0 +1,40 @@ +set(BRPC_SRCS ps_client.cc server.cc) +set_source_files_properties(${BRPC_SRCS}) + +set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog) + +brpc_library(sendrecv_rpc SRCS + ${BRPC_SRCS} + PROTO sendrecv.proto + DEPS ${BRPC_DEPS} ) + +set_property(GLOBAL PROPERTY RPC_DEPS sendrecv_rpc ${BRPC_DEPS} string_helper) + +get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) + +set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + +set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + + +cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS}) +cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS}) + +cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS}) +cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) + +cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS}) +cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS}) + +cc_library(brpc_utils SRCS brpc_utils.cc DEPS ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) +cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc9d017532dff0bb5d17fd65fd231aadeccefcb8 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -0,0 +1,879 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "Eigen/Dense" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" + +const static int max_port = 65535; + +DEFINE_int32(pserver_push_dense_merge_limit, 12, + "limit max push_dense local merge requests"); + +DEFINE_int32(pserver_push_sparse_merge_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_pull_dense_limit, 12, + "limit max push_sparse local merge requests"); + +DEFINE_int32(pserver_async_push_dense_interval_ms, 10, + "async push_dense to server interval"); + +DEFINE_int32(pserver_async_push_sparse_interval_ms, 10, + "async push_sparse to server interval"); + +DEFINE_bool(pserver_scale_gradient_by_merge, false, + "scale dense gradient when merged"); + +DEFINE_int32(pserver_communicate_compress_type, 0, + "none:0 snappy:1 gzip:2 zlib:3 lz4:4"); + +DEFINE_int32(pserver_max_async_call_num, 13, + "max task num in async_call_server"); + +DEFINE_int32(pserver_timeout_ms, 500000, "pserver request server timeout_ms"); + +DEFINE_int32(pserver_connect_timeout_ms, 10000, + "pserver connect server timeout_ms"); + +DEFINE_int32(pserver_sparse_merge_thread, 1, "pserver sparse merge thread num"); + +namespace paddle { +namespace distributed { + +inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + size_t remind = shard_num % server_num; + size_t local_shard_num = + remind == 0 ? shard_num / server_num : shard_num / server_num + 1; + return (key % shard_num) / local_shard_num; +} + +void DownpourPsClientService::service( + ::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + int ret = _client->handle_client2client_msg( + request->cmd_id(), request->client_id(), request->data()); + response->set_err_code(0); + response->set_err_msg(""); + if (ret != 0) { + response->set_err_code(-1); + response->set_err_msg("handle_client2client_msg failed"); + } +} + +// 启动client端RpcService 用于数据互发等操作 +int32_t BrpcPsClient::start_client_service() { + if (_service.configure(this, _client_id) != 0) { + LOG(ERROR) + << "service initialize failed, service_name:DownpourPsClientService"; + return -1; + } + _server.AddService(&_service, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + int start_port = 8500; + options.num_threads = 24; + + if (_server.Start(butil::my_ip_cstr(), brpc::PortRange(start_port, max_port), + &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed"; + return -1; + } + _env->registe_ps_client(butil::my_ip_cstr(), _server.listen_address().port, + _client_id); + return 0; +} + +int32_t BrpcPsClient::create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry) { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = pserver_connect_timeout_ms; + options.max_retry = max_retry; + + std::vector client_list = _env->get_ps_clients(); + _client_channels.resize(client_list.size()); + std::ostringstream os; + std::string server_ip_port; + for (size_t i = 0; i < client_list.size(); ++i) { + server_ip_port.assign(client_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(client_list[i].port)); + _client_channels[i].reset(new brpc::Channel()); + if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "psclient connect to client:" << server_ip_port + << " Failed!"; + } + os << server_ip_port << ","; + } + LOG(INFO) << "Client connect success:" << os.str(); + return 0; +} + +int32_t BrpcPsClient::initialize() { + _async_call_num = 0; + + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = FLAGS_pserver_timeout_ms; + options.connection_type = "pooled"; + options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms; + options.max_retry = 3; + + std::ostringstream os; + std::string server_ip_port; + std::string client_ip(butil::my_ip_cstr()); + + // 获取server列表,并连接 + std::vector server_list = _env->get_ps_servers(); + _server_channels.resize(server_list.size()); + for (size_t i = 0; i < server_list.size(); ++i) { + server_ip_port.assign(server_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(server_list[i].port)); + for (size_t j = 0; j < _server_channels[i].size(); ++j) { + _server_channels[i][j].reset(new brpc::Channel()); + if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "psclient connect to server:" << server_ip_port + << " Failed!"; + return -1; + } + } + os << server_ip_port << ","; + } + // 启动client探听接口, 并相互建立连接 + start_client_service(); + + _running = true; + _flushing = false; + return 0; +} + +int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) { + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + if (_responses[request_idx].err_code() != 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return 0; +} + +int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) { + uint32_t feasign_size = 0; + if (_cntls[request_idx]->Failed()) { + LOG(ERROR) << "resquest cmd_id:" << cmd_id << " failed, " + "err:" + << _cntls[request_idx]->ErrorText(); + return -1; + } + feasign_size = _responses[request_idx].err_code(); + if (feasign_size < 0) { + LOG(ERROR) << "response ret bad, server_idx:" << request_idx + << "cmd_id:" << cmd_id + << " err_code:" << _responses[request_idx].err_code() + << " err_msg:" << _responses[request_idx].err_msg(); + return -1; + } + return feasign_size; +} + +std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { + std::string data = _responses[request_idx].data(); + return data; +} + +std::future BrpcPsClient::print_table_stat(uint32_t table_id) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, table_id](void *done) { + int ret = 0; + uint64_t feasign_size = 0; + uint64_t mf_size = 0; + paddle::framework::BinaryArchive ar; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PRINT_TABLE_STAT) != 0) { + ret = -1; + break; + } + std::string resp = closure->get_response(i, PS_PRINT_TABLE_STAT); + ar.SetReadBuffer(const_cast(resp.c_str()), resp.length(), + nullptr); + + feasign_size += ar.Get(); + mf_size += ar.Get(); + } + closure->set_promise_value(ret); + std::cout << "table id: " << table_id + << ", feasign size: " << feasign_size + << ", mf size: " << mf_size << std::endl; + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PRINT_TABLE_STAT); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} +std::future BrpcPsClient::send_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::send_save_cmd( + uint32_t table_id, int cmd_id, const std::vector ¶ms) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void *done) { + int ret = 0; + uint32_t feasign_size = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_save_response(i, cmd_id) < 0) { + ret = -1; + break; + } + feasign_size += closure->check_save_response(i, cmd_id); + } + if (ret == 0) { + closure->set_promise_value(feasign_size); + } else { + closure->set_promise_value(ret); + } + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + for (const auto ¶m : params) { + closure->request(i)->add_params(param); + } + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_timeout_ms( + 10800000); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::shrink(uint32_t table_id) { + return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); +} + +std::future BrpcPsClient::load(const std::string &epoch, + const std::string &mode) { + return send_cmd(-1, PS_LOAD_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::load(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::save(const std::string &epoch, + const std::string &mode) { + return send_save_cmd(-1, PS_SAVE_ALL_TABLE, {epoch, mode}); +} +std::future BrpcPsClient::save(uint32_t table_id, + const std::string &epoch, + const std::string &mode) { + return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); +} + +std::future BrpcPsClient::clear() { + return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); +} +std::future BrpcPsClient::clear(uint32_t table_id) { + return send_cmd(table_id, PS_CLEAR_ONE_TABLE, {}); +} + +std::future BrpcPsClient::flush() { + _flushing = true; + std::promise promise; + std::future fut = promise.get_future(); + do { + VLOG(3) << "wait _async_call_num:" << _async_call_num; + usleep(100000); // sleep 100ms wait async end + } while (_async_call_num > 0); + promise.set_value(0); + _flushing = false; + return fut; +} + +void BrpcPsClient::finalize_worker() { + flush(); + _running = false; + _server.Stop(1000); + _server.Join(); +} + +std::future BrpcPsClient::stop_server() { + return send_cmd(-1, PS_STOP_SERVER, {}); +} + +std::future BrpcPsClient::start_profiler() { + return send_cmd(-1, PS_START_PROFILER, {}); +} + +std::future BrpcPsClient::stop_profiler() { + return send_cmd(-1, PS_STOP_PROFILER, {}); +} + +std::future BrpcPsClient::barrier(size_t table_id, + uint32_t barrier_type) { + return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); +} + +std::future BrpcPsClient::pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) { + auto *accessor = table_accessor(table_id); + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(1, [keys, values, accessor](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + uint32_t shard_nums; + if (closure->check_response(0, PS_PULL_GEO_PARAM) != 0) { + ret = -1; + } + auto &res_io_buffer = closure->cntl(0)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + io_buffer_itr.copy_and_forward((void *)(&shard_nums), sizeof(uint32_t)); + keys->resize(shard_nums); + values->resize(shard_nums * accessor->update_dim()); + io_buffer_itr.copy_and_forward((void *)(keys->data()), + sizeof(uint64_t) * shard_nums); + io_buffer_itr.copy_and_forward((void *)(values->data()), + shard_nums * accessor->update_size()); + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + closure->request(0)->set_cmd_id(PS_PULL_GEO_PARAM); + closure->request(0)->set_table_id(table_id); + closure->request(0)->set_client_id(_client_id); + PsService_Stub rpc_stub(get_cmd_channel(pserver_idx)); + closure->cntl(0)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_param( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + // 发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_PARAM); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::pull_dense(Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + // callback 将各shard结果,顺序填入region + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [request_call_num, num_per_shard, regions, region_num, + accessor](void *done) { + int ret = 0; + size_t region_idx = 0; // 当前填充的region偏移 + size_t region_data_idx = 0; // 当前填充的region内data偏移 + auto *closure = (DownpourBrpcClosure *)done; + size_t shard_data_size = num_per_shard * accessor->select_size(); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PULL_DENSE_TABLE) != 0) { + ret = -1; + break; + } + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + size_t shard_buffer_remain = res_io_buffer.size(); + if (shard_buffer_remain != shard_data_size) { + LOG(ERROR) << "expect res_size:" << shard_data_size + << ", but size:" << shard_buffer_remain + << ", ignore this response"; + ret = -1; + break; + } + while (shard_buffer_remain > 0 && region_idx < region_num) { + auto ®ion = regions[region_idx]; + if (region.size - region_data_idx >= shard_buffer_remain) { + // region待填充空间 >= 分片buffer数据, 直接拷贝置入 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), shard_buffer_remain); + region_data_idx += shard_buffer_remain; + shard_buffer_remain = 0; + } else if (region.size - region_data_idx == 0) { + // region填满,切换到下一个region + ++region_idx; + region_data_idx = 0; + } else { + // region不足以容纳所有数据,则能放多少 拷贝多少 + io_buffer_itr.copy_and_forward( + (void *)(region.data + region_data_idx), + region.size - region_data_idx); + shard_buffer_remain -= (region.size - region_data_idx); + ++region_idx; + region_data_idx = 0; + } + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PULL_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&num_per_shard, + sizeof(num_per_shard)); + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) { + auto *accessor = table_accessor(table_id); + size_t request_call_num = _server_channels.size(); + // 1.拆分Region数据到shard中,后续多shard并行拷贝数据 + std::vector> regions_partition(request_call_num); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + size_t shard_data_size = num_per_shard * accessor->update_size(); + size_t current_region_idx = 0; + size_t current_region_data_idx = 0; + for (size_t i = 0; i < request_call_num; ++i) { + size_t shard_data_remain_size = shard_data_size; + while (shard_data_remain_size > 0 && current_region_idx < region_num) { + const auto ®ion = regions[current_region_idx]; + size_t region_remain_size = region.size - current_region_data_idx; + if (shard_data_remain_size >= region_remain_size) { + regions_partition[i].push_back( + Region(region.data + current_region_data_idx, region_remain_size)); + ++current_region_idx; + current_region_data_idx = 0; + shard_data_remain_size -= region_remain_size; + } else { + regions_partition[i].push_back(Region( + region.data + current_region_data_idx, shard_data_remain_size)); + current_region_data_idx += shard_data_remain_size; + shard_data_remain_size = 0; + } + } + } + + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + static const int REGION_ASSIGN_BUFFER_SIZE = 1024 * 10; + static char region_assign_buffer[REGION_ASSIGN_BUFFER_SIZE]; //用于数据补齐 + //开始多shard并行拷贝&请求 + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_PARAM); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto &request_buffer = closure->cntl(i)->request_attachment(); + request_buffer.append((void *)&num_per_shard, sizeof(uint32_t)); + auto ®ion_list = regions_partition[i]; + size_t fill_remain_size = shard_data_size; + for (auto ®ion : region_list) { + fill_remain_size -= region.size; + request_buffer.append((void *)region.data, region.size); + } + //保证各分片数据对齐 + while (fill_remain_size > 0) { + size_t fill_num = fill_remain_size > REGION_ASSIGN_BUFFER_SIZE + ? REGION_ASSIGN_BUFFER_SIZE + : fill_remain_size; + request_buffer.append((void *)region_assign_buffer, fill_num); + fill_remain_size -= fill_num; + } + PsService_Stub rpc_stub(get_dense_channel(i)); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) { + auto *accessor = table_accessor(table_id); + //发送RPC请求 + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + size_t request_call_num = _server_channels.size(); + std::vector> ids; + std::vector> value_ptrs; + ids.resize(request_call_num); + value_ptrs.resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t pserver_idx = keys[i] % request_call_num; + ids[pserver_idx].push_back(keys[i]); + value_ptrs[pserver_idx].push_back(update_values[i]); + } + + for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { + auto kvs = ids[shard_idx]; + auto value_ptr = value_ptrs[shard_idx]; + + size_t kv_size = kvs.size(); + uint32_t value_size = accessor->update_size(); + + // 发送RPC请求 + auto *push_request = closure->request(shard_idx); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&kv_size, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(kv_size * (sizeof(uint64_t) + accessor->update_size())); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); + push_data_ptr += kv_size * sizeof(uint64_t); + + for (int i = 0; i < kv_size; ++i) { + memcpy(push_data_ptr, value_ptr[i], accessor->update_size()); + push_data_ptr += accessor->update_size(); + } + PsService_Stub rpc_stub(get_sparse_channel(shard_idx)); + closure->cntl(shard_idx)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(shard_idx), closure->request(shard_idx), + closure->response(shard_idx), closure); + } + return fut; +} + +std::future BrpcPsClient::push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) { + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + auto *accessor = table_accessor(table_id); + uint32_t num_per_shard = + dense_dim_per_shard(accessor->fea_dim(), request_call_num); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(PS_PUSH_DENSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + auto *push_data = closure->request(i)->mutable_data(); + push_data->clear(); + push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(float)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t)); + memcpy(push_data_ptr + sizeof(uint32_t), + total_send_data + i * num_per_shard, num_per_shard * sizeof(float)); + VLOG(1) << "push_dense_raw_gradient finish memcpy"; + // closure->cntl(i)->set_request_compress_type( + // (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + PsService_Stub rpc_stub(get_dense_channel(i)); + VLOG(1) << "push_dense_raw_gradient get_dense_channel " << i; + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + VLOG(1) << "push_dense_raw_gradient async service " << i; + } + return fut; +} + +std::future BrpcPsClient::pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) { + size_t request_call_num = _server_channels.size(); + + auto shard_sorted_kvs = std::make_shared< + std::vector>>>(); + shard_sorted_kvs->resize(request_call_num); + + for (size_t i = 0; i < num; ++i) { + size_t shard_id = keys[i] % request_call_num; + shard_sorted_kvs->at(shard_id).push_back({keys[i], select_values[i]}); + } + + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->select_size(); + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [shard_sorted_kvs, value_size](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < ids.size(); ++i) { + if (closure->check_response(i, PS_PULL_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + + auto &request_kvs = shard_sorted_kvs->at(i); + auto &res_io_buffer = closure->cntl(i)->response_attachment(); + butil::IOBufBytesIterator io_buffer_itr(res_io_buffer); + uint64_t last_key = UINT64_MAX; + float *last_value_data = NULL; + + for (size_t kv_idx = 0; kv_idx < request_kvs.size(); ++kv_idx) { + auto *kv_pair = &(request_kvs[kv_idx]); + if (kv_pair->first == last_key) { + memcpy((void *)kv_pair->second, (void *)last_value_data, + value_size); + } else { + last_key = kv_pair->first; + last_value_data = kv_pair->second; + if (value_size != + io_buffer_itr.copy_and_forward((void *)(last_value_data), + value_size)) { + LOG(WARNING) << "res data is lack or not in format"; + ret = -1; + break; + } + } + } + } + closure->set_promise_value(ret); + }); + + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + for (size_t i = 0; i < request_call_num; ++i) { + auto &sorted_kvs = shard_sorted_kvs->at(i); + std::sort(sorted_kvs.begin(), sorted_kvs.end(), + [](const std::pair &k1, + const std::pair &k2) { + return k1.first < k2.first; + }); + + uint64_t last_key = UINT64_MAX; + uint32_t kv_request_count = 0; + size_t sorted_kv_size = sorted_kvs.size(); + auto &request_buffer = closure->cntl(i)->request_attachment(); + for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) { + ++kv_request_count; + last_key = sorted_kvs[kv_idx].first; + request_buffer.append((void *)&last_key, sizeof(uint64_t)); + while (kv_idx < sorted_kv_size - 1 && + last_key == sorted_kvs[kv_idx + 1].first) { + ++kv_idx; + } + } + + if (kv_request_count == 0) { + closure->Run(); + } else { + closure->request(i)->set_cmd_id(PS_PULL_SPARSE_TABLE); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + closure->request(i)->add_params((char *)&kv_request_count, + sizeof(uint32_t)); + PsService_Stub rpc_stub(get_cmd_channel(i)); + closure->cntl(i)->set_log_id(butil::gettimeofday_ms()); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + } + return fut; +} + +std::future BrpcPsClient::send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) { + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + if (to_client_id >= _client_channels.size()) { + LOG(FATAL) << "to_client_id is out of range clients, which size is " + << _client_channels.size(); + promise->set_value(-1); + return fut; + } + auto *closure = new DownpourBrpcClosure(1, [msg_type](void *done) { + auto *closure = (DownpourBrpcClosure *)done; + int32_t ret = closure->check_response(0, msg_type + 1000); + closure->set_promise_value(ret); + }); + closure->add_promise(promise); + closure->request(0)->set_cmd_id(msg_type); + closure->request(0)->set_client_id(_client_id); + closure->request(0)->set_data(msg); + PsService_Stub rpc_stub(_client_channels[to_client_id].get()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +std::future BrpcPsClient::push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) { + auto *accessor = table_accessor(table_id); + size_t value_size = accessor->update_size(); + DownpourBrpcClosure *closure = reinterpret_cast(done); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + + // 发送RPC请求 + auto *push_request = closure->request(0); + push_request->set_cmd_id(PS_PUSH_SPARSE_TABLE); + push_request->set_table_id(table_id); + push_request->set_client_id(_client_id); + push_request->add_params((char *)&num, sizeof(uint32_t)); + auto *push_data = push_request->mutable_data(); + push_data->resize(num * (sizeof(uint64_t) + value_size)); + char *push_data_ptr = const_cast(push_data->data()); + memcpy(push_data_ptr, keys, num * sizeof(uint64_t)); + push_data_ptr += num * sizeof(uint64_t); + for (int i = 0; i < num; ++i) { + memcpy(push_data_ptr, update_values[i], value_size); + push_data_ptr += value_size; + } + PsService_Stub rpc_stub(get_sparse_channel(pserver_idx)); + closure->cntl(0)->set_request_compress_type( + (brpc::CompressType)FLAGS_pserver_communicate_compress_type); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h new file mode 100644 index 0000000000000000000000000000000000000000..c07165151507951a6c9023906ac3f3666e1209b3 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -0,0 +1,212 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +namespace paddle { +namespace distributed { + +class DownpourPsClientService : public PsService { + public: + DownpourPsClientService() {} + virtual ~DownpourPsClientService() {} + + virtual int32_t configure(PSClient *client, size_t rank_id) { + _client = client; + _rank = rank_id; + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + protected: + size_t _rank; + PSClient *_client; +}; + +class DownpourBrpcClosure : public PSClientClosure { + public: + DownpourBrpcClosure(size_t num, PSClientCallBack callback) + : PSClientClosure(callback) { + _waiting_num = num; + + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourBrpcClosure() {} + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id); + int check_save_response(size_t request_idx, int cmd_id); + std::string get_response(size_t request_idx, int cmd_id); + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; + +template +struct array_deleter { + void operator()(T *&x) const { delete[] x; } +}; + +class BrpcPsClient : public PSClient { + public: + BrpcPsClient() {} + virtual ~BrpcPsClient() { + // _running = false; + // try { + // _async_push_dense_thread.join(); + // _async_push_sparse_thread.join(); + //} catch (...) { + //} + } + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); + virtual std::future shrink(uint32_t table_id) override; + virtual std::future load(const std::string &epoch, + const std::string &mode) override; + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(const std::string &epoch, + const std::string &mode) override; + + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) override; + + virtual std::future clear() override; + + virtual std::future clear(uint32_t table_id) override; + + virtual std::future stop_server() override; + + virtual std::future start_profiler() override; + virtual std::future stop_profiler() override; + + virtual void finalize_worker() override; + + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id); + + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id); + + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, size_t num); + + virtual std::future print_table_stat(uint32_t table_id); + + virtual std::future barrier(size_t table_id, uint32_t barrier_type); + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx); + + virtual std::future flush(); + + virtual std::future send_client2client_msg( + int msg_type, int to_client_id, const std::string &msg) override; + + private: + virtual int32_t initialize() override; + + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + std::future send_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + std::future send_save_cmd(uint32_t table_id, int cmd_id, + const std::vector ¶m); + + inline brpc::Channel *get_sparse_channel(size_t server_id) { + return _server_channels[server_id][0].get(); + } + inline brpc::Channel *get_dense_channel(size_t server_id) { + return _server_channels[server_id][1].get(); + } + inline brpc::Channel *get_cmd_channel(size_t server_id) { + return _server_channels[server_id][2].get(); + } + + bool _running = false; + bool _flushing = false; + std::atomic _async_call_num; //异步请求计数 + + std::vector> + _client_channels; // client2client + std::vector, 3>> + _server_channels; // client2server + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) override; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) override; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) override; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, + void *done) override; + + virtual size_t get_server_nums() { return _server_channels.size(); } + + private: + int32_t start_client_service(); + + float _mae = 0; + float _mse = 0; + uint16_t _push_times = 0; + brpc::Server _server; + DownpourPsClientService _service; + std::atomic_uint grad_num_{0}; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc new file mode 100644 index 0000000000000000000000000000000000000000..1386e83447567f9e3acfefe8b992a3dbaa045d39 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -0,0 +1,530 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include // NOLINT +#include "Eigen/Dense" +#include "butil/endpoint.h" +#include "iomanip" +#include "paddle/fluid/distributed/table/table.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +int32_t BrpcPsServer::initialize() { + auto &service_config = _config.downpour_server_param().service_param(); + if (!service_config.has_service_class()) { + LOG(ERROR) << "miss service_class in ServerServiceParameter"; + return -1; + } + auto *service = CREATE_CLASS(PsBaseService, service_config.service_class()); + if (service == NULL) { + LOG(ERROR) << "service is unregistered, service_name:" + << service_config.service_class(); + return -1; + } + + _service.reset(service); + if (service->configure(this) != 0 || service->initialize() != 0) { + LOG(ERROR) << "service initialize failed, service_name:" + << service_config.service_class(); + return -1; + } + if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) { + LOG(ERROR) << "service add to brpc failed, service:" + << service_config.service_class(); + return -1; + } + return 0; +} + +uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { + std::unique_lock lock(mutex_); + + std::string ip_port = ip + ":" + std::to_string(port); + VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; + int num_threads = std::thread::hardware_concurrency(); + brpc::ServerOptions options; + options.num_threads = num_threads; + + if (_server.Start(ip_port.c_str(), &options) != 0) { + LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; + return 0; + } + VLOG(0) << "BrpcPsServer::start registe_ps_server"; + _environment->registe_ps_server(ip, port, _rank); + VLOG(0) << "BrpcPsServer::start wait"; + cv_.wait(lock, [&] { return stoped_; }); + + PSHost host; + host.ip = ip; + host.port = port; + host.rank = _rank; + VLOG(0) << "BrpcPsServer::start return host.rank"; + return host.rank; +} + +int32_t BrpcPsServer::port() { return _server.listen_address().port; } + +int32_t PsService::initialize() { + _is_initialize_shard_info = false; + _service_handler_map[PS_STOP_SERVER] = &PsService::stop_server; + _service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table; + _service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table; + _service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table; + _service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table; + _service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param; + _service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat; + _service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param; + _service_handler_map[PS_BARRIER] = &PsService::barrier; + _service_handler_map[PS_START_PROFILER] = &PsService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler; + + // shard初始化,server启动后才可从env获取到server_list的shard信息 + initialize_shard_info(); + + return 0; +} + +#define CHECK_TABLE_EXIST(table, request, response) \ + if (table == NULL) { \ + std::string err_msg("table not found with table_id:"); \ + err_msg.append(std::to_string(request.table_id())); \ + set_response_code(response, -1, err_msg.c_str()); \ + return -1; \ + } + +int32_t PsService::initialize_shard_info() { + if (!_is_initialize_shard_info) { + std::lock_guard guard(_initialize_shard_mutex); + if (_is_initialize_shard_info) { + return 0; + } + size_t shard_num = _server->environment()->get_ps_servers().size(); + auto &table_map = *(_server->table()); + for (auto itr : table_map) { + itr.second->set_shard(_rank, shard_num); + } + _is_initialize_shard_info = true; + } + return 0; +} + +void PsService::service(google::protobuf::RpcController *cntl_base, + const PsRequestMessage *request, + PsResponseMessage *response, + google::protobuf::Closure *done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + if (!request->has_table_id()) { + set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + return; + } + + response->set_err_code(0); + response->set_err_msg(""); + auto *table = _server->table(request->table_id()); + brpc::Controller *cntl = static_cast(cntl_base); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + set_response_code(*response, -1, err_msg.c_str()); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(table, *request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } +} + +int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_dense"); + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 1 for num of dense"); + return 0; + } + uint32_t num = *(const uint32_t *)request.params(0).c_str(); + if (num < 0) { + set_response_code(response, -1, + "PsRequestMessage.datas[0] is invalid, num must >= 0"); + return 0; + } + + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_dense(res_data.data(), num); + + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + + return 0; +} + +int32_t PsService::push_dense_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + push_buffer.resize(0); + push_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_buffer.data()), req_buffer_size); + + uint32_t num = *(const uint32_t *)data; + + const float *values = (const float *)(data + sizeof(uint32_t)); + if (table->push_dense_param(values, num) != 0) { + set_response_code(response, -1, "push_dense_param failed"); + } + return 0; +} + +int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_dense"); + CHECK_TABLE_EXIST(table, request, response) + auto req_buffer_size = request.data().size(); + if (req_buffer_size < 1) { + // set_response_code(response, 0, "push dense data is empty"); + return 0; + } + + /* + Push Content: + |--num--|---valuesData---| + |--4B---|----------------| + */ + uint32_t num = *(const uint32_t *)(request.data().data()); + const float *values = + (const float *)(request.data().data() + sizeof(uint32_t)); + if (table->push_dense(values, num) != 0) { + set_response_code(response, -1, "push_dense failed"); + } + + return 0; +} + +int32_t PsService::barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + + auto trainer_id = request.client_id(); + auto barrier_type = request.params(0); + table->barrier(trainer_id, barrier_type); + return 0; +} + +int32_t PsService::push_sparse_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse_param"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse_param(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse_param error"); + } + return 0; +} + +int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_geo_param"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + + auto trainer_id = request.client_id(); + + std::vector values; + std::vector ids; + table->pull_geo_param(trainer_id, &values, &ids); + + uint32_t num = ids.size(); + cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); + cntl->response_attachment().append((char *)ids.data(), + ids.size() * sizeof(uint64_t)); + cntl->response_attachment().append((char *)values.data(), + values.size() * sizeof(float)); + return 0; +} + +int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->pull_sparse"); + CHECK_TABLE_EXIST(table, request, response) + thread_local std::string push_sparse_request_buffer; + auto &req_io_buffer = cntl->request_attachment(); + auto req_buffer_size = req_io_buffer.size(); + if (req_buffer_size < 1) { + set_response_code(response, -1, "req attachment is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + push_sparse_request_buffer.resize(0); + push_sparse_request_buffer.reserve(req_buffer_size); + const char *data = (const char *)cntl->request_attachment().fetch( + const_cast(push_sparse_request_buffer.data()), req_buffer_size); + /* + Attachment Content: + |---keysData---| + |---8*{num}B---| + */ + const uint64_t *keys = (const uint64_t *)data; + std::vector res_data; + res_data.resize(num * table->value_accesor()->select_size() / sizeof(float)); + table->pull_sparse(res_data.data(), keys, num); + cntl->response_attachment().append((char *)res_data.data(), + res_data.size() * sizeof(float)); + return 0; +} + +int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::RecordEvent record_event("PsService->push_sparse"); + CHECK_TABLE_EXIST(table, request, response) + auto &push_data = request.data(); + if (push_data.size() < 1) { + // set_response_code(response, 0, "push sparse data is empty"); + return 0; + } + if (request.params_size() < 1) { + set_response_code(response, -1, + "PsRequestMessage.params is requeired at " + "least 1 for num of sparse_key"); + return 0; + } + uint32_t num = *(uint32_t *)(request.params(0).c_str()); + /* + Push Content: + |---keysData---|---valuesData---| + |---8*{num}B---|----------------| + */ + const uint64_t *keys = (const uint64_t *)push_data.data(); + const float *values = + (const float *)(push_data.data() + sizeof(uint64_t) * num); + if (table->push_sparse(keys, values, num) != 0) { + set_response_code(response, -1, "push_sparse error"); + } + return 0; +} + +int32_t PsService::print_table_stat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + std::pair ret = table->print_table_stat(); + paddle::framework::BinaryArchive ar; + ar << ret.first << ret.second; + std::string table_info(ar.Buffer(), ar.Length()); + response.set_data(table_info); + + return 0; +} + +int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2 for path & load_param"); + return -1; + } + if (table->load(request.params(0), request.params(1)) != 0) { + set_response_code(response, -1, "table load failed"); + return -1; + } + return 0; +} + +int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (load_one_table(itr.second.get(), request, response, cntl) != 0) { + LOG(ERROR) << "load table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 2, path&mode"); + return -1; + } + table->flush(); + + int32_t feasign_size = 0; + feasign_size = table->save(request.params(0), request.params(1)); + if (feasign_size < 0) { + set_response_code(response, -1, "table save failed"); + return -1; + } + return feasign_size; +} + +int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + int32_t all_feasign_size = 0; + int32_t feasign_size = 0; + + for (auto &itr : table_map) { + feasign_size = save_one_table(itr.second.get(), request, response, cntl); + if (feasign_size < 0) { + LOG(ERROR) << "save table[" << itr.first << "] failed"; + return -1; + } + } + return 0; +} + +int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + if (table->shrink() != 0) { + set_response_code(response, -1, "table shrink failed"); + } + return 0; +} + +int32_t PsService::clear_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->flush(); + table->clear(); + return 0; +} + +int32_t PsService::clear_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->table()); + for (auto &itr : table_map) { + if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { + return -1; + } + } + return 0; +} + +int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto *p_server = _server; + std::thread t_stop([p_server]() { + p_server->stop(); + LOG(INFO) << "Server Stoped"; + }); + t_stop.detach(); + return 0; +} + +int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::DisableProfiler(platform::EventSortingKey::kDefault, + string::Sprintf("server_%s_profile", _rank)); + return 0; +} + +int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + platform::EnableProfiler(platform::ProfilerState::kCPU); + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_ps_server.h b/paddle/fluid/distributed/service/brpc_ps_server.h new file mode 100644 index 0000000000000000000000000000000000000000..0a053848e1eb3c915b6405fcff33b5710c776943 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_ps_server.h @@ -0,0 +1,153 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" + +#include +#include +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class BrpcPsServer : public PSServer { + public: + BrpcPsServer() {} + virtual ~BrpcPsServer() {} + virtual uint64_t start(const std::string &ip, uint32_t port); + virtual int32_t stop() { + std::unique_lock lock(mutex_); + stoped_ = true; + cv_.notify_all(); + + _server.Stop(1000); + _server.Join(); + return 0; + } + virtual int32_t port(); + + private: + virtual int32_t initialize(); + + mutable std::mutex mutex_; + std::condition_variable cv_; + bool stoped_ = false; + brpc::Server _server; + std::shared_ptr _service; + std::vector> _pserver_channels; +}; + +class PsService; + +typedef int32_t (PsService::*serviceHandlerFunc)( + Table *table, const PsRequestMessage &request, PsResponseMessage &response, + brpc::Controller *cntl); + +class PsService : public PsBaseService { + public: + virtual int32_t initialize() override; + + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override; + + private: + int32_t initialize_shard_info(); + int32_t pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_dense_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + int32_t pull_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t pull_geo_param(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t push_sparse(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t load_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t save_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t shrink_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_one_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t clear_all_table(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_server(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t start_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + int32_t stop_profiler(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t print_table_stat(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + bool _is_initialize_shard_info; + std::mutex _initialize_shard_mutex; + std::unordered_map _service_handler_map; + std::unordered_map _msg_handler_map; + std::vector _ori_values; +}; + +class DownpourPServerBrpcClosure : public PServerClosure { + public: + DownpourPServerBrpcClosure(size_t num, PServerCallBack callback) + : PServerClosure(callback) { + _waiting_num = num; + _cntls.resize(num); + _requests.resize(num); + _responses.resize(num); + for (size_t i = 0; i < num; ++i) { + _cntls[i].reset(new brpc::Controller()); + } + } + virtual ~DownpourPServerBrpcClosure() {} + + virtual void Run() override { + if (_waiting_num.fetch_sub(1) == 1) { + _callback(this); + delete this; + } + } + PsRequestMessage *request(size_t i) { return &_requests[i]; } + PsResponseMessage *response(size_t i) { return &_responses[i]; } + brpc::Controller *cntl(size_t i) { return _cntls[i].get(); } + int check_response(size_t request_idx, int cmd_id) { return 1; } + int check_save_response(size_t request_idx, int cmd_id) { return 1; } + + private: + std::atomic _waiting_num; + std::vector _requests; + std::vector _responses; + std::vector> _cntls; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..abd58bf028c2c19e50d18d8b33ff34e2b92e2d3f --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -0,0 +1,314 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +framework::proto::VarType::Type VarMessageToVarType( + VariableMessage::Type type) { + switch (type) { + case VariableMessage::FP32: + return framework::proto::VarType::FP32; // NOLINT + case VariableMessage::FP64: + return framework::proto::VarType::FP64; // NOLINT + case VariableMessage::INT32: + return framework::proto::VarType::INT32; // NOLINT + case VariableMessage::INT64: + return framework::proto::VarType::INT64; // NOLINT + case VariableMessage::BOOL: + return framework::proto::VarType::BOOL; // NOLINT + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "VarMessageToVarType:Unsupported type %d", type)); + } +} + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* request, butil::IOBuf* iobuf) { + // 1. message_name + request->set_message_name(message_name); + + // 2. var_names + for (auto& send_var_name : send_var_name_val) { + request->add_send_var_names(send_var_name); + } + for (auto& recv_var_name : recv_var_name_val) { + request->add_recv_var_names(recv_var_name); + } + + // 3. VarMessage + for (auto& send_var_name : send_var_name_val) { + auto* send_var_msg = request->add_var_messages(); + butil::IOBuf temp_iobuf; + send_var_msg->set_varname(send_var_name); + + framework::Variable* var = scope->FindVar(send_var_name); + + if (var->IsType()) { + SerializeLodTensor(var, ctx, send_var_msg, &temp_iobuf); + } else if (var->IsType()) { + SerializeSelectedRows(var, ctx, send_var_msg, &temp_iobuf); + } + iobuf->append(temp_iobuf); + } +} + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + auto* tensor = var->GetMutable(); + var_msg->set_type(::paddle::LOD_TENSOR); + const framework::LoD lod = tensor->lod(); + if (lod.size() > 0) { + var_msg->set_lod_level(lod.size()); + for (auto& each : lod) { + VarMsg::LodData* lod_inner = var_msg->add_lod(); + for (auto& d : each) { + lod_inner->add_lod_data(d); + } + } + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf) { + framework::SelectedRows* slr = var->GetMutable(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + + var_msg->set_type(::paddle::SELECTED_ROWS); + var_msg->set_slr_height(slr->height()); + + auto* var_data = var_msg->mutable_data(); + var_data->clear(); + var_data->resize(rows->size() * sizeof(int64_t)); + char* data_ptr = const_cast(var_data->data()); + + if (platform::is_cpu_place(tensor->place())) { + memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t)); + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), data_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + &(*rows)[0], rows->size() * sizeof(int64_t), stream); +#endif + } + var_msg->set_data_type(static_cast(tensor->type())); + for (auto& dim : framework::vectorize(tensor->dims())) { + var_msg->add_dims(dim); + } + + // IO Buffer + if (platform::is_cpu_place(tensor->place())) { + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(tensor->data()), + data_len); + } else { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); + iobuf->append(reinterpret_cast(&data_len), 8); + iobuf->append(reinterpret_cast(temp_ptr), data_len); + delete[] temp_ptr; +#endif + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->Var(msg.varname()); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope) { + butil::IOBufBytesIterator io_buffer_itr(*iobuf); + // size_t shard_buffer_remain = res_io_buffer.size(); + for (int recv_var_index = 0; recv_var_index < multi_msg.send_var_names_size(); + ++recv_var_index) { + const auto& msg = multi_msg.var_messages(recv_var_index); + auto* var = scope->FindVar(msg.varname()); + PADDLE_ENFORCE_NE(var, nullptr, + platform::errors::InvalidArgument( + "Not find variable %s in scope.", msg.varname())); + if (msg.type() == ::paddle::LOD_TENSOR) { + DeserializeLodTensor(var, msg, io_buffer_itr, ctx); + } else if (msg.type() == ::paddle::SELECTED_ROWS) { + DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); + } + } +} + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + framework::LoDTensor* tensor = var->GetMutable(); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + + framework::LoD lod; + for (int i = 0; i < msg.lod_level(); ++i) { + framework::Vector v; + for (int j = 0; j < msg.lod(i).lod_data_size(); ++j) { + v.push_back(msg.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + unsigned long data_len; + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward((void*)temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), (void*)temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& io_buffer_itr, + const platform::DeviceContext& ctx) { + const auto place = ctx.GetPlace(); + auto* slr = var->GetMutable(); + framework::Tensor* tensor = slr->mutable_value(); + slr->set_height(msg.slr_height()); + std::vector tmp_rows(msg.slr_height()); + memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t)); + slr->set_rows(tmp_rows); + std::vector vec_dim; + for (auto& x : msg.dims()) { + vec_dim.push_back(x); + } + tensor->Resize(framework::make_ddim(vec_dim)); + void* tensor_data = + tensor->mutable_data(place, VarMessageToVarType(msg.data_type())); + // IO Buffer + if (platform::is_cpu_place(place)) { + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(tensor_data, data_len); + } else if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + char* temp_ptr = + new char[tensor->numel() * framework::SizeOfType(tensor->type())]; + unsigned long data_len; + io_buffer_itr.copy_and_forward((void*)(&data_len), 8); + io_buffer_itr.copy_and_forward(temp_ptr, data_len); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), temp_ptr, + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + delete[] temp_ptr; +#endif + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/brpc_utils.h b/paddle/fluid/distributed/service/brpc_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..aa340c58a7b8b0ed93d1dd67cd747689be9fe094 --- /dev/null +++ b/paddle/fluid/distributed/service/brpc_utils.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "brpc/channel.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/port.h" + +namespace grpc { +class ByteBuffer; +} // namespace grpc +namespace paddle { +namespace framework { +class Scope; +class Variable; +} // namespace framework +namespace platform { +class DeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +void SerializeToMultiVarMsgAndIOBuf( + const std::string& message_name, + const std::vector& send_var_name_val, + const std::vector& recv_var_name_val, + const platform::DeviceContext& ctx, const framework::Scope* scope, + MultiVarMsg* var_msg, butil::IOBuf* iobuf); + +void SerializeLodTensor(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* var_msg, + butil::IOBuf* iobuf); + +void SerializeSelectedRows(framework::Variable* var, + const platform::DeviceContext& ctx, VarMsg* request, + butil::IOBuf* iobuf); + +// Deserialize for Server +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + framework::Scope* scope); + +// Deserialize for Client +void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, + const butil::IOBuf* iobuf, + const platform::DeviceContext& ctx, + const framework::Scope* scope); + +void DeserializeLodTensor(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, + butil::IOBufBytesIterator& iobuf, + const platform::DeviceContext& ctx); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc new file mode 100644 index 0000000000000000000000000000000000000000..18776a61a5cee7306ea85114dbeec81287579f34 --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.cc @@ -0,0 +1,1171 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/distributed/service/communicator.h" +#include +#include "paddle/fluid/distributed/table/table.h" + +#include +#include + +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/string/printf.h" +#include "paddle/fluid/string/split.h" + +namespace paddle { +namespace distributed { + +using framework::LoDTensor; +using framework::SelectedRows; + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; +} + +Communicator::Communicator() {} + +void Communicator::init_gflag(const std::string &gflags) { + VLOG(0) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char *flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char *)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char **params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +std::once_flag Communicator::init_flag_; +std::shared_ptr Communicator::communicator_(nullptr); + +void Communicator::InitBrpcClient( + const std::string &dist_desc, + const std::vector &host_sign_list) { + // not used, just for psclient's init + std::map> + _dense_pull_regions; + for (auto &iter : recv_varname_to_ctx_) { + auto tid = iter.first; + auto var_names = iter.second; + + auto ®ions = _dense_pull_regions[tid]; + regions.reserve(var_names.size()); + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + float *w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + + if (_worker_ptr.get() == nullptr) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + servers_ = host_sign_list.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list, servers_); + _worker_ptr = std::shared_ptr( + paddle::distributed::PSClientFactory::create(_ps_param)); + _worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env, + trainer_id_); + } + return; +} + +void Communicator::RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvDense"); + std::vector regions; + regions.reserve(varnames.size()); + for (auto &t : varnames) { + Variable *var = scope->Var(t); + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(tensor->place()); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + } + } + auto status = + _worker_ptr->pull_dense(regions.data(), regions.size(), table_id); + status.wait(); + + for (auto &t : varnames) { + Variable *var = scope->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } + } + + return; +} + +void Communicator::RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDenseParam"); + auto place = platform::CPUPlace(); + std::vector regions; + for (auto &t : varnames) { + Variable *var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor *tensor = var->GetMutable(); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + Variable *temp_var = xpu_temp_scope_->Var(t); + LoDTensor *temp_tensor = temp_var->GetMutable(); + temp_tensor->Resize(tensor->dims()); + float *temp_data = temp_tensor->mutable_data(platform::CPUPlace()); + framework::TensorCopy(*tensor, platform::CPUPlace(), temp_tensor); + paddle::distributed::Region reg(temp_data, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " table_id " << table_id << " Temp_data[0] " << temp_data[0] + << " Temp_data[-1] " << temp_data[tensor->numel() - 1]; +#endif + } else { + float *w = tensor->mutable_data(place); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(1) << "AsyncCommunicator::RpcSendDenseParam Var " << t + << " talbe_id " << table_id << " Temp_data[0] " << w[0] + << " Temp_data[-1] " << w[tensor->numel() - 1]; + } + } + auto status = + _worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + status.wait(); + VLOG(4) << "RPC Send Dense Param " << table_id << " done!"; + return; +} + +void Communicator::RpcSendDense(const CommContext &ctx, const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendDense"); + auto &var_names = ctx.origin_varnames; + auto &table_id = ctx.table_id; + auto dense_data = std::make_shared>(); + size_t request_call_num = _worker_ptr->get_server_nums(); + uint32_t num_per_shard = + dense_dim_per_shard(ctx.height_sections[0], request_call_num); + dense_data->resize(num_per_shard * + request_call_num); // accessor->update_dim() = 1 + float *data = dense_data->data(); + uint32_t pos = 0; + for (size_t i = 0; i < var_names.size(); ++i) { + const LoDTensor tensor = scope.FindVar(var_names[i])->Get(); + size_t count = static_cast(tensor.numel()); + const float *g = tensor.data(); + CHECK(pos + count <= dense_data->size()) + << "invalid dense size, cur pos[" << pos << "]" + << " data_num[" << count << "] size[" << dense_data->size() << "]"; + memcpy(data + pos, g, count * sizeof(float)); + pos += count; + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_dense_raw_gradient( + table_id, data, dense_data->size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparseParam"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector push_g_vec; + + auto *send_var = scope.FindVar(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + push_g_vec.reserve(sparse_num); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto status = _worker_ptr->push_sparse_param( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope) { + platform::RecordEvent record_event("Communicator->RpcSendSparse"); + size_t request_call_num = _worker_ptr->get_server_nums(); + std::vector sparse_push_keys; + std::vector push_g_vec; + + auto *send_var = scope.FindVar(var_name); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->value().dims()[1]; + std::transform(tensor->rows().begin(), tensor->rows().end(), + std::back_inserter(sparse_push_keys), + [&](int id) { return static_cast(id); }); + + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->mutable_value()->data() + i * dim); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, [this, request_call_num](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient( + table_id, sparse_push_keys.data(), (const float **)push_g_vec.data(), + sparse_push_keys.size(), closure); + status.wait(); + return; +} + +void Communicator::RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope) { + platform::RecordEvent record_event("Communicator->RpcRecvSparse"); + auto *send_var = scope->Var(varname); + auto *tensor = send_var->GetMutable(); + auto dim = tensor->dims()[1]; + uint64_t sparse_num = static_cast(tensor->dims()[0]); + + std::vector sparse_push_keys(sparse_num); + std::iota(sparse_push_keys.begin(), sparse_push_keys.end(), 0); + + std::vector push_g_vec; + for (auto i = 0; i < static_cast(sparse_push_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * dim); + } + + auto status = _worker_ptr->pull_sparse((float **)push_g_vec.data(), table_id, + sparse_push_keys.data(), + sparse_push_keys.size()); + status.wait(); + return; +} + +void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + if (trainer_id_ == 0) { + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcSendDenseParam(varnames, table_id, *recv_scope_); + VLOG(1) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + BarrierWithTable(1); + } else { + BarrierWithTable(1); + for (auto &iter : recv_varname_to_ctx) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(1) << "pull dense param to table " << table_id + << " from 0' trainer done"; + } + } + BarrierWithTable(1); + return; +} + +void Communicator::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = _worker_ptr->start_profiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = _worker_ptr->stop_profiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void AsyncCommunicator::RecvThread() { + if (!independent_recv_) return; + VLOG(3) << "Independent RecvThread Start and Wait"; + + while (running_) { + int grad_num = grad_num_.load(); + if (grad_num > min_send_grad_num_before_recv_) { + RecvByCommunicator(); + grad_num_.store(0); + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + VLOG(1) << "communicator stopped, independent recv thread exit"; +} + +void AsyncCommunicator::RecvByCommunicator() { + if (!running_) return; + RecvNoBarrier(); + VLOG(3) << "run recv graph end"; +} + +void AsyncCommunicator::RecvNoBarrier() { + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + RpcRecvDense(varnames, table_id, recv_scope_); + } + + for (auto &iter : recv_varname_to_ctx_) { + auto var_names = iter.second; + for (auto &t : var_names) { + Variable *var = recv_scope_->FindVar(t); + LoDTensor *tensor = var->GetMutable(); + VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " + << platform::is_gpu_place(tensor->place()); + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + LoDTensor *temp_tensor = + xpu_temp_scope_->FindVar(t)->GetMutable(); + framework::TensorCopy(*temp_tensor, tensor->place(), tensor); +#endif + } + } + } + + return; +} + +void AsyncCommunicator::SendByCommunicator() { + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + + auto send_recv_task = [this, &ctx] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + auto &check_queue = send_varname_to_queue_[varnames[0]]; + std::vector>> vars; + vars.resize(var_nums); + int merged_var_num = 0; + int wait_times = 0; + while (merged_var_num < max_merge_var_num_) { + if (check_queue->Size() == 0) { + VLOG(4) << "wait_times -> " << wait_times; + if (wait_times >= send_wait_times_) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } else { + wait_times = 0; + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + vars[i].push_back(var_queue->Pop()); + } + merged_var_num++; + } + } + if (merged_var_num == 0) return; + + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + if (!independent_recv_ && + recv_varname_to_ctx_.find(table_id) != recv_varname_to_ctx_.end()) { + auto recv_varnames = recv_varname_to_ctx_.at(table_id); + RpcRecvDense(recv_varnames, table_id, recv_scope_); + } + } + if (independent_recv_) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void AsyncCommunicator::MainThread() { + VLOG(3) << "AsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + RpcProfilerControl(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void HalfAsyncCommunicator::MainThread() { + VLOG(3) << "HalfAsyncCommunicator MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + SendByCommunicator(); + BarrierSend(); + RecvByCommunicator(); + BarrierRecv(); + BarrierWeakUp(); + } + VLOG(1) << "communicator stopped, send thread exit"; +} + +void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + send_scope_.reset(new Scope()); + xpu_temp_scope_.reset(new Scope()); + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + for (auto &var_name : varnames) { + send_varname_to_queue_[var_name] = + std::make_shared>>( + send_queue_size_); + } + } + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); +} + +AsyncCommunicator::~AsyncCommunicator() { + running_ = false; + if (main_thread_) main_thread_->join(); + if (recv_thread_) recv_thread_->join(); +} + +void AsyncCommunicator::Start() { + VLOG(1) << "Communicator start"; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + VLOG(1) << "start send thread and recv thread"; + waiting_ = true; + running_ = true; + // flushing_ = false; + BarrierTriggerReset(max_merge_var_num_); + // start send and recv thread + main_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::MainThread, this))); + if (independent_recv_) { + recv_thread_.reset( + new std::thread(std::bind(&AsyncCommunicator::RecvThread, this))); + } + } +} + +void AsyncCommunicator::Stop() { + VLOG(1) << "Communicator stop"; + running_ = false; + if (!communicator_) { + VLOG(0) << "Communicator is not inited, do nothing"; + } else { + if (recv_thread_) { + VLOG(1) << "stop recv thread"; + recv_thread_->join(); + recv_thread_.reset(nullptr); + } + if (main_thread_) { + VLOG(1) << "stop main thread"; + main_thread_->join(); + main_thread_.reset(nullptr); + } + } + VLOG(1) << "Communicator stop done"; +} + +bool AsyncCommunicator::Check(const std::vector &var_tables) { + PADDLE_ENFORCE_EQ( + var_tables.size(), 1, + platform::errors::InvalidArgument("var_tables.size() == 1 is permitted")); + + auto table_name = var_tables[0]; + if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) + return false; + return true; +} + +bool AsyncCommunicator::Check(const int table_id) { + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (ctx.table_id == table_id) return true; + } + return false; +} + +void AsyncCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + for (size_t i = 0; i < var_names.size(); i++) { + auto *var = scope.FindVar(var_names[i]); + auto tmp_grad_var = std::make_shared(); + framework::CopyVariable(*var, tmp_grad_var.get()); + send_varname_to_queue_[var_names[i]]->Push(tmp_grad_var); + } +} + +void HalfAsyncCommunicator::Clean() { + for (auto &iter : send_varname_to_queue_) { + auto &var_name = iter.first; + auto &var_queue = iter.second; + + while (var_queue->Size() > 0) { + var_queue->Pop(); + } + + VLOG(3) << "clean var: " << var_name << " done"; + } +} + +void HalfAsyncCommunicator::BarrierTriggerDecrement() { + barrier_trigger_--; + VLOG(3) << "BarrierTriggerDecrement decrement barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::BarrierTriggerReset(int initial_val) { + barrier_trigger_.store(initial_val); + + VLOG(3) << "BarrierTriggerReset reset barrier trigger to " + << barrier_trigger_.load(); +} + +void HalfAsyncCommunicator::Barrier() { + barrier_counter_++; + + if (!running_) { + VLOG(3) << "Communicator is not running, release barrier"; + return; + } + + { + std::unique_lock lk(barrier_mutex_); + barrier_cond_.wait(lk, [this] { return (barrier_counter_ == 0); }); + } +} + +int HalfAsyncCommunicator::BatchesCounter() { + while (running_) { + if (barrier_counter_.load() >= barrier_trigger_.load() && + barrier_trigger_.load() != 0) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + return barrier_counter_.load(); +} + +void HalfAsyncCommunicator::SendByCommunicator() { + int batches = BatchesCounter(); + VLOG(1) << "HalfAsyncCommunicator::BatchesCounter = " << batches; + if (batches <= 0) return; + + std::vector> tasks; + tasks.reserve(send_varname_to_ctx_.size()); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto send_recv_task = [this, &ctx, batches] { + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + size_t var_nums = varnames.size(); + + std::vector>> vars; + vars.resize(var_nums); + for (size_t i = 0; i < var_nums; i++) { + auto &var_name = varnames[i]; + auto &var_queue = send_varname_to_queue_[var_name]; + for (int j = 0; j < batches; j++) vars[i].push_back(var_queue->Pop()); + MergeVars(var_name, vars[i], send_scope_.get(), 1); + } + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + RpcSendSparse(varnames[0], table_id, *send_scope_); + } else { + RpcSendDense(ctx, *send_scope_); + } + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(send_recv_task))); + } + for (auto &task : tasks) { + task.wait(); + } + return; +} + +void HalfAsyncCommunicator::BarrierWeakUp() { + barrier_counter_.store(0); + barrier_cond_.notify_all(); +} + +void SyncCommunicator::BarrierSend() { + if (!running_) return; + BarrierWithTable(0); + VLOG(4) << "BarrierSend with SyncCommunicator"; +} + +void SyncCommunicator::BarrierRecv() { + if (!running_) return; + BarrierWithTable(1); + + VLOG(4) << "BarrierRecv with SyncCommunicator"; +} + +void GeoCommunicator::Send(const std::vector &var_names, + const framework::Scope &scope) { + waiting_ = false; + auto before_send = GetCurrentUS(); + auto table_name = var_names[0]; + + size_t splited_var_nums = + send_varname_to_ctx_[table_name].splited_varnames.size(); + + std::unordered_map> ids_table; + + for (size_t j = 0; j < splited_var_nums; j++) { + ids_table.insert(std::pair>( + send_varname_to_ctx_[table_name].splited_varnames[j], + std::unordered_set())); + } + + auto *var = scope.FindVar(table_name); + + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "Only need to send Sparse Grad in Geo mode.")); + auto &rows = var->Get().rows(); + + // insert ids which has not been record + for (size_t j = 0; j < rows.size(); j++) { + auto ep_idx = rows[j] % splited_var_nums; + ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) + .insert(rows[j]); + } + + for (auto &iter : ids_table) { + auto &key = iter.first; + auto &sparse_ids_set = iter.second; + auto sparse_ids_vec = std::make_shared>(); + sparse_ids_vec->assign(sparse_ids_set.begin(), sparse_ids_set.end()); + sparse_id_queues_.at(key)->Push(sparse_ids_vec); + VLOG(3) << "push " << sparse_ids_vec->size() << " ids to " << key + << "'s queue"; + } + + auto after_send = GetCurrentUS(); + VLOG(2) << "run send op finish. use time " << (after_send - before_send); +} + +void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) { + send_varname_to_ctx_ = std::move(send_varname_to_ctx); + recv_varname_to_ctx_ = std::move(recv_varname_to_ctx); + recv_scope_ = std::move(recv_scope); + + PADDLE_ENFORCE_GT( + send_varname_to_ctx.size(), 0, + platform::errors::InvalidArgument("send var contexts can not be zero")); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) continue; + auto &varnames = ctx.origin_varnames; + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + for (auto &splited_var : ctx.splited_varnames) { + parallel_task_nums_ += 1; + sparse_id_queues_.insert( + std::pair>>>>( + splited_var, + std::make_shared< + BlockingQueue>>>( + send_queue_size_))); + } + } + + send_threadpool_.reset(new ::ThreadPool(thread_pool_size_)); + + delta_scope_.reset(new Scope()); + old_scope_.reset(new Scope()); + pserver_scope_.reset(new Scope()); +} + +void GeoCommunicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { + std::vector> tasks; + tasks.reserve(recv_varname_to_ctx_.size()); + + for (auto &iter : recv_varname_to_ctx_) { + auto &table_id = iter.first; + auto &varnames = iter.second; + + auto recv_task = [this, &table_id, &varnames] { + InitDense(varnames, table_id); + }; + tasks.emplace_back(send_threadpool_->enqueue(std::move(recv_task))); + } + + for (auto &task : tasks) { + task.wait(); + } + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + if (!ctx.is_sparse) return; + auto &varname = ctx.origin_varnames[0]; + auto &table_id = ctx.table_id; + auto param = varname.substr(0, varname.size() - 5); + InitSparse(param, table_id); + } + return; +} + +void GeoCommunicator::InitDense(std::vector &varnames, + int table_id) { + if (trainer_id_ == 0) { + RpcSendDenseParam(varnames, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvDense(varnames, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + // copy to old_scope + for (auto &t : varnames) { + auto *global_var = recv_scope_->FindVar(t); + global_var->GetMutable(); + auto *old_var = old_scope_->Var(t); + old_var->GetMutable(); + framework::CopyVariable(*global_var, old_var); + } + VLOG(1) << "init dense table " << table_id << " done"; +} + +void GeoCommunicator::SendDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->SendDense"); + auto &var_names = send_ctx.origin_varnames; + auto &table_id = send_ctx.table_id; + for (auto &varname : var_names) { + auto param_name = GradToParam(varname); + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_timestamp = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_timestamp->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto t_timestamp = var_timestamp->GetMutable(); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest.dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest.numel(), t_latest.data(), + t_timestamp->data(), t_delta->data()); + + float coefficient = 1.0 / static_cast(trainers_); + blas.SCAL(t_latest.numel(), coefficient, t_delta->data()); + + blas.VADD(t_latest.numel(), t_timestamp->data(), + t_delta->data(), t_timestamp->data()); + } + RpcSendDense(send_ctx, *delta_scope_); + VLOG(1) << "Finish Send Dense " << var_names[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvDense(const CommContext &send_ctx) { + platform::RecordEvent record_event("GeoCommunicator->RecvDense"); + auto &table_id = send_ctx.table_id; + auto &varnames = recv_varname_to_ctx_.at(table_id); + // 1. recv from pserver + RpcRecvDense(varnames, table_id, pserver_scope_.get()); + + // 2.1 pserver - old => delta; 2.2 latest + old => latest 2.3 old => pserver + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + for (auto &varname : varnames) { + auto *var_latest = recv_scope_->FindVar(varname); + auto t_latest = var_latest->GetMutable(); + + auto *var_old = old_scope_->FindVar(varname); + auto t_old = var_old->GetMutable(); + + auto *var_pserver = pserver_scope_->FindVar(varname); + auto t_pserver = var_pserver->Get(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + t_delta->mutable_data(t_latest->dims(), cpu_ctx.GetPlace()); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + blas.VSUB(t_latest->numel(), t_pserver.data(), t_old->data(), + t_delta->data()); + blas.VADD(t_latest->numel(), t_latest->data(), + t_delta->data(), t_latest->data()); + blas.VCOPY(t_latest->numel(), t_pserver.data(), + t_old->data()); + } + VLOG(1) << "Finish Recv Dense " << varnames[0] << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) { + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " begin."; + if (trainer_id_ == 0) { + RpcSendSparseParam(var_name, table_id, *recv_scope_); + BarrierWithTable(1); + VLOG(0) << "push sparse param to table " << table_id + << " from 0' trainer done"; + } else { + BarrierWithTable(1); + RpcRecvSparse(var_name, table_id, recv_scope_); + VLOG(0) << "push dense param to table " << table_id + << " from 0' trainer done"; + } + + VLOG(0) << "Init Sparse " << var_name << " : table " << table_id << " done."; + auto *global_var = recv_scope_->FindVar(var_name); + auto *var = old_scope_->Var(var_name); + framework::CopyVariable(*global_var, var); + return; +} + +std::vector GeoCommunicator::MergeSparseIds( + const std::string &send_varname) { + size_t merge_num = 0, wait_times = 0; + std::unordered_set sparse_ids; + while (merge_num < static_cast(max_merge_var_num_)) { + VLOG(3) << "Merge Number of " << send_varname << " = " << merge_num; + if (sparse_id_queues_.at(send_varname)->Size() > 0) { + wait_times = 0; + std::shared_ptr> pop_ids = + sparse_id_queues_.at(send_varname)->Pop(); + for (size_t j = 0; j < pop_ids->size(); j++) { + sparse_ids.insert(pop_ids->at(j)); + } + merge_num += 1; + VLOG(3) << "sparse_id_queues_(" << send_varname << ") pushed"; + } else if (sparse_id_queues_.at(send_varname)->Size() == 0) { + VLOG(3) << "wait_times -> " << wait_times; + if (wait_times >= static_cast(send_wait_times_)) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + wait_times++; + continue; + } + } + std::vector res; + res.assign(sparse_ids.begin(), sparse_ids.end()); + return res; +} + +void GeoCommunicator::SendSparse(const std::string &varname, + std::vector &sparse_ids, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->SendSparse"); + std::string param_name = SplitedGradToParam(varname); + VLOG(1) << "In GeoCommunicator::SendSparse(" << varname << " " << param_name + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id + << ", ep_idx: " << ep_idx; + + auto *var_latest = recv_scope_->FindVar(param_name); + auto *var_old = old_scope_->FindVar(param_name); + + PADDLE_ENFORCE_EQ(var_latest->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + PADDLE_ENFORCE_EQ(var_old->IsInitialized(), true, + platform::errors::Unavailable( + "%s is not initialized, please check", param_name)); + + auto &t_latest = var_latest->Get(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest.dims()[1]; + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + + auto *var_delta = delta_scope_->Var(varname); + auto *t_delta = var_delta->GetMutable(); + auto *var_t_value = t_delta->mutable_value(); + var_t_value->Resize({static_cast(sparse_ids.size()), dims1}); + auto *t_value = var_t_value->mutable_data(cpu_ctx.GetPlace()); + + t_delta->set_rows(sparse_ids); + t_delta->set_height(t_latest.dims()[0]); + + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + float coefficient = 1.0 / static_cast(trainers_); + + std::vector push_g_vec; + for (auto j = 0; j < static_cast(sparse_ids.size()); ++j) { + blas.VSUB(dims1, t_latest.data() + sparse_ids[j] * dims1, + t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1); + blas.SCAL(dims1, coefficient, t_value + j * dims1); + blas.VADD(dims1, t_old->data() + sparse_ids[j] * dims1, + t_value + j * dims1, + t_old->data() + sparse_ids[j] * dims1); + push_g_vec.push_back(t_value + j * dims1); + } + + ++_async_call_num; + DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [this](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + if (closure->check_response(0, PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + } + closure->set_promise_value(ret); + --_async_call_num; + }); + auto status = _worker_ptr->push_sparse_raw_gradient_partial( + table_id, (const uint64_t *)sparse_ids.data(), + (const float **)push_g_vec.data(), sparse_ids.size(), closure, ep_idx); + status.wait(); + + VLOG(1) << "Finish Send Sparse " << varname + << ", ids.size = " << sparse_ids.size() << ", table_id: " << table_id; + return; +} + +void GeoCommunicator::RecvSparse(const std::string &varname, int table_id, + int ep_idx) { + platform::RecordEvent record_event("GeoCommunicator->RecvSparse"); + // 1. recv from pserver + std::vector keys; + std::vector values; + auto status = _worker_ptr->pull_geo_param(table_id, &values, &keys, ep_idx); + status.wait(); + + std::string param = SplitedGradToParam(varname); + VLOG(1) << "RecvSparse receive var: " << varname << " " << param << ", " + << table_id << "; ids Size: " << keys.size() + << "; values size: " << values.size(); + + auto *var_latest = recv_scope_->FindVar(param); + auto *var_old = old_scope_->FindVar(param); + + auto *t_latest = var_latest->GetMutable(); + auto *t_old = var_old->GetMutable(); + + auto dims1 = t_latest->dims()[1]; + auto numel = keys.size() * dims1; + + std::vector v_delta; + v_delta.resize(numel); + + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + auto blas = + paddle::operators::math::GetBlas( + cpu_ctx); + + for (auto j = 0; j < static_cast(keys.size()); ++j) { + float *latest_data = t_latest->data() + keys[j] * dims1; + float *old_data = t_old->data() + keys[j] * dims1; + // pserver - old => delta + blas.VSUB(dims1, values.data() + j * dims1, old_data, + v_delta.data() + j * dims1); + // latest + delta => latest + blas.VADD(dims1, latest_data, v_delta.data() + j * dims1, latest_data); + // pserver => old + blas.VCOPY(dims1, values.data() + j * dims1, old_data); + } + VLOG(1) << "Finish Recv Sparse " << param << ", table_id: " << table_id; +} + +void GeoCommunicator::MainThread() { + VLOG(3) << "MainThread start and wait"; + + while (waiting_ && running_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + VLOG(3) << "wait for running"; + } + + while (running_) { + std::vector> tasks; + tasks.reserve(parallel_task_nums_); + + for (auto &iter : send_varname_to_ctx_) { + auto &ctx = iter.second; + auto &varnames = ctx.origin_varnames; + auto &table_id = ctx.table_id; + + if (ctx.is_sparse) { + PADDLE_ENFORCE_EQ( + varnames.size(), 1, + platform::errors::InvalidArgument( + "sparse variables can only be merged by one variables")); + int pserver_num = static_cast(ctx.epmap.size()); + for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { + // varname: emb@GRAD, param_name: emb, splited_varname: emb.delta0 + auto send_recv_task = [this, table_id, ep_idx, &ctx] { + auto splited_varname = ctx.splited_varnames[ep_idx]; + auto sparse_ids = MergeSparseIds(splited_varname); + SendSparse(splited_varname, sparse_ids, table_id, ep_idx); + RecvSparse(splited_varname, table_id, ep_idx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } else { + auto send_recv_task = [this, &ctx] { + SendDense(ctx); + RecvDense(ctx); + }; + tasks.emplace_back( + send_threadpool_->enqueue(std::move(send_recv_task))); + } + } + for (auto &task : tasks) { + task.wait(); + } + } +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h new file mode 100644 index 0000000000000000000000000000000000000000..a22b006013461c9ca4c10710339ac60550dabec9 --- /dev/null +++ b/paddle/fluid/distributed/service/communicator.h @@ -0,0 +1,561 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/distributed/communicator_common.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/split.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/ps_client.h" + +DECLARE_bool(communicator_is_sgd_optimizer); + +namespace paddle { +namespace distributed { + +using Scope = framework::Scope; +using Variable = framework::Variable; + +template +class BlockingQueue { + public: + explicit BlockingQueue(size_t capacity) : capacity_(capacity) { + PADDLE_ENFORCE_GT(capacity_, 0, + platform::errors::InvalidArgument( + "The capacity must be greater than 0.")); + } + + bool Push(const T &elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.push_back(elem); + } + cv_.notify_one(); + return true; + } + + bool Push(T &&elem) { + { + std::unique_lock lock(mutex_); + cv_.wait(lock, [&] { return queue_.size() < capacity_; }); + queue_.emplace_back(std::move(elem)); + } + cv_.notify_one(); + return true; + } + + T Pop() { + std::unique_lock lock(mutex_); + cv_.wait(lock, [=] { return !queue_.empty(); }); + T rc(std::move(queue_.front())); + queue_.pop_front(); + cv_.notify_one(); + return rc; + } + + size_t Cap() const { + std::lock_guard lock(mutex_); + return capacity_; + } + + size_t Size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + private: + const size_t capacity_; + std::deque queue_; + + mutable std::mutex mutex_; + std::condition_variable cv_; +}; + +template +using EigenVector = framework::EigenVector; + +template +inline void MergeVars(const std::string &var_name, + const std::vector> &vars, + Scope *scope, bool merge_add = true) { + PADDLE_ENFORCE_NE(vars.empty(), true, platform::errors::InvalidArgument( + "vector vars are empty.")); + auto cpu_place = platform::CPUPlace(); + auto &var0 = vars[0]; + auto *out_var = scope->Var(var_name); + + if (var0->IsType()) { + auto dims = var0->Get().dims(); + VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims + << "; merge add: " << merge_add; + // init output tensor + auto *out_t = out_var->GetMutable(); + out_t->mutable_data(dims, cpu_place); + // check the input dims + for (auto &var : vars) { + auto &var_t = var->Get(); + PADDLE_ENFORCE_EQ( + var_t.dims(), dims, + platform::errors::InvalidArgument("vars should have the same dims.")); + } + + // set output tensor to 0. + auto cpu_ctx = paddle::platform::CPUDeviceContext(); + paddle::operators::math::SetConstant + constant_functor; + constant_functor(cpu_ctx, out_t, static_cast(0)); + // sum all vars to out + auto result = EigenVector::Flatten(*out_t); + for (auto &var : vars) { + auto &in_t = var->Get(); + auto in = EigenVector::Flatten(in_t); + result.device(*cpu_ctx.eigen_device()) = result + in; + } + if (!merge_add) { + result.device(*cpu_ctx.eigen_device()) = + result / static_cast(vars.size()); + } + } else if (var0->IsType()) { + auto &slr0 = var0->Get(); + auto *out_slr = out_var->GetMutable(); + out_slr->mutable_rows()->clear(); + out_slr->mutable_value()->mutable_data({{}}, cpu_place); + std::vector inputs; + inputs.reserve(vars.size()); + for (auto &var : vars) { + inputs.push_back(&var->Get()); + } + auto dev_ctx = paddle::platform::CPUDeviceContext(); + if (merge_add) { + paddle::operators::math::scatter::MergeAdd< + paddle::platform::CPUDeviceContext, T> + merge_add; + merge_add(dev_ctx, inputs, out_slr); + } else { + paddle::operators::math::scatter::MergeAverage< + paddle::platform::CPUDeviceContext, T> + merge_average; + merge_average(dev_ctx, inputs, out_slr); + } + + VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() + << " dims: " << slr0.value().dims() << "; merge add: " << merge_add; + } else { + PADDLE_THROW(platform::errors::InvalidArgument("unsupported var type: %s!", + var0->Type())); + } +} + +using RpcCtxMap = std::unordered_map; +using RecvCtxMap = std::unordered_map>; +using SparseValue = std::unordered_map>; + +class Communicator { + public: + Communicator(); + + explicit Communicator(const std::map &envs_) { + VLOG(0) << "Communicator Init Envs"; + for (auto &iter : envs_) { + envs[iter.first] = iter.second; + VLOG(0) << iter.first << ": " << iter.second; + } + barrier_table_id_ = std::stoi(envs.at("barrier_table_id")); + trainer_id_ = std::stoi(envs.at("trainer_id")); + trainers_ = std::stoi(envs.at("trainers")); + } + + virtual void InitBrpcClient(const std::string &dist_desc, + const std::vector &host_sign_list); + // 1. recv dense param + virtual void RpcRecvDense(const std::vector &varnames, + int table_id, Scope *scope); + // 2. send dense param + virtual void RpcSendDenseParam(const std::vector &varnames, + int table_id, const Scope &scope); + // 3. send dense grad + virtual void RpcSendDense(const CommContext &ctx, const Scope &scope); + // 4. send sparse grad + virtual void RpcSendSparse(const std::string &var_name, int table_id, + const Scope &scope); + // 5. send sparse param + virtual void RpcSendSparseParam(const std::string &varname, int table_id, + const Scope &scope); + // 6. recv sparse param + virtual void RpcRecvSparse(const std::string &varname, int table_id, + Scope *scope); + + virtual ~Communicator() {} + virtual void RpcProfilerControl(); + + virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); + + virtual void Start() = 0; + + virtual void Stop() = 0; + + virtual bool IsRunning() { return running_; } + + virtual void Clean() {} + + virtual bool Check(const int table_id) = 0; + virtual bool Check(const std::vector &var_tables) = 0; + + virtual void Send(const std::vector &var_names, + const framework::Scope &scope) = 0; + + virtual void RecvNoBarrier() {} + + virtual void Barrier() {} + + virtual void BarrierWithTable(uint32_t barrier_type) { + auto rets = _worker_ptr->barrier(barrier_table_id_, barrier_type); + rets.wait(); + } + + virtual void BarrierTriggerDecrement() {} + + virtual void BarrierTriggerReset(int init_counter) {} + + virtual void InitEnvs() = 0; + + virtual void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) {} + + static Communicator *GetInstance() { return communicator_.get(); } + + static std::shared_ptr GetInstantcePtr() { + return communicator_; + } + + template + static Communicator *InitInstance( + const RpcCtxMap &send_ctx, const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, Scope *recv_scope, + const std::map &envs) { + std::call_once(init_flag_, &Communicator::InitWithRpcCtx, send_ctx, + recv_ctx, dist_desc, host_sign_list, recv_scope, + std::ref(envs)); + return communicator_.get(); + } + + // Init is called by InitInstance. + template + static void InitWithRpcCtx(const RpcCtxMap &send_ctx, + const RecvCtxMap &recv_ctx, + const std::string &dist_desc, + const std::vector &host_sign_list, + Scope *recv_scope, + const std::map &envs) { + if (communicator_.get() == nullptr) { + communicator_.reset(new T(std::ref(envs))); + communicator_->InitEnvs(); + communicator_->InitBrpcClient(dist_desc, host_sign_list); + communicator_->InitImpl(send_ctx, recv_ctx, recv_scope); + } + } + + PSClient *GetPsClient() { return _worker_ptr.get(); } + + std::shared_ptr GetPsClientPtr() { + return _worker_ptr; + } + + std::shared_ptr _worker_ptr; // pointer to worker + + protected: + bool running_ = false; + bool waiting_ = true; + bool flushing_ = false; + bool do_server_profiler_ = false; + static std::shared_ptr communicator_; + static std::once_flag init_flag_; + + std::unordered_map envs; + + // 计算每个shard 对 dense的存储量 + inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total, + uint32_t shard_num) { + return dense_dim_total / shard_num + 1; + } + + void init_gflag(const std::string &gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; + int servers_ = 0; + int trainers_; + int trainer_id_ = 0; + int barrier_table_id_ = 0; + RpcCtxMap send_varname_to_ctx_; + RecvCtxMap recv_varname_to_ctx_; + + Scope *recv_scope_; // should be global scope + std::unique_ptr xpu_temp_scope_; + std::atomic _async_call_num{0}; +}; + +class AsyncCommunicator : public Communicator { + public: + AsyncCommunicator() : Communicator() {} + + explicit AsyncCommunicator(const std::map &envs) + : Communicator(envs) {} + + ~AsyncCommunicator(); + + void InitEnvs() { + independent_recv_ = static_cast( + std::stoi(envs.at("communicator_independent_recv_thread"))); + min_send_grad_num_before_recv_ = + std::stoi(envs.at("communicator_min_send_grad_num_before_recv")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + } + + void Start() override; + + void Stop() override; + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + virtual void MainThread(); + virtual void RecvThread(); + + virtual bool Check(const int table_id); + virtual bool Check(const std::vector &var_tables); + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + virtual void SendByCommunicator(); + + virtual void SendGlobalStep(int batches) {} + + virtual void RecvByCommunicator(); + + virtual void RecvNoBarrier(); + + virtual int BatchesCounter() { return 1; } + + virtual void BarrierSend() {} + + virtual void BarrierRecv() {} + + virtual void BarrierWeakUp() {} + + protected: + std::unordered_map>>> + send_varname_to_queue_; + std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; + + int min_send_grad_num_before_recv_; + int thread_pool_size_; + int max_merge_var_num_; + int send_wait_times_; + int send_queue_size_; + bool need_global_step_ = false; + bool independent_recv_ = true; + int parallel_task_nums_ = 0; + + std::unique_ptr main_thread_{nullptr}; + std::unique_ptr recv_thread_{nullptr}; + + std::unique_ptr send_scope_; // an independent scope + std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv +}; + +class HalfAsyncCommunicator : public AsyncCommunicator { + public: + HalfAsyncCommunicator() {} + + explicit HalfAsyncCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "HalfAsyncCommunicator Initialized"; + } + + void MainThread() override; + + void SendByCommunicator() override; + + void Clean() override; + + void Barrier() override; + + void BarrierTriggerDecrement() override; + + void BarrierTriggerReset(int initial_val) override; + + int BatchesCounter(); + + void BarrierWeakUp(); + + protected: + // mutex for Wait for barrier + std::mutex barrier_mutex_; + std::condition_variable barrier_cond_; + std::atomic barrier_trigger_{0}; + std::atomic barrier_counter_{0}; +}; + +class SyncCommunicator : public HalfAsyncCommunicator { + public: + SyncCommunicator() : HalfAsyncCommunicator() {} + + explicit SyncCommunicator(const std::map &envs) + : HalfAsyncCommunicator(envs) {} + + void InitEnvs() { + // enfore to recv after send + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + send_queue_size_ = std::stoi(envs.at("communicator_send_queue_size")); + need_global_step_ = + static_cast(std::stoi(envs.at("need_global_step"))); + + VLOG(0) << "SyncCommunicator Initialized"; + } + + void BarrierSend(); + + void BarrierRecv(); + + private: + std::vector pserver_endpoints_{}; +}; + +class GeoCommunicator : public AsyncCommunicator { + public: + GeoCommunicator() : AsyncCommunicator() {} + + explicit GeoCommunicator(const std::map &envs) + : AsyncCommunicator(envs) {} + + void InitImpl(const RpcCtxMap &send_varname_to_ctx, + const RecvCtxMap &recv_varname_to_ctx, + Scope *recv_scope) override; + + void InitParams(const RecvCtxMap &recv_varname_to_ctx) override; + void InitDense(std::vector &varnames, int table_id); + void InitSparse(const std::string &var_name, int table_id); + + void SendDense(const CommContext &send_ctx); + void RecvDense(const CommContext &send_ctx); + + std::vector MergeSparseIds(const std::string &varname); + void SendSparse(const std::string &varname, std::vector &sparse_ids, + int table_id, int ep_idx); + void RecvSparse(const std::string &varname, int table_id, int ep_idx); + + void MainThread() override; + + void InitEnvs() { + independent_recv_ = false; + min_send_grad_num_before_recv_ = 0; + send_wait_times_ = std::stoi(envs.at("communicator_send_wait_times")); + thread_pool_size_ = std::stoi(envs.at("communicator_thread_pool_size")); + // id_queue's size + max_merge_var_num_ = std::stoi(envs.at("communicator_max_merge_var_num")); + send_queue_size_ = max_merge_var_num_; + VLOG(0) << "GeoCommunicator Initialized"; + } + + void Send(const std::vector &var_names, + const framework::Scope &scope) override; + + void SendByCommunicator() { return; } + + void SendGlobalStep(int batches) override { return; } + + void RecvByCommunicator() override { return; } + + inline std::string GradToParam(const std::string var_name) { + std::string param_name = var_name.substr(0, var_name.size() - 5); + return param_name; + } + + inline std::string SplitedGradToParam(const std::string delta_name) { + // delta_name: emb.delta0 + auto pos = delta_name.find(".block"); + std::string param_name = delta_name.substr(0, pos); + return param_name; + } + + private: + // parameter for delta calc and send + std::shared_ptr delta_scope_; + // parameter for storage the pserver param after last recv + std::shared_ptr old_scope_; + // parameter on pserver + std::shared_ptr pserver_scope_; + + std::unordered_map< + std::string, + std::shared_ptr>>>> + sparse_id_queues_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.cc b/paddle/fluid/distributed/service/env.cc new file mode 100644 index 0000000000000000000000000000000000000000..25bc2cc366aaacba32c22a5225d344f8618767d9 --- /dev/null +++ b/paddle/fluid/distributed/service/env.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/env.h" + +namespace paddle { +namespace distributed {} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h new file mode 100644 index 0000000000000000000000000000000000000000..42f31717f7fba4203cdbd24d59cfa2d9973d5e8a --- /dev/null +++ b/paddle/fluid/distributed/service/env.h @@ -0,0 +1,284 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace distributed { + +struct PSHost { + std::string ip; + uint32_t port; + uint32_t rank; + + PSHost() = default; + PSHost(const std::string ip, uint32_t port, uint32_t rank) + : ip(ip), port(port), rank(rank) {} + + // |---ip---|---port---|--rank--| + // |-32bit--|--20bit---|--12bit-| + // for pslib + uint64_t serialize_to_uint64() { + uint64_t host_label = 0; + host_label = inet_addr(ip.c_str()); + host_label = host_label << 32; + host_label += (port << 12); + host_label += rank; + return host_label; + } + + void parse_from_uint64(uint64_t host_label) { + static uint64_t rank_label_mask = (1L << 12) - 1; + static uint64_t port_label_mask = (1L << 20) - 1; + rank = host_label & rank_label_mask; + port = (host_label >> 12) & port_label_mask; + uint32_t ip_addr = (host_label >> 32); + ip = inet_ntoa(*(in_addr *)&ip_addr); + } + + std::string to_string() { + std::stringstream s; + s << "host: " << ip; + s << " port: " << port; + s << " rank: " << rank; + s << " uint: " << serialize_to_uint64(); + return s.str(); + } + + // for open source parameter server + std::string serialize_to_string() { + std::stringstream s; + s << ip << ":"; + s << port << ":"; + s << rank; + return s.str(); + } + + void parse_from_string(std::string endpoint) { + std::vector endpoint_info; + string_split(endpoint, ':', &endpoint_info); + ip = endpoint_info[0]; + port = std::stoi(endpoint_info[1]); + rank = std::stoi(endpoint_info[2]); + } + + void string_split(const std::string &str, char sep, + std::vector *pieces, bool ignore_null = true) { + pieces->clear(); + if (str.empty()) { + if (!ignore_null) { + pieces->push_back(str); + } + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } + } +}; + +class PSEnvironment { + public: + explicit PSEnvironment() {} + virtual ~PSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + return 0; + } + virtual int32_t set_ps_servers( + const std::vector *host_endpoint_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + return 0; + } + + virtual int32_t set_ps_clients(std::string *host_endpoint_list, + int node_num) { + return 0; + } + virtual uint64_t get_local_host_sign() { return 0; } + virtual std::vector get_ps_servers() const { return _ps_server_list; } + virtual int32_t registe_ps_server(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_server_list, + _ps_server_sign_set); + } + + virtual std::vector get_ps_clients() const { return _ps_client_list; } + virtual int32_t registe_ps_client(const std::string &ip, uint32_t port, + int32_t rank) { + return registe_ps_host(ip, port, rank, _ps_client_list, + _ps_client_sign_set); + } + + virtual std::vector get_client_info() { + std::vector client_info; + for (auto &i : _ps_client_sign_set) { + client_info.push_back(i); + } + return client_info; + } + + virtual std::vector get_client_info(bool use_string_endpoint) { + if (use_string_endpoint) { + std::vector client_info; + for (auto &i : _ps_client_list) { + client_info.push_back(i.serialize_to_string()); + } + return client_info; + } + return {}; + } + + protected: + //注册一个host + virtual int32_t registe_ps_host(const std::string &ip, uint32_t port, + int32_t rank, std::vector &host_list, + std::unordered_set &sign_set) { + PSHost host; + host.ip = ip; + host.port = port; + host.rank = rank; + if (sign_set.count(rank) > 0) { + LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + << ", rank:" << host.rank + << " already register, ignore register"; + } else { + host_list.push_back(host); + sign_set.insert(rank); + } + // if (sign_set.count(host.serialize_to_uint64()) > 0) { + // LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port + // << ", rank:" << host.rank + // << " already register, ignore register"; + // } else { + // host_list.push_back(host); + // sign_set.insert(host.serialize_to_uint64()); + // } + return 0; + } + + std::vector _ps_client_list; + std::unordered_set _ps_client_sign_set; // for unique filter + + std::vector _ps_server_list; + std::unordered_set _ps_server_sign_set; // for unique filter +}; + +class PaddlePSEnvironment : public PSEnvironment { + public: + explicit PaddlePSEnvironment() {} + virtual ~PaddlePSEnvironment() {} + + virtual int32_t set_ps_servers(uint64_t *host_sign_list, int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_servers(const std::vector *host_sign_list, + int node_num) { + _ps_server_list.clear(); + _ps_server_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_server_list.push_back(host); + _ps_server_sign_set.insert(host.rank); + } + } + std::sort( + _ps_server_list.begin(), _ps_server_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(uint64_t *host_sign_list, int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list[i] > 0) { + PSHost host; + host.parse_from_uint64(host_sign_list[i]); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.serialize_to_uint64()); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual int32_t set_ps_clients(std::vector *host_sign_list, + int node_num) { + _ps_client_list.clear(); + _ps_client_sign_set.clear(); + for (int i = 0; i < node_num; ++i) { + if (host_sign_list->at(i) != "") { + PSHost host; + host.parse_from_string(host_sign_list->at(i)); + _ps_client_list.push_back(host); + _ps_client_sign_set.insert(host.rank); + } + } + std::sort( + _ps_client_list.begin(), _ps_client_list.end(), + [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + return 0; + } + + virtual uint64_t get_local_host_sign() { + if (_ps_client_list.size() > 0) { + return _ps_client_list[0].serialize_to_uint64(); + } else { + return 0; + } + } +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4d1f27377f0e6f8a01b288fc48007ed66b2005c --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_client.h" +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/io/fs.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/timer.h" + +DECLARE_int32(rpc_deadline); +namespace paddle { +namespace distributed { + +DEFINE_int32(pserver_timeout_ms, 10800000, "pserver request server timeout_ms"); + +std::shared_ptr HeterClient::s_instance_ = NULL; +bool HeterClient::is_initialized_ = false; + +void HeterClient::MainThread() { + while (running_) { + RpcProfilerControl(); + } +} + +void HeterClient::Stop() { + running_ = false; + if (!is_initialized_) { + VLOG(0) << "HeterClient is not inited, do nothing"; + } else { + if (main_thread_) { + auto status = StopHeterWorker(); + status.wait(); + main_thread_->join(); + main_thread_.reset(nullptr); + } + VLOG(1) << "HeterClient Stop Done"; + } +} + +void HeterClient::RpcProfilerControl() { + if (trainer_id_ == 0) { + if (!do_server_profiler_ && platform::IsProfileEnabled()) { + // send profiler start flag + do_server_profiler_ = true; + auto start_status = StartProfiler(); + start_status.wait(); + } else if (do_server_profiler_ && !platform::IsProfileEnabled()) { + // send profiler end flag + auto stop_status = StopProfiler(); + stop_status.wait(); + do_server_profiler_ = false; + } + } +} + +void HeterClient::CreateClient2XpuConnection() { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.connection_type = "single"; + options.timeout_ms = pserver_timeout_ms; + + xpu_channels_.resize(xpu_list_.size()); + for (size_t i = 0; i < xpu_list_.size(); ++i) { + xpu_channels_[i].reset(new brpc::Channel()); + if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { + VLOG(0) << "HeterServer channel init fail"; + } + } +} + +void HeterClient::SendAndRecvAsync( + const std::vector& ep, const platform::DeviceContext& ctx, + const framework::Scope& scope, const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name) { + platform::RecordEvent record_event("HeterClient->SendAndRecvAsync"); + const platform::DeviceContext* p_ctx = &ctx; + const framework::Scope* p_scope = &scope; + const std::string message_name_val = message_name; + const std::vector send_var_name_val = send_var_name; + const std::vector recv_var_name_val = recv_var_name; + + VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: " + << message_name_val; + // Todo: get correct channel + int num = trainer_id_ % xpu_channels_.size(); + + brpc::Controller cntl; + cntl.set_timeout_ms(pserver_timeout_ms); + distributed::MultiVarMsg request, response; + auto& request_io_buffer = cntl.request_attachment(); + ::paddle::PsService_Stub stub(xpu_channels_[num].get()); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, + &request, &request_io_buffer); + stub.SendAndRecvVariable(&cntl, &request, &response, NULL); + PADDLE_ENFORCE_NE( + cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendAndRecv meets brpc error, error message is %s", + cntl.ErrorText())); + VLOG(4) << "call heter_worker success"; + auto& response_io_buffer = cntl.response_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer, + ctx, p_scope); +} + +std::future HeterClient::SendCmd( + uint32_t table_id, int cmd_id, const std::vector& params) { + size_t request_call_num = xpu_channels_.size(); + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure( + request_call_num, [request_call_num, cmd_id](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(trainer_id_); + for (const auto& param : params) { + closure->request(i)->add_params(param); + } + ::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); + closure->cntl(i)->set_timeout_ms( + pserver_timeout_ms); // cmd msg don't limit timeout for save/load + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + +std::future HeterClient::StartProfiler() { + return SendCmd(-1, PS_START_PROFILER, {}); +} + +std::future HeterClient::StopProfiler() { + return SendCmd(-1, PS_STOP_PROFILER, {}); +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h new file mode 100644 index 0000000000000000000000000000000000000000..b1c268c3231f9224163dc1e54206a207dc460551 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_client.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +typedef std::function HeterRpcCallbackFunc; + +class OnHeterRpcDone : public google::protobuf::Closure { + public: + OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {} + virtual ~OnHeterRpcDone() {} + void Run() { + std::unique_ptr self_guard(this); + handler_(this); + } + + HeterRpcCallbackFunc handler_; + MultiVariableMessage response; + brpc::Controller cntl; +}; + +class HeterClient { + public: + virtual ~HeterClient() {} + + HeterClient() { + running_ = true; + main_thread_.reset( + new std::thread(std::bind(&HeterClient::MainThread, this))); + } + + void CreateClient2XpuConnection(); + + void SendAndRecvAsync(const std::vector& ep, + const platform::DeviceContext& ctx, + const framework::Scope& scope, + const std::string& message_name, + const std::vector& send_var_name, + const std::vector& recv_var_name); + + // HeterClient singleton + static std::shared_ptr GetInstance( + const std::vector& endpoint, const int& trainer_id) { + if (NULL == s_instance_) { + is_initialized_ = true; + s_instance_.reset(new paddle::distributed::HeterClient()); + std::vector xpu_list = {endpoint}; + s_instance_->SetXpuList(endpoint); + s_instance_->SetTrainerID(trainer_id); + s_instance_->CreateClient2XpuConnection(); + } + return s_instance_; + } + + void Stop(); + + void MainThread(); + + void RpcProfilerControl(); + + std::future SendCmd(uint32_t table_id, int cmd_id, + const std::vector& params); + + std::future StartProfiler(); + std::future StopProfiler(); + std::future StopHeterWorker(); + + std::vector& GetXpuList() { return xpu_list_; } + + void SetXpuList(const std::vector& xpu_list) { + xpu_list_ = xpu_list; + }; + + void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } + + private: + static std::shared_ptr s_instance_; + + protected: + static bool is_initialized_; + std::unique_ptr main_thread_{nullptr}; + std::vector> xpu_channels_; + DISABLE_COPY_AND_ASSIGN(HeterClient); + std::vector xpu_list_; + + bool running_ = false; + int trainer_id_; + bool do_server_profiler_ = false; +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.cc b/paddle/fluid/distributed/service/heter_server.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9daf8be1ccb66d6a185d8828d1d5a53417249d2 --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/heter_server.h" +#include +#include +#include "paddle/fluid/framework/fleet/heter_wrapper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/timer.h" + +namespace paddle { +namespace distributed { + +std::shared_ptr HeterServer::s_instance_ = NULL; + +void HeterServer::RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + service_.RegisterServiceHandler(message_name, func); +} + +void HeterServer::StartHeterService() { + server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); + brpc::ServerOptions options; + if (server_.Start(endpoint_.c_str(), &options) != 0) { + VLOG(0) << "heter server start fail"; + } else { + VLOG(0) << "heter server start success! listen on " << endpoint_; + } + + { + std::lock_guard lock(this->mutex_ready_); + ready_ = 1; + } + condition_ready_.notify_all(); + + server_.Join(); +} + +void HeterServer::SetEndPoint(std::string& endpoint) { + endpoint_ = endpoint; + service_.SetEndpoint(endpoint); +} + +void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); } + +void HeterServer::WaitServerReady() { + std::unique_lock lock(this->mutex_ready_); + condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); +} + +int32_t HeterService::stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::DisableProfiler( + platform::EventSortingKey::kDefault, + string::Sprintf("heter_worker_%s_profile", endpoint_)); + return 0; +} + +int32_t HeterService::start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + platform::EnableProfiler(platform::ProfilerState::kAll); + return 0; +} + +int32_t HeterService::stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl) { + auto client_id = request.client_id(); + stop_cpu_worker_set_.insert(client_id); + if (stop_cpu_worker_set_.size() == fan_in_) { + is_exit_ = true; + } + return 0; +} + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h new file mode 100644 index 0000000000000000000000000000000000000000..07fff7adc6e94a24fe52efbca21605c4c4e4a44c --- /dev/null +++ b/paddle/fluid/distributed/service/heter_server.h @@ -0,0 +1,243 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "brpc/channel.h" +#include "brpc/controller.h" +#include "brpc/server.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN +#include "paddle/fluid/platform/profiler.h" + +namespace paddle { +namespace distributed { + +using MultiVarMsg = ::paddle::MultiVariableMessage; +using VarMsg = ::paddle::VariableMessage; + +class HeterService; +typedef int32_t (HeterService::*serviceHandlerFunc)( + const PsRequestMessage& request, PsResponseMessage& response, + brpc::Controller* cntl); + +typedef std::function HeterRpcCallbackFunc; +typedef std::function + HeterServiceHandler; + +class HeterService : public ::paddle::PsService { + public: + HeterService() { + _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; + _service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler; + } + + virtual ~HeterService() {} + + virtual void service(::google::protobuf::RpcController* controller, + const ::paddle::PsRequestMessage* request, + ::paddle::PsResponseMessage* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string log_label("ReceiveCmd-"); + + response->set_err_code(0); + response->set_err_msg(""); + brpc::Controller* cntl = static_cast(controller); + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(*request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } + }; + + void SendAndRecvVariable(::google::protobuf::RpcController* controller, + const MultiVarMsg* request, MultiVarMsg* response, + ::google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::string message_name = request->message_name(); + auto itr = handler_map_.find(message_name); + brpc::Controller* cntl = static_cast(controller); + PADDLE_ENFORCE_NE( + itr, handler_map_.end(), + platform::errors::InvalidArgument( + "HeterService::SendAndRecvVariable Get illegal message_name: %s " + "which is not in HeterService::handler_map_", + message_name)); + itr->second(request, response, cntl); + } + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func) { + handler_map_[message_name] = func; + } + + void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } + void SetFanin(const int& fan_in) { fan_in_ = fan_in; } + bool IsExit() { return is_exit_; } + + private: + int32_t stop_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t start_profiler(const PsRequestMessage& request, + PsResponseMessage& response, brpc::Controller* cntl); + + int32_t stop_heter_worker(const PsRequestMessage& request, + PsResponseMessage& response, + brpc::Controller* cntl); + + private: + std::string endpoint_; + std::unordered_map handler_map_; + std::unordered_map _service_handler_map; + std::unordered_set stop_cpu_worker_set_; + int fan_in_; + bool is_exit_ = false; +}; + +class HeterServer { + public: + virtual ~HeterServer() {} + + void Stop() { + server_.Stop(1000); + server_.Join(); + } + + bool IsExit() { return service_.IsExit(); } + + HeterServer() {} + + void RegisterServiceHandler(std::string message_name, + HeterServiceHandler func); + + void StartHeterService(); + + void SetEndPoint(std::string& endpoint); + void SetFanin(int& fan_in); + + // HeterWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new HeterServer()); + } + return s_instance_; + } + + void WaitServerReady(); + + private: + static std::shared_ptr s_instance_; + std::string endpoint_; + + protected: + brpc::Server server_; + HeterService service_; + DISABLE_COPY_AND_ASSIGN(HeterServer); + std::mutex mutex_ready_; + std::condition_variable condition_ready_; + int ready_; +}; + +class HeterRequestHandler { + public: + HeterRequestHandler() + : dev_ctx_(nullptr), + executor_(nullptr), + scope_(nullptr), + program_(nullptr) {} + + virtual ~HeterRequestHandler() {} + + void SetScope(framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + void SetProgram(framework::ProgramDesc* program) { program_ = program; } + void SetExecutor(framework::Executor* executor) { executor_ = executor; } + + void SetGradToPreparedCtx( + std::unordered_map< + std::string, std::shared_ptr>* g) { + message_to_prepared_ctx_ = g; + } + + virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) = 0; + + protected: + const platform::DeviceContext* dev_ctx_; + framework::Executor* executor_; + framework::Scope* scope_; + framework::ProgramDesc* program_; + + std::unordered_map>* + message_to_prepared_ctx_; +}; + +class RequestSendAndRecvHandler final : public HeterRequestHandler { + public: + RequestSendAndRecvHandler() {} + virtual ~RequestSendAndRecvHandler() {} + int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) override { + platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle"); + auto& local_scope = scope_->NewScope(); + auto message_name = request->message_name(); + auto& request_io_buffer = cntl->request_attachment(); + distributed::DeserializeFromMultiVarMsgAndIOBuf( + *request, &request_io_buffer, *dev_ctx_, &local_scope); + executor_->RunPreparedContext( + (*message_to_prepared_ctx_)[message_name].get(), &local_scope, false); + + auto response_var_nums = request->recv_var_names_size(); + std::vector response_var_names(response_var_nums), + empty_var_names{}; + + for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) { + response_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto& response_io_buffer = cntl->response_attachment(); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name, response_var_names, empty_var_names, *dev_ctx_, + &local_scope, response, &response_io_buffer); + scope_->DeleteScope(&local_scope); + return 0; + } +}; + +} // end namespace distributed +} // end namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd5fb9c24b32cebd36f19822d97bde56171dac6d --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/ps_client.h" + +#include + +#include "brpc/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { +REGISTER_CLASS(PSClient, BrpcPsClient); + +int32_t PSClient::configure( + const PSParameter &config, + const std::map> ®ions, + PSEnvironment &env, size_t client_id) { + _env = &env; + _config = config; + _dense_pull_regions = regions; + _client_id = client_id; + _config.mutable_worker_param() + ->mutable_downpour_worker_param() + ->mutable_downpour_table_param() + ->CopyFrom(_config.server_param() + .downpour_server_param() + .downpour_table_param()); + + const auto &work_param = _config.worker_param().downpour_worker_param(); + + for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { + auto *accessor = CREATE_CLASS( + ValueAccessor, + work_param.downpour_table_param(i).accessor().accessor_class()); + accessor->configure(work_param.downpour_table_param(i).accessor()); + accessor->initialize(); + _table_accessors[work_param.downpour_table_param(i).table_id()].reset( + accessor); + } + return initialize(); +} + +PSClient *PSClientFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_client_class()) { + LOG(ERROR) << "miss client_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); + if (client == NULL) { + LOG(ERROR) << "client is not registered, server_name:" + << service_param.client_class(); + return NULL; + } + + TableManager::instance().initialize(); + LOG(INFO) << "Create PSClient[" << service_param.client_class() + << "] success"; + return client; +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h new file mode 100644 index 0000000000000000000000000000000000000000..23b00b3c816088e26c8d05f090a8bc815038b0d9 --- /dev/null +++ b/paddle/fluid/distributed/service/ps_client.h @@ -0,0 +1,208 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/table/accessor.h" + +namespace paddle { +namespace distributed { + +typedef std::function PSClientCallBack; +class PSClientClosure : public google::protobuf::Closure { + public: + PSClientClosure(PSClientCallBack callback) : _callback(callback) {} + virtual ~PSClientClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PSClientCallBack _callback; + std::vector>> _promises; +}; + +class PSClient { + public: + PSClient() {} + virtual ~PSClient() {} + PSClient(PSClient &&) = delete; + PSClient(const PSClient &) = delete; + + virtual int32_t configure( + const PSParameter &config, + const std::map> + ®ions, + PSEnvironment &_env, size_t client_id) final; + + virtual int32_t create_client2client_connection( + int pserver_timeout_ms, int pserver_connect_timeout_ms, + int max_retry) = 0; + + // 触发table数据退场 + virtual std::future shrink(uint32_t table_id) = 0; + + // 全量table进行数据load + virtual std::future load(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据load + virtual std::future load(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + // 全量table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(const std::string &epoch, + const std::string &mode) = 0; + // 指定table数据save value_accessor根据mode,可能有不同的save条件 + virtual std::future save(uint32_t table_id, const std::string &epoch, + const std::string &mode) = 0; + + //清空table数据 + virtual std::future clear() = 0; + virtual std::future clear(uint32_t table_id) = 0; + + // pull dense的参数部分,并分块填充到本地网络参数中 + // start和num用于拉取部分参数 + // future结束前keys和values缓冲区不能再次使用 + // client将values按照区块拆包后送交多个sender + // sender聚集同一区块的请求,累计多个填充buffer + // server将参数区块中配置的某一维提取返回 + // 返回数据解包后填充到累计的多个buffer中 + virtual std::future pull_dense(Region *regions, size_t region_num, + size_t table_id) = 0; //保留 + + // firstly push dense param for parameter server + // this is neccessary because dense weight initialized in trainer on cold + // start + virtual std::future push_dense_param(const Region *regions, + size_t region_num, + size_t table_id) = 0; + + // 使用keys进行pull请求,结果填充values + // keys和values的个数均为num个,每个value占用select_size空间 + // future结束前keys和values缓冲区不能再次使用 + // 整合多个线程请求的keys,聚集并分散发送到server + // 返回结果后,遍历buffer并对values赋值 + virtual std::future pull_sparse(float **select_values, + size_t table_id, + const uint64_t *keys, + size_t num) = 0; + + virtual std::future print_table_stat(uint32_t table_id) = 0; + + // 确保所有积攒中的请求都发起发送 + virtual std::future flush() = 0; + // server优雅退出 + virtual std::future stop_server() = 0; + + // server profilera + virtual std::future start_profiler() = 0; + virtual std::future stop_profiler() = 0; + + virtual std::future barrier(size_t table_id, + uint32_t barrier_type) = 0; + + virtual std::future pull_geo_param(size_t table_id, + std::vector *values, + std::vector *keys, + int pserver_idx) = 0; + + virtual void finalize_worker() = 0; + // client to client, 消息发送 + virtual std::future send_client2client_msg(int msg_type, + int to_client_id, + const std::string &msg) { + LOG(FATAL) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + // client2client消息处理,std::function ret (msg_type, from_client_id, msg) + typedef std::function MsgHandlerFunc; + virtual int registe_client2client_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + virtual int handle_client2client_msg(int msg_type, int from_client_id, + const std::string &msg) { + auto itr = _msg_handler_map.find(msg_type); + if (itr == _msg_handler_map.end()) { + LOG(WARNING) << "unknown client2client_msg type:" << msg_type; + return -1; + } + return itr->second(msg_type, from_client_id, msg); + } + + virtual ValueAccessor *table_accessor(size_t table_id) { + auto itr = _table_accessors.find(table_id); + if (itr == _table_accessors.end()) { + return NULL; + } + return itr->second.get(); + } + + virtual size_t get_server_nums() = 0; + + virtual std::future push_dense_raw_gradient( + int table_id, float *total_send_data, size_t total_send_data_size, + void *done) = 0; + + virtual std::future push_sparse_raw_gradient( + size_t table_id, const uint64_t *keys, const float **update_values, + size_t num, void *done) = 0; + + virtual std::future push_sparse_raw_gradient_partial( + size_t table_id, const uint64_t *keys, const float **update_values, + uint32_t num, void *done, int pserver_idx) = 0; + + virtual std::future push_sparse_param(size_t table_id, + const uint64_t *keys, + const float **update_values, + size_t num, void *done) = 0; + + protected: + virtual int32_t initialize() = 0; + size_t _client_id; + PSParameter _config; + std::map> + _dense_pull_regions; + PSEnvironment *_env; + std::unordered_map> _table_accessors; + std::unordered_map + _msg_handler_map; //处理client2client消息 +}; +REGISTER_REGISTERER(PSClient); + +class PSClientFactory { + public: + static PSClient *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto new file mode 100644 index 0000000000000000000000000000000000000000..8f5c8baa2f82427bf62f4429eb622986d761af9e --- /dev/null +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -0,0 +1,113 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +package paddle; +option cc_generic_services = true; +option cc_enable_arenas = true; + +enum PsCmdID { + PS_PULL_DENSE_TABLE = 0; + PS_PUSH_DENSE_TABLE = 1; + PS_PULL_SPARSE_TABLE = 2; + PS_PUSH_SPARSE_TABLE = 3; + PS_SHRINK_TABLE = 4; + PS_SAVE_ONE_TABLE = 5; + PS_SAVE_ALL_TABLE = 6; + PS_LOAD_ONE_TABLE = 7; + PS_LOAD_ALL_TABLE = 8; + PS_CLEAR_ONE_TABLE = 9; + PS_CLEAR_ALL_TABLE = 10; + PS_PUSH_DENSE_PARAM = 11; + PS_STOP_SERVER = 12; + PS_SAVE_ONE_CACHE_TABLE = 13; + PS_GET_CACHE_THRESHOLD = 14; + PS_CACHE_SHUFFLE = 15; + PS_COPY_TABLE = 16; + PS_COPY_TABLE_BY_FEASIGN = 17; + PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18; + PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19; + PS_PRINT_TABLE_STAT = 20; + PS_SAVE_ONE_TABLE_PREFIX = 21; + PS_SAVE_ONE_TABLE_WITH_WHITELIST = 22; + PS_LOAD_ONE_TABLE_WITH_WHITELIST = 23; + PS_PULL_GEO_PARAM = 24; + PS_BARRIER = 25; + PS_PUSH_SPARSE_PARAM = 26; + PS_START_PROFILER = 27; + PS_STOP_PROFILER = 28; +} + +message PsRequestMessage { + required uint32 cmd_id = 1; + optional uint32 table_id = 2; + repeated bytes params = 3; + optional int32 client_id = 4; + optional bytes data = 5; +}; + +message PsResponseMessage { + required int32 err_code = 1 [ default = 0 ]; + required string err_msg = 2 [ default = "" ]; + optional bytes data = 3; +}; + +enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; +} + +message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + optional string varname = 1; + // TODO(Yancey1989): reference framework::proto::VarDesc::VarType + optional VarType type = 2; + // bool persistable is not needed for sending. + // tensor info: + optional Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + optional int64 lod_level = 5; + repeated LodData lod = 6; + // selected_rows height, aka. original dim0 + optional int64 slr_height = 7; + // tensor data + optional bytes data = 8; +} + +// for SendAndRecv RPC method +message MultiVariableMessage { + // message flags + required string message_name = 1; + repeated string send_var_names = 2; + repeated string recv_var_names = 3; + repeated VariableMessage var_messages = 4; +}; + +service PsService { + rpc service(PsRequestMessage) returns (PsResponseMessage); + rpc SendAndRecvVariable(MultiVariableMessage) returns (MultiVariableMessage); +}; \ No newline at end of file diff --git a/paddle/fluid/distributed/service/server.cc b/paddle/fluid/distributed/service/server.cc new file mode 100644 index 0000000000000000000000000000000000000000..1582b8739c1775fc828d7ab29ea9b0e61ffc5bef --- /dev/null +++ b/paddle/fluid/distributed/service/server.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/service/server.h" +#include "glog/logging.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/table/table.h" + +namespace paddle { +namespace distributed { + +REGISTER_CLASS(PSServer, BrpcPsServer); +REGISTER_CLASS(PsBaseService, PsService); + +PSServer *PSServerFactory::create(const PSParameter &ps_config) { + const auto &config = ps_config.server_param(); + + if (!config.has_downpour_server_param()) { + LOG(ERROR) << "miss downpour_server_param in ServerParameter"; + return NULL; + } + + if (!config.downpour_server_param().has_service_param()) { + LOG(ERROR) << "miss service_param in ServerParameter.downpour_server_param"; + return NULL; + } + + if (!config.downpour_server_param().service_param().has_server_class()) { + LOG(ERROR) << "miss server_class in " + "ServerParameter.downpour_server_param.service_param"; + return NULL; + } + + const auto &service_param = config.downpour_server_param().service_param(); + PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); + if (server == NULL) { + LOG(ERROR) << "server is not registered, server_name:" + << service_param.server_class(); + return NULL; + } + TableManager::instance().initialize(); + return server; +} + +int32_t PSServer::configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) { + _config = config.server_param(); + _rank = server_rank; + _environment = &env; + _shuffled_ins = + paddle::framework::MakeChannel>(); + const auto &downpour_param = _config.downpour_server_param(); + + uint32_t barrier_table = UINT32_MAX; + + for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { + auto *table = CREATE_CLASS( + Table, downpour_param.downpour_table_param(i).table_class()); + + if (downpour_param.downpour_table_param(i).table_class() == + "BarrierTable") { + barrier_table = downpour_param.downpour_table_param(i).table_id(); + } + table->initialize(downpour_param.downpour_table_param(i), + config.fs_client_param()); + _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table); + } + + if (barrier_table != UINT32_MAX) { + _table_map[barrier_table]->set_table_map(&_table_map); + } + + return initialize(); +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h new file mode 100644 index 0000000000000000000000000000000000000000..4faa0f9db2c4c510f7e010616811f1a5fd10af43 --- /dev/null +++ b/paddle/fluid/distributed/service/server.h @@ -0,0 +1,150 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "butil/endpoint.h" +#include "google/protobuf/service.h" +#include "paddle/fluid/distributed/common/registerer.h" +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/channel.h" + +namespace paddle { +namespace distributed { + +class Table; + +class PSServer { + public: + PSServer() {} + virtual ~PSServer() {} + PSServer(PSServer &&) = delete; + PSServer(const PSServer &) = delete; + + virtual int32_t configure(const PSParameter &config, PSEnvironment &env, + size_t server_rank) final; + + // return server_ip + virtual std::string ip() { return butil::my_ip_cstr(); } + // return server_port + virtual int32_t port() = 0; + + virtual uint64_t start(const std::string &ip, uint32_t port) = 0; + virtual int32_t stop() = 0; + + inline size_t rank() const { return _rank; } + + inline PSEnvironment *environment() { return _environment; } + + inline const ServerParameter *config() const { return &_config; } + inline Table *table(size_t table_id) { + auto itr = _table_map.find(table_id); + if (itr != _table_map.end()) { + return itr->second.get(); + } + return NULL; + } + + inline std::unordered_map> *table() { + return &_table_map; + } + + typedef std::function MsgHandlerFunc; + virtual int registe_pserver2pserver_msg_handler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + + paddle::framework::Channel> _shuffled_ins; + + protected: + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + ServerParameter _config; + PSEnvironment *_environment; + std::unordered_map> _table_map; + std::unordered_map _msg_handler_map; +}; + +REGISTER_REGISTERER(PSServer); + +typedef std::function PServerCallBack; + +class PServerClosure : public google::protobuf::Closure { + public: + PServerClosure(PServerCallBack callback) : _callback(callback) {} + virtual ~PServerClosure() {} + virtual void set_promise_value(int value) { + for (auto &promise : _promises) { + promise->set_value(value); + } + } + void add_promise(std::shared_ptr> &promise) { + _promises.push_back(promise); + } + + protected: + PServerCallBack _callback; + std::vector>> _promises; +}; + +class PsBaseService : public PsService { + public: + PsBaseService() : _rank(0), _server(NULL), _config(NULL) {} + virtual ~PsBaseService() {} + + virtual int32_t configure(PSServer *server) { + _server = server; + _rank = _server->rank(); + _config = _server->config(); + return 0; + } + virtual void service(::google::protobuf::RpcController *controller, + const ::paddle::PsRequestMessage *request, + ::paddle::PsResponseMessage *response, + ::google::protobuf::Closure *done) override = 0; + + virtual void set_response_code(PsResponseMessage &response, int err_code, + const char *err_msg) { + response.set_err_msg(err_msg); + response.set_err_code(err_code); + LOG(WARNING) << "Resonse err_code:" << err_code << " msg:" << err_msg; + } + + virtual int32_t initialize() = 0; + + protected: + size_t _rank; + PSServer *_server; + const ServerParameter *_config; +}; +REGISTER_REGISTERER(PsBaseService); + +class PSServerFactory { + public: + static PSServer *create(const PSParameter &config); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc new file mode 100644 index 0000000000000000000000000000000000000000..40a6d2e122718790d1970e7697c05ff862e6f738 --- /dev/null +++ b/paddle/fluid/distributed/service/service.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/distributed/service/service.h" + +#include +#include +#include +#include +#include "paddle/fluid/distributed/service/communicator.h" +#include "paddle/fluid/string/string_helper.h" + +using namespace std; + +namespace paddle { +namespace distributed { + +paddle::distributed::PSParameter load_from_prototxt( + const std::string& filename) { + paddle::distributed::PSParameter param; + int file_descriptor = open(filename.c_str(), O_RDONLY); + + if (file_descriptor == -1) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + google::protobuf::io::FileInputStream fileInput(file_descriptor); + if (!google::protobuf::TextFormat::Parse(&fileInput, ¶m)) { + VLOG(3) << "FATAL: fail to parse " << filename; + exit(-1); + } + + close(file_descriptor); + return param; +} + +void PSCore::init_gflag(const std::string& gflags) { + LOG(INFO) << "Init With Gflags:" << gflags; + std::vector flags = paddle::string::split_string(gflags); + if (flags.size() < 1) { + flags.push_back("-max_body_size=314217728"); + flags.push_back("-bthread_concurrency=40"); + flags.push_back("-socket_max_unwritten_bytes=2048000000"); + flags.push_back("-max_connection_pool_size=1950"); + } + auto it = flags.begin(); + flags.insert(it, "exe default"); + char* flags_ptr[flags.size()]; + for (size_t i = 0; i < flags.size(); ++i) { + flags_ptr[i] = (char*)(flags[i].c_str()); + } + int params_cnt = flags.size(); + char** params_ptr = &(flags_ptr[0]); + ::google::ParseCommandLineFlags(¶ms_cnt, ¶ms_ptr, true); +} + +int PSCore::init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + _server_ptr = std::shared_ptr( + paddle::distributed::PSServerFactory::create(_ps_param)); + ret = _server_ptr->configure(_ps_param, _ps_env, index); + CHECK(ret == 0) << "failed to configure server"; + return ret; +} + +int PSCore::init_worker( + const std::string& dist_desc, + const std::map>& regions, + const std::vector* host_sign_list, int node_num, int index) { + google::protobuf::TextFormat::ParseFromString(dist_desc, &_ps_param); + init_gflag(_ps_param.init_gflags()); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(host_sign_list, node_num); + int ret = 0; + VLOG(1) << "PSCore::init_worker"; + auto* communicator = Communicator::GetInstance(); + ret = communicator->GetPsClient()->configure(_ps_param, regions, _ps_env, + index); + communicator->Start(); + return ret; +} + +std::vector PSCore::get_client_info() { + return _ps_env.get_client_info(); +} + +int PSCore::create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + int ret = _worker_ptr->create_client2client_connection( + pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); + return ret; +} + +uint64_t PSCore::run_server(const std::string& ip, uint32_t port) { + return _server_ptr->start(ip, port); +} + +int PSCore::finalize_worker() { + _worker_ptr->finalize_worker(); + return 0; +} + +int PSCore::stop_server() { + auto stop_status = _worker_ptr->stop_server(); + stop_status.wait(); + return 0; +} +paddle::distributed::PSParameter* PSCore::get_param() { return &_ps_param; } +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h new file mode 100644 index 0000000000000000000000000000000000000000..97cb864e344bf8a152a494a5365ce7a22f5eb4c8 --- /dev/null +++ b/paddle/fluid/distributed/service/service.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/server.h" + +namespace paddle { +namespace distributed { + +class PSCore { + public: + explicit PSCore() {} + virtual ~PSCore() {} + + virtual int init_server(const std::string& dist_desc, + const std::vector* host_sign_list, + int node_num, int index); + virtual int init_worker( + const std::string& dist_desc, + const std::map>& + regions, + const std::vector* host_sign_list, int node_num, int index); + virtual uint64_t run_server(const std::string& ip, uint32_t port); + virtual int stop_server(); + virtual int finalize_worker(); + virtual std::vector get_client_info(); + virtual int create_client2client_connection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry); + std::shared_ptr + _server_ptr; // pointer to server + std::shared_ptr + _worker_ptr; // pointer to worker + virtual paddle::distributed::PSParameter* get_param(); + + private: + void init_gflag(const std::string& gflags); + paddle::distributed::PSParameter _ps_param; + paddle::distributed::PaddlePSEnvironment _ps_env; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/test/CMakeLists.txt b/paddle/fluid/distributed/test/CMakeLists.txt index e4cc93c9adf65c74c7df6b01a90a34e8a88f502d..405fe7561115e68b45d4f8a84e59e4f12faed18c 100644 --- a/paddle/fluid/distributed/test/CMakeLists.txt +++ b/paddle/fluid/distributed/test/CMakeLists.txt @@ -16,3 +16,16 @@ cc_test(geo_table_test SRCS geo_table_test.cc DEPS common_table table tensor_acc set_source_files_properties(barrier_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(barrier_table_test SRCS barrier_table_test.cc DEPS common_table table tensor_accessor ps_framework_proto ${COMMON_DEPS}) + + +# open it until CI support brpc +return() + +set_source_files_properties(brpc_service_dense_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_dense_sgd_test SRCS brpc_service_dense_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_service_sparse_sgd_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_service_sparse_sgd_test SRCS brpc_service_sparse_sgd_test.cc DEPS scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS}) + +set_source_files_properties(brpc_utils_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_function ${COMMON_DEPS} ${RPC_DEPS}) diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b2f808a2a82d558a6dabc85b57139f99d8ea389 --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -0,0 +1,272 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0 * (float)i; +} + +void GetDownpourDenseTableProto( + ::paddle::distributed::TableParameter* dense_table_proto) { + dense_table_proto->set_table_id(0); + dense_table_proto->set_table_class("CommonDenseTable"); + dense_table_proto->set_shard_num(256); + dense_table_proto->set_type(::paddle::distributed::PS_DENSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + dense_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + dense_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(100); + accessor_proto->set_embedx_dim(1); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(100); + common_proto->add_initializers("fill_constant&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(dense_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_dense_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(worker_dense_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_dense_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourDenseTableProto(server_dense_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4214; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "RUN set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + LOG(INFO) << "RUN configure"; + pserver_ptr_->configure(server_proto, _ps_env, 0); + LOG(INFO) << "RUN start"; + pserver_ptr_->start(ip_, port_); + LOG(INFO) << "End start"; +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + LOG(INFO) << "Run set_ps_servers"; + _ps_env.set_ps_servers(&host_sign_list_, servers_); + LOG(INFO) << "Run Create PSClient"; + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + LOG(INFO) << "Run configure"; + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushDense() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + LOG(INFO) << "Run InitTensorsOnClient"; + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + float* w = tensor->data(); + paddle::distributed::Region reg(w, tensor->numel()); + regions.emplace_back(std::move(reg)); + + LOG(INFO) << "Run RunClient"; + RunClient(dense_regions); + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_dense_param"; + float* temp = new float[tensor->numel()](); + std::vector temp_region; + paddle::distributed::Region temp_reg(temp, tensor->numel()); + temp_region.emplace_back(std::move(temp_reg)); + auto pull_status = + worker_ptr_->pull_dense(temp_region.data(), temp_region.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(temp[idx], 1.0); + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_dense_param"; + auto push_status = + worker_ptr_->push_dense_param(regions.data(), regions.size(), 0); + push_status.wait(); + + pull_status = worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx)); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_dense_grad"; + auto push_grad_status = + worker_ptr_->push_dense_raw_gradient(0, temp, tensor->numel(), closure); + push_grad_status.wait(); + + auto pull_update_status = + worker_ptr_->pull_dense(regions.data(), regions.size(), 0); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushDense, Run) { RunBrpcPushDense(); } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..224b9ba2fc780a217bbe4a007d624d0f7afcedf0 --- /dev/null +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -0,0 +1,285 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT +#include +#include // NOLINT + +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/distributed/service/brpc_ps_client.h" +#include "paddle/fluid/distributed/service/brpc_ps_server.h" +#include "paddle/fluid/distributed/service/env.h" +#include "paddle/fluid/distributed/service/ps_client.h" +#include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/distributed/service/service.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; +namespace distributed = paddle::distributed; + +void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { + auto x_var = scope->Var("x"); + x_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; +} + +void GetDownpourSparseTableProto( + ::paddle::distributed::TableParameter* sparse_table_proto) { + sparse_table_proto->set_table_id(0); + sparse_table_proto->set_table_class("CommonSparseTable"); + sparse_table_proto->set_shard_num(256); + sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); + ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->mutable_accessor(); + ::paddle::distributed::CommonAccessorParameter* common_proto = + sparse_table_proto->mutable_common(); + + accessor_proto->set_accessor_class("CommMergeAccessor"); + accessor_proto->set_fea_dim(0); + accessor_proto->set_embedx_dim(10); + + common_proto->set_name("sgd"); + common_proto->set_table_name("MergedDense"); + common_proto->set_trainer_num(1); + common_proto->set_sync(false); + common_proto->add_params("Param"); + common_proto->add_dims(10); + common_proto->add_initializers("uniform_random&0&-1.0&1.0"); + common_proto->add_params("LearningRate"); + common_proto->add_dims(1); + common_proto->add_initializers("fill_constant&1.0"); +} + +::paddle::distributed::PSParameter GetServerProto() { + // Generate server proto desc + ::paddle::distributed::PSParameter server_fleet_desc; + ::paddle::distributed::ServerParameter* server_proto = + server_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(sparse_table_proto); + return server_fleet_desc; +} + +::paddle::distributed::PSParameter GetWorkerProto() { + ::paddle::distributed::PSParameter worker_fleet_desc; + ::paddle::distributed::WorkerParameter* worker_proto = + worker_fleet_desc.mutable_worker_param(); + + ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto = + worker_proto->mutable_downpour_worker_param(); + + ::paddle::distributed::TableParameter* worker_sparse_table_proto = + downpour_worker_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(worker_sparse_table_proto); + + ::paddle::distributed::ServerParameter* server_proto = + worker_fleet_desc.mutable_server_param(); + ::paddle::distributed::DownpourServerParameter* downpour_server_proto = + server_proto->mutable_downpour_server_param(); + ::paddle::distributed::ServerServiceParameter* server_service_proto = + downpour_server_proto->mutable_service_param(); + server_service_proto->set_service_class("PsService"); + server_service_proto->set_server_class("BrpcPsServer"); + server_service_proto->set_client_class("BrpcPsClient"); + server_service_proto->set_start_server_port(0); + server_service_proto->set_server_thread_num(12); + + ::paddle::distributed::TableParameter* server_sparse_table_proto = + downpour_server_proto->add_downpour_table_param(); + GetDownpourSparseTableProto(server_sparse_table_proto); + + return worker_fleet_desc; +} + +/*-------------------------------------------------------------------------*/ + +std::string ip_ = "127.0.0.1"; +uint32_t port_ = 4209; + +std::vector host_sign_list_; + +std::shared_ptr pserver_ptr_; + +std::shared_ptr worker_ptr_; + +void RunServer() { + ::paddle::distributed::PSParameter server_proto = GetServerProto(); + + auto _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, 1); + pserver_ptr_ = std::shared_ptr( + paddle::distributed::PSServerFactory::create(server_proto)); + pserver_ptr_->configure(server_proto, _ps_env, 0); + pserver_ptr_->start(ip_, port_); +} + +void RunClient(std::map>& + dense_regions) { + ::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); + paddle::distributed::PaddlePSEnvironment _ps_env; + auto servers_ = host_sign_list_.size(); + _ps_env = paddle::distributed::PaddlePSEnvironment(); + _ps_env.set_ps_servers(&host_sign_list_, servers_); + worker_ptr_ = std::shared_ptr( + paddle::distributed::PSClientFactory::create(worker_proto)); + worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0); +} + +void RunBrpcPushSparse() { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + auto ph_host = paddle::distributed::PSHost(ip_, port_, 0); + host_sign_list_.push_back(ph_host.serialize_to_string()); + + // Srart Server + std::thread server_thread(RunServer); + sleep(1); + + // Start Client + framework::Scope client_scope; + platform::CPUPlace place; + InitTensorsOnClient(&client_scope, &place, 100); + std::map> dense_regions; + dense_regions.insert( + std::pair>(0, {})); + auto regions = dense_regions[0]; + framework::Variable* var = client_scope.FindVar("x"); + framework::LoDTensor* tensor = var->GetMutable(); + + RunClient(dense_regions); + std::vector fea_keys(10); + std::vector fea_values(100); + std::vector fea_temp_values(100); + std::vector fea_value_ptr(10); + std::vector fea_temp_value_ptr(10); + + for (size_t idx = 0; idx < fea_keys.size(); ++idx) { + fea_keys[idx] = (uint64_t)idx; + fea_value_ptr[idx] = fea_values.data() + idx * 10; + fea_temp_value_ptr[idx] = fea_temp_values.data() + idx * 10; + } + + /*-----------------------Test Server Init----------------------------------*/ + LOG(INFO) << "Run pull_sparse_param"; + auto pull_status = worker_ptr_->pull_sparse(fea_value_ptr.data(), 0, + fea_keys.data(), fea_keys.size()); + pull_status.wait(); + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + fea_values.data()[idx] *= 2.0; + } + + /*-----------------------Test Push Param----------------------------------*/ + + LOG(INFO) << "Run push_sparse_param"; + paddle::distributed::DownpourBrpcClosure* closure_push_param = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + auto push_status = worker_ptr_->push_sparse_param( + 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), + closure_push_param); + push_status.wait(); + + auto pull_param_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_param_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); + } + + /*-----------------------Test Push Grad----------------------------------*/ + + paddle::distributed::DownpourBrpcClosure* closure_push_grad = + new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { + int ret = 0; + auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; + for (size_t i = 0; i < 1; ++i) { + if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) { + ret = -1; + break; + } + } + closure->set_promise_value(ret); + }); + + LOG(INFO) << "Run pull_sparse_grad"; + std::vector push_g_vec; + for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { + push_g_vec.push_back(tensor->data() + i * 10); + } + auto push_grad_status = worker_ptr_->push_sparse_raw_gradient( + 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), + closure_push_grad); + push_grad_status.wait(); + + auto pull_update_status = worker_ptr_->pull_sparse( + fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size()); + pull_update_status.wait(); + + for (size_t idx = 0; idx < tensor->numel(); ++idx) { + EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx] - 1.0); + } + + LOG(INFO) << "Run stop_server"; + worker_ptr_->stop_server(); + LOG(INFO) << "Run finalize_worker"; + worker_ptr_->finalize_worker(); + server_thread.join(); +} + +TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); } diff --git a/paddle/fluid/distributed/test/heter_serde_test.cc b/paddle/fluid/distributed/test/brpc_utils_test.cc similarity index 98% rename from paddle/fluid/distributed/test/heter_serde_test.cc rename to paddle/fluid/distributed/test/brpc_utils_test.cc index 21380921958dbbf3eb6d8cd0589316d5612fd290..ce33cbe6ea39713589eb8f201d95bb2d99e5ff0c 100644 --- a/paddle/fluid/distributed/test/heter_serde_test.cc +++ b/paddle/fluid/distributed/test/brpc_utils_test.cc @@ -23,7 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" -#include "paddle/fluid/distributed/service/heter_serde.h" +#include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/printf.h" diff --git a/paddle/fluid/distributed/test/geo_table_test.cc b/paddle/fluid/distributed/test/geo_table_test.cc index fffecbe199e055126519526d87d034bea30b331a..5ec1e87dcb693827cbf1aa2024706b8fcd0b7dc9 100644 --- a/paddle/fluid/distributed/test/geo_table_test.cc +++ b/paddle/fluid/distributed/test/geo_table_test.cc @@ -109,7 +109,7 @@ TEST(SparseGeoTable, SSUM) { auto id = geo_pull_ids[i][j]; for (int k = 0; k < emb_dim; k++) { ASSERT_TRUE(abs(geo_pull_values[i][j * emb_dim + k] - - pull_values[id * emb_dim + k]) < 1e-6); + pull_values[id * emb_dim + k]) < 1e-5); } } } diff --git a/paddle/fluid/distributed/test/sparse_table_test.cc b/paddle/fluid/distributed/test/sparse_table_test.cc index 65439014e8f0e26e0e6c2e06d692f416235029cf..6db95c5fac211b94db726ee77c9122a8824c2351 100644 --- a/paddle/fluid/distributed/test/sparse_table_test.cc +++ b/paddle/fluid/distributed/test/sparse_table_test.cc @@ -103,7 +103,7 @@ TEST(CommonSparseTable, SGD) { table->pull_sparse(pull_values.data(), init_keys.data(), init_keys.size()); for (size_t i = 0; i < init_values.size(); ++i) { auto update_val = init_values[i] - 1.0 * total_gradients[i]; - ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-6); + ASSERT_TRUE(abs(update_val - pull_values[i]) < 1e-5); } }